diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000000..e694a9d33d04 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = E,F403,F405,F541,F841,W +select = E9,F,W6 +per-file-ignores = + __init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/ci_failure_report.md b/.github/ISSUE_TEMPLATE/ci_failure_report.md new file mode 100644 index 000000000000..6bf4c7762319 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/ci_failure_report.md @@ -0,0 +1,10 @@ +--- +name: CI failure report +about: Report a DeepSpeed CI failure +title: "{{ env.GITHUB_WORKFLOW }} CI test failure" +labels: ci-failure +assignees: '' + +--- + +The Nightly CI for {{ env.GITHUB_SERVER_URL }}/{{ env.GITHUB_REPOSITORY }}/actions/runs/{{ env.GITHUB_RUN_ID }} failed. diff --git a/.github/ISSUE_TEMPLATE/deepspeed_chat_bug_report.md b/.github/ISSUE_TEMPLATE/deepspeed_chat_bug_report.md new file mode 100644 index 000000000000..f27b1c6303eb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/deepspeed_chat_bug_report.md @@ -0,0 +1,44 @@ +--- +name: Bug report (DeepSpeed-Chat) +about: Create a DeepSpeed-Chat related issue to help us improve +title: "[BUG]" +labels: bug,deepspeed-chat +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. Please include which training step you are using and which model you are training. + +**Log output** +If you used `train.py` to launch the application, please include the contents of the output log file. + +**To Reproduce** +Steps to reproduce the behavior: +1. Command/Script to reproduce +2. What packages are required and their versions +3. How to run the script +4. ... + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**ds_report output** +Please run `ds_report` to give us details about your setup. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**System info (please complete the following information):** + - OS: [e.g. Ubuntu 18.04] + - GPU count and types [e.g. two machines with x8 A100s each] + - (if applicable) what [DeepSpeed-MII](https://github.com/deepspeedai/deepspeed-mii) version are you using + - (if applicable) Hugging Face Transformers/Accelerate/etc. versions + - Python version + - Any other relevant info about your setup + +**Docker context** +Are you using a specific docker image that you can share? + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/inference_bug_report.md b/.github/ISSUE_TEMPLATE/inference_bug_report.md index bc5df17258b0..8a4144ce049a 100644 --- a/.github/ISSUE_TEMPLATE/inference_bug_report.md +++ b/.github/ISSUE_TEMPLATE/inference_bug_report.md @@ -29,7 +29,7 @@ If applicable, add screenshots to help explain your problem. **System info (please complete the following information):** - OS: [e.g. Ubuntu 18.04] - GPU count and types [e.g. two machines with x8 A100s each] - - (if applicable) what [DeepSpeed-MII](https://github.com/microsoft/deepspeed-mii) version are you using + - (if applicable) what [DeepSpeed-MII](https://github.com/deepspeedai/deepspeed-mii) version are you using - (if applicable) Hugging Face Transformers/Accelerate/etc. versions - Python version - Any other relevant info about your setup diff --git a/.github/workflows/amd-mi100.yml b/.github/workflows/amd-mi100.yml deleted file mode 100644 index 61c30bfb2cea..000000000000 --- a/.github/workflows/amd-mi100.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: amd-mi100 - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - amd-tests: - # The type of runner that the job will run on - runs-on: [self-hosted, amd, mi100] - - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install transformers - run: | - git clone https://github.com/huggingface/transformers - cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 - git rev-parse --short HEAD - pip install . - - # Runs a set of commands using the runners shell - - name: Install deepspeed - run: | - pip install .[dev,1bit,autotuning] - #python -c "from deepspeed.env_report import cli_main; cli_main()" - ds_report - - - name: Python environment - run: | - pip list - - # Runs a set of commands using the runners shell - - name: Unit tests - run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 --verbose unit/ - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/ diff --git a/.github/workflows/amd-mi200.yml b/.github/workflows/amd-mi200.yml index 5dbab382e128..63d7348374fe 100644 --- a/.github/workflows/amd-mi200.yml +++ b/.github/workflows/amd-mi200.yml @@ -1,14 +1,11 @@ name: amd-mi200 on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' + workflow_dispatch: pull_request: - paths-ignore: - - 'docs/**' + paths: + - '.github/workflows/amd-mi200.yml' + - 'requirements/**' schedule: - cron: "0 0 * * *" @@ -16,22 +13,27 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + contents: read + issues: write + jobs: amd-tests: + name: amd-mi200 / AMD MI200 tests # The type of runner that the job will run on runs-on: [self-hosted, amd, mi200] # Steps represent a sequence of tasks that will be executed as part of the job steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm6.0 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -44,11 +46,16 @@ jobs: git rev-parse --short HEAD pip install . - - name: Install apex + - name: Install (ROCm) apex run: | - pip install ninja - pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git' - + git clone https://github.com/ROCmSoftwarePlatform/apex.git + CURRENT_VER=$(git rev-parse HEAD) + INSTALLED_VER=$(cat /blob/amd-apex/.venv_installed_version) + if [[ "$CURRENT_VER" != "$INSTALLED_VER" ]]; then + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings="--global-option=--cpp_ext" --config-settings="--global-option=--cuda_ext" --target=/blob/amd-apex/ --upgrade . + git rev-parse HEAD > /blob/amd-apex/.venv_installed_version + fi + echo PYTHONPATH=$PYTHONPATH:/blob/amd-apex/ >> $GITHUB_ENV # Runs a set of commands using the runners shell - name: Install deepspeed run: | @@ -63,7 +70,16 @@ jobs: # Runs a set of commands using the runners shell - name: Unit tests run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 --verbose unit/ - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/ + pytest $PYTEST_OPTS -n 4 --verbose unit/ + pytest $PYTEST_OPTS -m 'sequential' unit/ + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/auto-sync.yml b/.github/workflows/auto-sync.yml deleted file mode 100644 index 5cc5dc02224f..000000000000 --- a/.github/workflows/auto-sync.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: AutoSync - -on: - push: - branches: - - 'master' - -jobs: - - Create-PR: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v2 - with: - token: ${{ secrets.GHP_TOKEN }} - repository: ${{ secrets.DST_REPO }} - ref: ${{ secrets.DST_REPO_BRANCH }} - path: dst-repo - - - name: Get PR data - run: | - echo "REPO=${{ github.repository }}" >> $GITHUB_ENV - echo "COMMIT_SHA=${{ github.event.after }}" >> $GITHUB_ENV - echo "SHORT_SHA=$(echo ${{ github.event.after }} | cut -c1-8)" >> $GITHUB_ENV - echo "USERNAME=${{ github.event.head_commit.author.username }}" >> $GITHUB_ENV - echo "USER_EMAIL=${{ github.event.head_commit.author.username }}@users.noreply.github.com" >> $GITHUB_ENV - echo "PR_NAME=$(echo '${{ github.event.head_commit.message }}' | head -1 | sed 's|#|${{ github.repository }}#|g')" >> $GITHUB_ENV - - - name: Cherry pick commit - continue-on-error: true - run: | - cd dst-repo - git config --global user.name ${{ env.USERNAME }} - git config --global user.email ${{ env.USER_EMAIL }} - git fetch https://github.com/${{ env.REPO }}.git master - git cherry-pick FETCH_HEAD --strategy-option octopus - - - name: Add modified files - run: | - cd dst-repo - git add . - - - name: Create Pull Request - uses: peter-evans/create-pull-request@v4 - with: - path: dst-repo - token: ${{ secrets.GHP_TOKEN }} - body: | - **Auto-generated PR** - Repo - [${{ env.REPO }}](https://github.com/${{ env.REPO }}) - PR name - ${{ env.PR_NAME }} - Commit - ${{ env.REPO }}@${{ env.COMMIT_SHA }} - Author - @${{ env.USERNAME }} - branch: AutoPR/${{ env.SHORT_SHA }} - assignees: ${{ env.USERNAME }} - title: ${{ env.PR_NAME }} - labels: AutoPR - author: ${{ env.USERNAME }} <${{ env.USER_EMAIL }}> diff --git a/.github/workflows/aws-accelerate.yml b/.github/workflows/aws-accelerate.yml new file mode 100644 index 000000000000..f8397381d163 --- /dev/null +++ b/.github/workflows/aws-accelerate.yml @@ -0,0 +1,115 @@ +################################################################################ +# DeepSpeed CI - AWS L40S GPU Tests (HuggingFace Accelerate Integration) +# +# Runs the same tests as modal-accelerate.yml but on AWS self-hosted runners. +# Tests DeepSpeed integration with HuggingFace Accelerate library. +# Uses 4x NVIDIA L40S GPUs on g6e.12xlarge instances. +################################################################################ + +name: aws-accelerate + +on: + workflow_dispatch: + + push: + branches: + - master + + pull_request: + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + check-paths: + name: aws-accelerate / check paths + runs-on: ubuntu-latest + outputs: + should_run: ${{ steps.filter.outputs.run_tests }} + steps: + - uses: actions/checkout@v4 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + run_tests: + - '**' + - '!docs/**' + - '!blogs/**' + - '!deepspeed/inference/v2/**' + - '!tests/unit/inference/v2/**' + + accelerate-tests: + name: aws-accelerate / accelerate integration tests + needs: check-paths + if: needs.check-paths.outputs.should_run == 'true' + runs-on: [self-hosted, gpu-ci, gpu-l40s, l40s-1gpu, aws] + timeout-minutes: 60 + + container: + image: nvidia/cuda:12.6.3-devel-ubuntu22.04 + options: --gpus all --shm-size "32G" + + env: + TORCH_VER: "2.7" + CUDA_VER: "12.6" + + steps: + - name: Install system dependencies + run: | + apt-get update && apt-get install -y git git-lfs libaio-dev python3 python3-pip + git lfs install + ln -sf /usr/bin/python3 /usr/bin/python + + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Install PyTorch + run: | + pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu126 + + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/requirements-dev.txt + pip install datasets + + - name: Check environment + run: | + echo "=== GPU Information ===" + nvidia-smi + echo "" + echo "=== CUDA Version ===" + nvcc --version + echo "" + echo "=== Python/PyTorch Info ===" + python --version + python -c "import torch; print(f'PyTorch: {torch.__version__}')" + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + python -c "import torch; print(f'CUDA devices: {torch.cuda.device_count()}')" + python -c "import torch; print(f'BF16 support: {torch.cuda.is_bf16_supported()}')" + + - name: Install DeepSpeed + run: | + # Initialize CUDA before install so setup.py can detect NCCL version + python -c "import torch; torch.cuda.init(); print(f'NCCL version: {torch.cuda.nccl.version()}')" + # Use --no-build-isolation so setup.py can access pre-installed PyTorch + pip install --no-build-isolation . + ds_report + # Debug: Check captured torch_info values + python -c "from deepspeed.git_version_info import torch_info; print(f'torch_info: {torch_info}')" + + - name: Clone and install Accelerate + run: | + git clone https://github.com/huggingface/accelerate + pip install "./accelerate[testing]" + + - name: Run Accelerate DeepSpeed tests + run: | + pytest --verbose ./accelerate/tests/deepspeed diff --git a/.github/workflows/aws-torch-latest-full.yml b/.github/workflows/aws-torch-latest-full.yml new file mode 100644 index 000000000000..5f8b6183f968 --- /dev/null +++ b/.github/workflows/aws-torch-latest-full.yml @@ -0,0 +1,380 @@ +################################################################################ +# DeepSpeed CI - AWS L40S GPU Full Tests (PyTorch Latest) +# +# Runs the full DeepSpeed unit test suite on AWS self-hosted runners. +# Prefers 4x NVIDIA L40S GPUs on g6e.12xlarge instances, with AWS-side +# fallback to 8x A100 nodes when L40S capacity is unavailable. +# +# This workflow runs: +# - Parallel tests with pytest-xdist (-n 8) +# - Sequential tests marked with @pytest.mark.sequential +# - Nightly schedule: skips if no new commits since last successful run +################################################################################ + +name: aws-torch-latest-full + +on: + schedule: + - cron: '0 8 * * *' # Daily at 08:00 UTC (midnight PST) + workflow_dispatch: + inputs: + torch_preset: + description: PyTorch preset to install for manual runs + required: false + default: '2.10.0-cu126' + type: choice + options: + - '2.7.1-cu126' + - '2.8.0-cu126' + - '2.9.1-cu126' + - '2.10.0-cu126' + - '2.11.0-cu126' + transformers_version: + description: Hugging Face Transformers PyPI package version to install + required: false + default: '4.50.0' + type: string + transformers_source: + description: Hugging Face Transformers source for manual runs + required: false + default: 'git' + type: choice + options: + - 'pypi' + - 'git' + transformers_ref: + description: Hugging Face Transformers git ref to install when source is git + required: false + default: 'main' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + check-changes: + name: Check for new commits + runs-on: ubuntu-latest + if: github.event_name == 'schedule' + outputs: + has_changes: ${{ steps.check.outputs.has_changes }} + steps: + - name: Check for commits since last successful run + id: check + env: + GH_TOKEN: ${{ github.token }} + run: | + default_branch="${{ github.event.repository.default_branch }}" + + last_sha=$(gh api \ + "repos/${{ github.repository }}/actions/workflows/aws-torch-latest-full.yml/runs?status=success&event=schedule&branch=${default_branch}&per_page=1" \ + --jq '.workflow_runs[0].head_sha // empty') + + current_sha="${{ github.sha }}" + + if [ -z "$last_sha" ]; then + echo "No previous successful run found - running tests" + echo "has_changes=true" >> "$GITHUB_OUTPUT" + elif [ "$last_sha" = "$current_sha" ]; then + echo "No new commits since last successful run ($last_sha) - skipping" + echo "has_changes=false" >> "$GITHUB_OUTPUT" + else + echo "New commits detected: $last_sha -> $current_sha - running tests" + echo "has_changes=true" >> "$GITHUB_OUTPUT" + fi + + unit-tests: + name: Unit Tests (Full) + needs: [check-changes] + if: | + always() && + (github.event_name == 'workflow_dispatch' || needs.check-changes.outputs.has_changes == 'true') + runs-on: [self-hosted, gpu-ci, gpu-l40s, l40s-4gpu, aws] + timeout-minutes: 180 + + container: + image: nvidia/cuda:12.6.3-devel-ubuntu22.04 + # Mount /mnt/aio for async I/O tests (O_DIRECT requires native filesystem, not overlayfs) + options: --gpus all --shm-size "32G" -v /mnt/aio:/mnt/aio + + env: + DEFAULT_TORCH_PRESET: '2.10.0-cu126' + DEFAULT_TRANSFORMERS_SOURCE: 'git' + DEFAULT_TRANSFORMERS_VERSION: '4.50.0' + DEFAULT_TRANSFORMERS_REF: 'main' + CUTLASS_PATH: /opt/cutlass + # Disable reuse_dist_env to prevent pool worker cleanup hangs in full test runs + DS_DISABLE_REUSE_DIST_ENV: '1' + + steps: + - name: Install system dependencies + run: | + apt-get update && apt-get install -y git git-lfs libaio-dev pdsh python3 python3-pip + git lfs install + ln -sf /usr/bin/python3 /usr/bin/python + + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Resolve dependency inputs + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + MANUAL_TORCH_PRESET: ${{ github.event.inputs.torch_preset || '' }} + MANUAL_TRANSFORMERS_SOURCE: ${{ github.event.inputs.transformers_source || '' }} + MANUAL_TRANSFORMERS_VERSION: ${{ github.event.inputs.transformers_version || '' }} + MANUAL_TRANSFORMERS_REF: ${{ github.event.inputs.transformers_ref || '' }} + run: | + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TORCH_PRESET" ]; then + selected_preset="$MANUAL_TORCH_PRESET" + else + selected_preset="$DEFAULT_TORCH_PRESET" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_SOURCE" ]; then + transformers_source="$MANUAL_TRANSFORMERS_SOURCE" + else + transformers_source="$DEFAULT_TRANSFORMERS_SOURCE" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_VERSION" ]; then + transformers_version="$MANUAL_TRANSFORMERS_VERSION" + else + transformers_version="$DEFAULT_TRANSFORMERS_VERSION" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_REF" ]; then + transformers_ref="$MANUAL_TRANSFORMERS_REF" + else + transformers_ref="$DEFAULT_TRANSFORMERS_REF" + fi + + if [ "$transformers_source" = 'git' ] && [ -z "$transformers_ref" ]; then + transformers_ref='main' + fi + + case "$selected_preset" in + '2.7.1-cu126') + torch_install_version='2.7.1' + torchvision_install_version='0.22.1' + torchaudio_install_version='2.7.1' + torch_test_version='2.7' + cuda_test_version='12.6' + pytorch_index_url='https://download.pytorch.org/whl/cu126' + ;; + '2.8.0-cu126') + torch_install_version='2.8.0' + torchvision_install_version='0.23.0' + torchaudio_install_version='2.8.0' + torch_test_version='2.8' + cuda_test_version='12.6' + pytorch_index_url='https://download.pytorch.org/whl/cu126' + ;; + '2.9.1-cu126') + torch_install_version='2.9.1' + torchvision_install_version='0.24.1' + torchaudio_install_version='2.9.1' + torch_test_version='2.9' + cuda_test_version='12.6' + pytorch_index_url='https://download.pytorch.org/whl/cu126' + ;; + '2.10.0-cu126') + torch_install_version='2.10.0' + torchvision_install_version='0.25.0' + torchaudio_install_version='2.10.0' + torch_test_version='2.10' + cuda_test_version='12.6' + pytorch_index_url='https://download.pytorch.org/whl/cu126' + ;; + '2.11.0-cu126') + torch_install_version='2.11.0' + torchvision_install_version='0.26.0' + torchaudio_install_version='2.11.0' + torch_test_version='2.11' + cuda_test_version='12.6' + pytorch_index_url='https://download.pytorch.org/whl/cu126' + ;; + *) + echo "Unsupported torch_preset: $selected_preset" >&2 + exit 1 + ;; + esac + + { + echo "SELECTED_TORCH_PRESET=$selected_preset" + echo "TORCH_INSTALL_VERSION=$torch_install_version" + echo "TORCHVISION_INSTALL_VERSION=$torchvision_install_version" + echo "TORCHAUDIO_INSTALL_VERSION=$torchaudio_install_version" + echo "TORCH_TEST_VERSION=$torch_test_version" + echo "CUDA_TEST_VERSION=$cuda_test_version" + echo "PYTORCH_INDEX_URL=$pytorch_index_url" + echo "TRANSFORMERS_SOURCE=$transformers_source" + echo "TRANSFORMERS_VERSION=$transformers_version" + echo "TRANSFORMERS_REF=$transformers_ref" + } >> "$GITHUB_ENV" + + echo "Selected PyTorch preset: $selected_preset" + echo "Resolved install tuple: torch==$torch_install_version torchvision==$torchvision_install_version torchaudio==$torchaudio_install_version" + echo "Resolved test expectations: torch=$torch_test_version cuda=$cuda_test_version" + echo "Resolved PyTorch index: $pytorch_index_url" + echo "Resolved Transformers source: $transformers_source" + echo "Resolved Transformers version: $transformers_version" + echo "Resolved Transformers ref: $transformers_ref" + + - name: Install CUTLASS + run: | + git clone --depth 1 --branch v3.5.1 https://github.com/NVIDIA/cutlass.git /opt/cutlass + echo "CUTLASS installed at /opt/cutlass" + ls -la /opt/cutlass/include/ | head -10 + + - name: Install PyTorch + run: | + pip install \ + torch=="$TORCH_INSTALL_VERSION" \ + torchvision=="$TORCHVISION_INSTALL_VERSION" \ + torchaudio=="$TORCHAUDIO_INSTALL_VERSION" \ + --index-url "$PYTORCH_INDEX_URL" + + - name: Install Transformers + run: | + case "$TRANSFORMERS_SOURCE" in + 'pypi') + pip install "transformers==$TRANSFORMERS_VERSION" + ;; + 'git') + git clone --filter=blob:none https://github.com/huggingface/transformers /tmp/transformers + cd /tmp/transformers + git checkout "$TRANSFORMERS_REF" + resolved_ref="$(git rev-parse HEAD)" + echo "TRANSFORMERS_RESOLVED_REF=$resolved_ref" >> "$GITHUB_ENV" + echo "Resolved Transformers git ref: $resolved_ref" + pip install . + ;; + *) + echo "Unsupported TRANSFORMERS_SOURCE: $TRANSFORMERS_SOURCE" >&2 + exit 1 + ;; + esac + python -c "import transformers; print('transformers:', transformers.__version__, transformers)" + + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/requirements-dev.txt + pip install -r requirements/requirements-deepcompile.txt + pip install pytest-timeout pytest-instafail + + - name: Check environment + run: | + echo "=== Selected PyTorch Preset ===" + echo "Preset: $SELECTED_TORCH_PRESET" + echo "Install tuple: torch==$TORCH_INSTALL_VERSION torchvision==$TORCHVISION_INSTALL_VERSION torchaudio==$TORCHAUDIO_INSTALL_VERSION" + echo "PyTorch index URL: $PYTORCH_INDEX_URL" + echo "Expected test versions: torch=$TORCH_TEST_VERSION cuda=$CUDA_TEST_VERSION" + echo "Transformers source: $TRANSFORMERS_SOURCE" + echo "Transformers version: $TRANSFORMERS_VERSION" + echo "Transformers ref: $TRANSFORMERS_REF" + echo "Transformers resolved ref: ${TRANSFORMERS_RESOLVED_REF:-}" + echo "" + echo "=== GPU Information ===" + nvidia-smi + echo "" + echo "=== CUDA Version ===" + nvcc --version + echo "" + echo "=== Python/PyTorch Info ===" + python --version + python -c "import torch; print(f'PyTorch: {torch.__version__}')" + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + python -c "import torch; print(f'CUDA devices: {torch.cuda.device_count()}')" + python -c "import torch; print(f'BF16 support: {torch.cuda.is_bf16_supported()}')" + echo "" + echo "=== CUTLASS ===" + echo "CUTLASS_PATH: $CUTLASS_PATH" + ls -la "$CUTLASS_PATH"/include/ | head -5 + + - name: Detect GPU architecture + run: | + python - <<'PY' + import os + import torch + + torch.cuda.init() + major, minor = torch.cuda.get_device_capability(0) + arch = f"{major}.{minor}" + gpu_count = torch.cuda.device_count() + gpu_name = torch.cuda.get_device_name(0) + + with open(os.environ["GITHUB_ENV"], "a", encoding="utf-8") as env_file: + env_file.write(f"TORCH_CUDA_ARCH_LIST={arch}\n") + env_file.write(f"GPU_COUNT={gpu_count}\n") + + print(f"Detected GPU: {gpu_name}") + print(f"Detected compute capability: {arch}") + print(f"Detected GPU count: {gpu_count}") + PY + + - name: Install DeepSpeed + run: | + echo "Using TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST" + # Initialize CUDA before install so setup.py can detect NCCL version + python -c "import torch; torch.cuda.init(); print(f'NCCL version: {torch.cuda.nccl.version()}')" + # Use --no-build-isolation so setup.py can access pre-installed PyTorch + pip install --no-build-isolation .[dev,1bit,autotuning,deepcompile] + ds_report + + - name: Reinstall selected Transformers + run: | + case "$TRANSFORMERS_SOURCE" in + 'pypi') + pip install --no-deps --force-reinstall "transformers==$TRANSFORMERS_VERSION" + ;; + 'git') + cd /tmp/transformers + pip install --no-deps --force-reinstall . + ;; + *) + echo "Unsupported TRANSFORMERS_SOURCE: $TRANSFORMERS_SOURCE" >&2 + exit 1 + ;; + esac + python -c "import transformers; print('transformers:', transformers.__version__, transformers)" + + - name: Python environment + run: | + pip list + + - name: Unit tests (parallel) + run: | + echo "Running parallel tests with TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST on $GPU_COUNT GPUs" + cd tests + # Skip tests requiring unavailable hardware or known issues: + # - nvme checkpointing: no nvme device + # - GDS tests: no GPUDirect Storage support + # - launcher user_args: pdsh requires SSH server + # - zenflow: Stage 3 tests have pre-existing bugs + CUDA/fork issues + rm -rf /mnt/aio/pytest + pytest --instafail --timeout 600 --forked -n 8 --basetemp=/mnt/aio/pytest unit/ \ + --ignore=unit/runtime/zero/test_nvme_checkpointing.py \ + --ignore=unit/ops/aio/test_gds.py \ + --ignore=unit/launcher/test_user_args.py \ + --ignore=unit/runtime/zenflow \ + --ignore=unit/ops/adam/test_zf_torch_adam.py \ + --torch_ver="$TORCH_TEST_VERSION" --cuda_ver="$CUDA_TEST_VERSION" + + - name: Unit tests (sequential) + run: | + echo "Running sequential tests with TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST on $GPU_COUNT GPUs" + cd tests + rm -rf /mnt/aio/pytest + pytest --instafail --timeout 600 -m 'sequential' --basetemp=/mnt/aio/pytest unit/ \ + --ignore=unit/runtime/zero/test_nvme_checkpointing.py \ + --ignore=unit/ops/aio/test_gds.py \ + --ignore=unit/launcher/test_user_args.py \ + --ignore=unit/runtime/zenflow \ + --ignore=unit/ops/adam/test_zf_torch_adam.py \ + --ignore=unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py \ + --torch_ver="$TORCH_TEST_VERSION" --cuda_ver="$CUDA_TEST_VERSION" diff --git a/.github/workflows/aws-torch-latest.yml b/.github/workflows/aws-torch-latest.yml new file mode 100644 index 000000000000..e108dda8bed0 --- /dev/null +++ b/.github/workflows/aws-torch-latest.yml @@ -0,0 +1,109 @@ +################################################################################ +# DeepSpeed CI - AWS L40S GPU Tests (PyTorch Latest) +# +# Runs the same tests as modal-torch-latest.yml but on AWS self-hosted runners. +# Uses 4x NVIDIA L40S GPUs on g6e.12xlarge instances. +################################################################################ + +name: aws-torch-latest + +on: + workflow_dispatch: + + push: + branches: + - master + + pull_request: + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + check-paths: + name: aws-torch-latest / check paths + runs-on: ubuntu-latest + outputs: + should_run: ${{ steps.filter.outputs.run_tests }} + steps: + - uses: actions/checkout@v4 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + run_tests: + - '**' + - '!docs/**' + - '!blogs/**' + - '!deepspeed/inference/v2/**' + - '!tests/unit/inference/v2/**' + + unit-tests: + name: aws-torch-latest / unit tests (v1) + needs: check-paths + if: needs.check-paths.outputs.should_run == 'true' + runs-on: [self-hosted, gpu-ci, gpu-l40s, l40s-4gpu, aws] + timeout-minutes: 60 + + container: + image: nvidia/cuda:12.6.3-devel-ubuntu22.04 + options: --gpus all --shm-size "32G" + + env: + TORCH_VER: "2.7" + CUDA_VER: "12.6" + + steps: + - name: Install system dependencies + run: | + apt-get update && apt-get install -y git git-lfs libaio-dev python3 python3-pip + git lfs install + ln -sf /usr/bin/python3 /usr/bin/python + + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Install PyTorch + run: | + pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu126 + + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/requirements-dev.txt + pip install -r requirements/requirements-deepcompile.txt + + - name: Check environment + run: | + echo "=== GPU Information ===" + nvidia-smi + echo "" + echo "=== CUDA Version ===" + nvcc --version + echo "" + echo "=== Python/PyTorch Info ===" + python --version + python -c "import torch; print(f'PyTorch: {torch.__version__}')" + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + python -c "import torch; print(f'CUDA devices: {torch.cuda.device_count()}')" + python -c "import torch; print(f'BF16 support: {torch.cuda.is_bf16_supported()}')" + + - name: Install DeepSpeed + run: | + # Initialize CUDA before install so setup.py can detect NCCL version + python -c "import torch; torch.cuda.init(); print(f'NCCL version: {torch.cuda.nccl.version()}')" + # Use --no-build-isolation so setup.py can access pre-installed PyTorch + pip install --no-build-isolation . + ds_report + # Debug: Check captured torch_info values + python -c "from deepspeed.git_version_info import torch_info; print(f'torch_info: {torch_info}')" + + - name: Run unit tests + run: | + pytest -n 4 --forked --verbose tests/unit/v1/ --torch_ver=${{ env.TORCH_VER }} --cuda_ver=${{ env.CUDA_VER }} diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml new file mode 100644 index 000000000000..7a4952f18987 --- /dev/null +++ b/.github/workflows/cpu-torch-latest.yml @@ -0,0 +1,213 @@ +name: cpu-torch-latest + +on: + workflow_dispatch: + inputs: + torch_preset: + description: PyTorch CPU preset to install for manual runs + required: false + default: '2.10.0-cpu' + type: choice + options: + - '2.7.1-cpu' + - '2.8.0-cpu' + - '2.9.1-cpu' + - '2.10.0-cpu' + transformers_version: + description: Hugging Face Transformers PyPI package version to install + required: false + default: '4.50.0' + type: string + transformers_source: + description: Hugging Face Transformers source for manual runs + required: false + default: 'git' + type: choice + options: + - 'pypi' + - 'git' + transformers_ref: + description: Hugging Face Transformers git ref to install when source is git + required: false + default: 'main' + type: string + pull_request: + paths-ignore: + - 'docs/**' + - 'blogs/**' + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + merge_group: + branches: [ master ] + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + name: cpu-torch-latest / unit tests + runs-on: ubuntu-24.04 + + env: + DEFAULT_TORCH_PRESET: '2.10.0-cpu' + DEFAULT_TRANSFORMERS_SOURCE: 'git' + DEFAULT_TRANSFORMERS_VERSION: '4.50.0' + DEFAULT_TRANSFORMERS_REF: 'main' + + steps: + - uses: actions/checkout@v4 + + - id: setup-venv + uses: ./.github/workflows/setup-venv + + - name: Install system packages + run: | + sudo apt-get install -y numactl pdsh + + - name: Resolve dependency inputs + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + MANUAL_TORCH_PRESET: ${{ github.event.inputs.torch_preset || '' }} + MANUAL_TRANSFORMERS_SOURCE: ${{ github.event.inputs.transformers_source || '' }} + MANUAL_TRANSFORMERS_VERSION: ${{ github.event.inputs.transformers_version || '' }} + MANUAL_TRANSFORMERS_REF: ${{ github.event.inputs.transformers_ref || '' }} + run: | + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TORCH_PRESET" ]; then + selected_preset="$MANUAL_TORCH_PRESET" + else + selected_preset="$DEFAULT_TORCH_PRESET" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_SOURCE" ]; then + transformers_source="$MANUAL_TRANSFORMERS_SOURCE" + else + transformers_source="$DEFAULT_TRANSFORMERS_SOURCE" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_VERSION" ]; then + transformers_version="$MANUAL_TRANSFORMERS_VERSION" + else + transformers_version="$DEFAULT_TRANSFORMERS_VERSION" + fi + + if [ "$GITHUB_EVENT_NAME" = 'workflow_dispatch' ] && [ -n "$MANUAL_TRANSFORMERS_REF" ]; then + transformers_ref="$MANUAL_TRANSFORMERS_REF" + else + transformers_ref="$DEFAULT_TRANSFORMERS_REF" + fi + + if [ "$transformers_source" = 'git' ] && [ -z "$transformers_ref" ]; then + transformers_ref='main' + fi + + case "$selected_preset" in + '2.7.1-cpu') + torch_install_version='2.7.1' + torchvision_install_version='0.22.1' + torch_test_version='2.7' + ;; + '2.8.0-cpu') + torch_install_version='2.8.0' + torchvision_install_version='0.23.0' + torch_test_version='2.8' + ;; + '2.9.1-cpu') + torch_install_version='2.9.1' + torchvision_install_version='0.24.1' + torch_test_version='2.9' + ;; + '2.10.0-cpu') + torch_install_version='2.10.0' + torchvision_install_version='0.25.0' + torch_test_version='2.10' + ;; + *) + echo "Unsupported torch_preset: $selected_preset" >&2 + exit 1 + ;; + esac + + { + echo "SELECTED_TORCH_PRESET=$selected_preset" + echo "TORCH_INSTALL_VERSION=$torch_install_version" + echo "TORCHVISION_INSTALL_VERSION=$torchvision_install_version" + echo "TORCH_TEST_VERSION=$torch_test_version" + echo "PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cpu" + echo "TRANSFORMERS_SOURCE=$transformers_source" + echo "TRANSFORMERS_VERSION=$transformers_version" + echo "TRANSFORMERS_REF=$transformers_ref" + } >> "$GITHUB_ENV" + + echo "Selected PyTorch preset: $selected_preset" + echo "Resolved install tuple: torch==$torch_install_version torchvision==$torchvision_install_version" + echo "Resolved test expectation: torch=$torch_test_version" + echo "Resolved Transformers source: $transformers_source" + echo "Resolved Transformers version: $transformers_version" + echo "Resolved Transformers ref: $transformers_ref" + + - name: Install PyTorch + run: | + pip install \ + torch=="$TORCH_INSTALL_VERSION" \ + torchvision=="$TORCHVISION_INSTALL_VERSION" \ + --index-url "$PYTORCH_INDEX_URL" + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install Transformers + run: | + case "$TRANSFORMERS_SOURCE" in + 'pypi') + pip install "transformers==$TRANSFORMERS_VERSION" + ;; + 'git') + git clone --filter=blob:none https://github.com/huggingface/transformers /tmp/transformers + cd /tmp/transformers + git checkout "$TRANSFORMERS_REF" + resolved_ref="$(git rev-parse HEAD)" + echo "TRANSFORMERS_RESOLVED_REF=$resolved_ref" >> "$GITHUB_ENV" + echo "Resolved Transformers git ref: $resolved_ref" + pip install . + ;; + *) + echo "Unsupported TRANSFORMERS_SOURCE: $TRANSFORMERS_SOURCE" >&2 + exit 1 + ;; + esac + python -c "import transformers; print('transformers:', transformers.__version__, transformers)" + + - name: Install deepspeed + run: | + pip install .[dev,autotuning] + ds_report + + - name: Reinstall selected Transformers + run: | + case "$TRANSFORMERS_SOURCE" in + 'pypi') + pip install --no-deps --force-reinstall "transformers==$TRANSFORMERS_VERSION" + ;; + 'git') + cd /tmp/transformers + pip install --no-deps --force-reinstall . + ;; + *) + echo "Unsupported TRANSFORMERS_SOURCE: $TRANSFORMERS_SOURCE" >&2 + exit 1 + ;; + esac + python -c "import transformers; print('transformers:', transformers.__version__, transformers)" + + - name: Python environment + run: | + pip list + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="$TORCH_TEST_VERSION" + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="$TORCH_TEST_VERSION" diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 6a93b5b418a0..7734d97681d5 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -1,12 +1,12 @@ name: Formatting on: - push: - branches: - - 'staging**' + workflow_dispatch: pull_request: branches: '**' + merge_group: + branches: [ master ] schedule: - cron: "0 0 * * *" @@ -17,23 +17,25 @@ concurrency: jobs: # formatting and basic install on cpu-only machine - formatting: - runs-on: ubuntu-20.04 + unit-tests: + name: formatting / formatting checks + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: environment run: | which python python --version - - name: Install deepspeed + - name: Install dependencies run: | - pip install .[dev,autotuning] - ds_report + # Previously we would do pip install .[dev] but this is causing out of + # space errors start with torch 2.1.0 release + grep -E "clang-format|pre-commit" requirements/requirements-dev.txt | xargs pip install - name: Formatting checks run: | - pip show pre-commit clang-format - pre-commit run --all-files + pip show pre-commit clang-format + pre-commit run --all-files diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml new file mode 100644 index 000000000000..0f304d226a29 --- /dev/null +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -0,0 +1,88 @@ +name: hpu-gaudi2-nightly + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - ".github/workflows/hpu-gaudi2-nightly.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + name: hpu-gaudi2-nightly / unit tests + # The type of runner that the job will run on + runs-on: [self-hosted, intel, gaudi2] + container: + image: vault.habana.ai/gaudi-docker/1.21.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest + ports: + - 80 + options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice + + env: + PT_HPU_LAZY_MODE: 0 + TORCHINDUCTOR_COMPILE_THREADS: 1 + TEST_LIST: | + test_adamw.py + test_bf16.py + test_ds_config_dict.py + test_dynamic_loss_scale.py + test_latest_checkpoint.py + test_moe_checkpoint.py + test_multi_output_model.py + test_other_optimizer.py + test_pipe.py + test_pipeline.py + test_universal_checkpoint.py + test_zero_context_return.py + test_zero_leaf_module.py + test_zero_offloadpp.py + test_zero_tiled.py + test_autotp_training.py + test_ulysses.py + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v4 + + - name: Check container state + run: | + ldd --version + hl-smi -L + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + git rev-parse --short HEAD + pip install . + + - name: Install deepspeed + run: | + pip install .[dev,autotuning] + ds_report + + - name: Python environment + run: | + pip list + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE} + export TORCHINDUCTOR_COMPILE_THREADS=${TORCHINDUCTOR_COMPILE_THREADS} + TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}') + echo "TEST_LIST ${TEST_LIST}" + pytest --verbose unit/ -k "${TEST_LIST}" diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml new file mode 100644 index 000000000000..17e962b5a9e0 --- /dev/null +++ b/.github/workflows/hpu-gaudi2.yml @@ -0,0 +1,139 @@ +name: hpu-gaudi2 + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - ".github/workflows/hpu-gaudi2.yml" + - "accelerator/hpu_accelerator.py" + - "op_builder/hpu/**" + - "deepspeed/runtime/engine.py" + - "deepspeed/runtime/bf16_optimizer.py" + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" + - "deepspeed/runtime/zero/partition_parameters.py" + - "deepspeed/runtime/zero/partitioned_param_coordinator.py" + - "deepspeed/runtime/zero/parameter_offload.py" + - "deepspeed/runtime/pipe/engine.py" + - "deepspeed/runtime/utils.py" + - "deepspeed/inference/engine.py" + - "deepspeed/module_inject/auto_tp.py" + - "deepspeed/module_inject/replace_module.py" + - "deepspeed/module_inject/load_checkpoint.py" + - "deepspeed/module_inject/inject.py" + - "deepspeed/ops/transformer/**" + - "deepspeed/ops/adam/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + name: hpu-gaudi2 / unit tests + # The type of runner that the job will run on + runs-on: [self-hosted, intel, gaudi2] + container: + image: vault.habana.ai/gaudi-docker/1.21.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest + ports: + - 80 + options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice + + env: + PT_HPU_LAZY_MODE: 0 + TORCHINDUCTOR_COMPILE_THREADS: 1 + TEST_LIST: | + test_accelerator.py + test_autotuning.py + test_compression.py + test_dist.py + test_elastic.py + test_ds_arguments.py + test_run.py + test_multinode_runner.py + test_moe_tp.py + test_monitor.py + (test_zero_optimizer.py and (TestSaveTensorClone or TestZeRONonDistributed)) + (test_latest_checkpoint.py and test_missing_latest) + test_reshape_checkpoint.py + test_shared_weights.py + test_sparse.py + test_tag_validation.py + test_pipe_module.py + (test_flops_profiler.py and test_flops_profiler_in_inference) + test_get_optim_files.py + test_groups.py + test_partition_balanced.py + (test_adamw.py and TestAdamConfigs) + test_coalesced_collectives.py + test_activation_checkpointing_non_reentrant.py + test_activation_checkpointing.py + test_data.py + (test_ds_config_dict.py and (TestBasicConfig or TestBatchConfig)) + test_ds_config_model.py + test_mup_optimizers.py + (test_pld.py and test_pld_schedule) + test_runtime_utils.py + test_pipe_schedule.py + test_topology.py + (test_ds_initialize.py and (TestClientOptimizer or TestClientLrScheduler)) + test_csr.py + (test_fp16.py and (TestZeroEmptyGrad or TestZeroAllowUntestedOptimizer)) + (test_bf16.py and TestZeroDtypeCocktail) + test_partition.py + test_ignore_unused_parameters.py + test_zero_config.py + test_zero_context_ancestry.py + (test_zero_context.py and not TestSerialContext) + test_zero_dynamic_class.py + test_zero_nesting_init.py + test_zeropp.py + (test_zero.py and (TestZero3ParamPartitioningLargeParam or TestZero3ParamPartitioningLargeParam)) + (test_linear.py and (TestLoRALinear or TestBasicLinear)) + (test_ctx.py and TestEngine) + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v4 + + - name: Check container state + run: | + ldd --version + hl-smi -L + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 981c276 + git rev-parse --short HEAD + pip install . + + - name: Install deepspeed + run: | + pip install .[dev,autotuning] + ds_report + + - name: Python environment + run: | + pip list + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE} + export TORCHINDUCTOR_COMPILE_THREADS=${TORCHINDUCTOR_COMPILE_THREADS} + TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}') + echo "TEST_LIST ${TEST_LIST}" + pytest --verbose unit/ -k "${TEST_LIST}" diff --git a/.github/workflows/modal-accelerate.yml b/.github/workflows/modal-accelerate.yml new file mode 100644 index 000000000000..bbf93694f6e6 --- /dev/null +++ b/.github/workflows/modal-accelerate.yml @@ -0,0 +1,104 @@ +name: modal-accelerate + +# This CI is running on modal.com's GPUs. +# +# It's set up here on github actions and then the cloned repo is sent to modal and everything +# happens on their hw - see ci/accelerate.py for where the actual vm is loaded, updated and the tests are +# run. +# +# Both files are annotated to what's important and how one might change or update things if needed. +# +# Note that since this is a Required job we can't use `on.push.path` file filter - we are using +# collect-tests job to do the filtering for us so that the job can be skipped and satisfy the +# Required status for PRs to pass. +# + + +on: + workflow_dispatch: + push: + branches: + - master + + # you have to switch to `pull_request` if you need to change the CI job's python script, + # otherwise GH will use a master version of the CI files, ignoring the modifications in the PR - + # the other way is to use modal cli to test this job from one's host - it'd require setting up + # modal secrets + # pull_request: + pull_request_target: + paths-ignore: + - 'docs/**' + - 'blogs/**' + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + types: [review_requested, ready_for_review, synchronize] + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + collect-tests: + name: modal-accelerate / collect tests + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + outputs: + deepspeed: ${{ steps.filter.outputs.deepspeed }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Filter changed files + uses: dorny/paths-filter@v2 + id: filter + with: + token: ${{ secrets.GITHUB_TOKEN }} + filters: | + deepspeed: + - 'deepspeed/**' + - '.github/workflows/modal*.yml' + - 'ci/**' + - 'tests/unit/**' + - 'csrc/**' + + deploy: + name: modal-accelerate / DeepSpeedAI CI + runs-on: ubuntu-latest + needs: collect-tests + env: + # these are created at https://modal.com/settings/deepspeedai/tokens + # they are then added to the repo's secrets at https://github.com/deepspeedai/deepspeed/settings/secrets/actions + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + # this one comes from https://huggingface.co/settings/profile of the bot user + # and it too is then updated at https://github.com/deepspeedai/deepspeed/settings/secrets/actions + HF_TOKEN: ${{ secrets.HF_TOKEN }} + + if: needs.collect-tests.outputs.deepspeed == 'true' + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' # caching pip dependencies + + - name: Install build dependencies + run: | + pip install uv # much faster than pip + uv pip install --system modal + + - name: Run tests + run: | + modal run -m ci.accelerate diff --git a/.github/workflows/modal-torch-latest.yml b/.github/workflows/modal-torch-latest.yml new file mode 100644 index 000000000000..31a0bcd4ca1a --- /dev/null +++ b/.github/workflows/modal-torch-latest.yml @@ -0,0 +1,129 @@ +name: modal-torch-latest + +# This CI is running on modal.com's GPUs. +# +# It's set up here on github actions and then the cloned repo is sent to modal and everything +# happens on their hw - see ci/torch_latest.py for where the actual vm is loaded, updated and the tests are +# run. +# +# Both files are annotated to what's important and how one might change or update things if needed. +# +# Note that since this is a Required job we can't use `on.push.path` file filter - we are using +# collect-tests job to do the filtering for us so that the job can be skipped and satisfy the +# Required status for PRs to pass. +# + + +on: + workflow_dispatch: + inputs: + torch_preset: + description: Modal PyTorch/CUDA image preset for manual runs + required: false + default: '2.10.0-cuda12.8' + type: choice + options: + - '2.7.1-cuda12.8' + - '2.8.0-cuda12.8' + - '2.9.1-cuda12.8' + - '2.10.0-cuda12.8' + - '2.11.0-cuda12.8' + transformers_source: + description: Hugging Face Transformers source for manual runs + required: false + default: 'git' + type: choice + options: + - 'requirements' + - 'git' + transformers_ref: + description: Hugging Face Transformers git ref to install when source is git + required: false + default: 'main' + type: string + + push: + branches: + - master + + pull_request_target: + paths-ignore: + - 'docs/**' + - 'blogs/**' + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + types: [review_requested, ready_for_review, synchronize] + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + collect-tests: + name: modal-torch-latest / collect tests + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + outputs: + deepspeed: ${{ steps.filter.outputs.deepspeed }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Filter changed files + uses: dorny/paths-filter@v3 + id: filter + with: + token: ${{ secrets.GITHUB_TOKEN }} + filters: | + deepspeed: + - 'deepspeed/**' + - '.github/workflows/modal*.yml' + - 'ci/**' + - 'tests/unit/**' + - 'csrc/**' + + deploy: + name: modal-torch-latest / DeepSpeedAI CI + runs-on: ubuntu-latest + needs: collect-tests + env: + # these are created at https://modal.com/settings/deepspeedai/tokens + # they are then added to the repo's secrets at https://github.com/deepspeedai/deepspeed/settings/secrets/actions + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + # this one comes from https://huggingface.co/settings/profile of the bot user + # and it too is then updated at https://github.com/deepspeedai/deepspeed/settings/secrets/actions + HF_TOKEN: ${{ secrets.HF_TOKEN }} + + if: github.event_name == 'workflow_dispatch' || needs.collect-tests.outputs.deepspeed == 'true' + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' # caching pip dependencies + + - name: Install build dependencies + run: | + pip install uv # much faster than pip + uv pip install --system modal + + - name: Run tests + env: + MODAL_TORCH_PRESET: ${{ github.event.inputs.torch_preset || '2.10.0-cuda12.8' }} + MODAL_TRANSFORMERS_SOURCE: ${{ github.event.inputs.transformers_source || 'git' }} + MODAL_TRANSFORMERS_REF: ${{ github.event.inputs.transformers_ref || 'main' }} + run: | + modal run -m ci.torch_latest diff --git a/.github/workflows/no-torch.yml b/.github/workflows/no-torch.yml new file mode 100644 index 000000000000..3caa04531967 --- /dev/null +++ b/.github/workflows/no-torch.yml @@ -0,0 +1,50 @@ +name: no-torch + +on: + workflow_dispatch: + pull_request: + paths: + - 'accelerator/**' + - '.github/workflows/no-torch.yml' + - 'op_builder/**' + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + name: no-torch / source distribution build + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + + - id: setup-venv + uses: ./.github/workflows/setup-venv + + - name: Python environment + run: | + pip uninstall torch --yes + pip install setuptools + pip install build + pip list + + - name: Build deepspeed + run: | + DS_BUILD_STRING=" " python -m build --sdist + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml new file mode 100644 index 000000000000..15ac017d15da --- /dev/null +++ b/.github/workflows/nv-a6000.yml @@ -0,0 +1,75 @@ +name: nv-a6000 + +on: + pull_request: + paths: + - 'accelerator/cuda_accelerator.py' + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + - '.github/workflows/nv-a6000.yml' + workflow_dispatch: + inputs: + mii_branch: + description: 'DeepSpeed-MII Branch' + required: false + default: 'main' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + name: nv-a6000 / inference and MII tests + runs-on: [self-hosted, nvidia, a6000] + container: + image: nvcr.io/nvidia/pytorch:25.01-py3 + ports: + - 80 + options: --gpus all --shm-size "8G" + + steps: + - uses: actions/checkout@v4 + + - name: Check container state + run: | + ldd --version + nvcc --version + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if you need to use an older transformers version temporarily in case of breakage + # git checkout 981c276 + git rev-parse --short HEAD + python -m pip install . + - name: Install deepspeed + run: | + python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja + python -m pip install .[dev,1bit,autotuning,inf] + ds_report + - name: Python environment + run: | + python -m pip list + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.6" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.6" --cuda_ver="12" + - name: MII unit tests + run: | + BRANCH="main" + if [[ ! -z "${{ github.event.inputs.mii_branch }}" ]]; then + BRANCH="${{ github.event.inputs.mii_branch }}" + fi + echo "Cloning DeepSpeed-MII branch: $BRANCH" + git clone -b $BRANCH --depth=1 https://github.com/deepspeedai/DeepSpeed-MII.git + cd DeepSpeed-MII + pip install .[dev] + cd tests + python -m pytest --color=yes --durations=0 --verbose -rF ./ diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 534c79fa4a90..a1b7fd343e0b 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -1,34 +1,22 @@ name: nv-accelerate-v100 -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -43,15 +31,17 @@ jobs: - name: HF Accelerate tests run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch git clone https://github.com/huggingface/accelerate cd accelerate git rev-parse --short HEAD + + # temp workaround until this is resolved https://github.com/huggingface/accelerate/issues/3676 + pip install datasets==3.6.0 + # installing dependencies pip install .[testing] # force protobuf version due to issues pip install "protobuf<4.21.0" - # tmp fix: force newer datasets version - #pip install "datasets>=2.0.0" pip list - HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose tests/deepspeed + pytest $PYTEST_OPTS --color=yes --durations=0 --verbose tests/deepspeed diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml new file mode 100644 index 000000000000..8543501ba074 --- /dev/null +++ b/.github/workflows/nv-ds-chat.yml @@ -0,0 +1,80 @@ +name: nv-ds-chat + +on: + workflow_dispatch: + inputs: + dse_branch: + description: 'DeepSpeedExamples Branch' + required: false + default: 'master' + type: string + pull_request: + paths: + - ".github/workflows/nv-ds-chat.yml" + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" + - "deepspeed/runtime/hybrid_engine.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + name: nv-ds-chat / DeepSpeed-Chat tests + runs-on: [self-hosted, nvidia, cu124, v100] + + steps: + - uses: actions/checkout@v4 + + - id: setup-venv + uses: ./.github/workflows/setup-venv + + - name: Install pytorch + run: | + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install deepspeed + run: | + pip install .[dev] + pip install transformers==4.48.3 + ds_report + + - name: Install deepspeed-chat + run: | + BRANCH="master" + if [[ ! -z "${{ github.event.inputs.dse_branch }}" ]]; then + BRANCH="${{ github.event.inputs.dse_branch }}" + fi + echo "DeepSpeedExamples Branch: $BRANCH" + git clone -b $BRANCH https://github.com/deepspeedai/DeepSpeedExamples.git + cd DeepSpeedExamples/applications/DeepSpeed-Chat + pip install -r requirements.txt + pip install -e . + + - name: Python environment + run: | + pip list + + - name: DS-Chat unit tests + run: | + cd DeepSpeedExamples/applications/DeepSpeed-Chat + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + unset NCCL_DEBUG + cd tests + pytest $PYTEST_OPTS ./ + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml new file mode 100644 index 000000000000..69f6988aa415 --- /dev/null +++ b/.github/workflows/nv-flash-attn.yml @@ -0,0 +1,68 @@ +name: nv-flash-attn + +on: + workflow_dispatch: + pull_request: + paths: + - 'deepspeed/sequence/**' + - 'tests/unit/sequence_parallelism/**' + - '.github/workflows/nv-flash-attn.yml' + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + name: nv-flash-attn / sequence parallelism tests + runs-on: [self-hosted, nvidia, a6000] + container: + image: nvcr.io/nvidia/pytorch:24.12-py3 + ports: + - 80 + options: --gpus all --shm-size "8G" + + steps: + - uses: actions/checkout@v4 + + - name: Check container state + run: | + ldd --version + nvcc --version + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + + + - name: Install deepspeed + run: | + python -m pip install .[dev] + ds_report + + # install transformers after deepspeed so that the right version of transformers is installed + - name: Install transformers + run: | + python -m pip install transformers==4.50.0 + + - name: Install FlashAttention + run: | + python -m pip install flash-attn + - name: Python environment + run: | + python -m pip list + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + python -m pytest --color=yes --durations=0 --verbose -rF unit/sequence_parallelism/test_ulysses.py --torch_ver="2.6" --cuda_ver="12" + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 3da31f4b7994..907c5296ec14 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -1,14 +1,18 @@ name: nv-inference on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' + workflow_dispatch: pull_request: - paths-ignore: - - 'docs/**' + paths: + - '.github/workflows/nv-inference.yml' + - 'requirements/**' + - 'deepspeed/__init__.py' + - 'deepspeed/inference/**' + - '!deepspeed/inference/v2/**' # exclude v2 dir + - 'tests/unit/inference/**' + - '!tests/unit/inference/v2/**' # exclude v2 tests dir + merge_group: + branches: [ master ] schedule: - cron: "0 0 * * *" @@ -18,17 +22,18 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + name: nv-inference / inference tests + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116 + pip install -U --cache-dir $TORCH_CACHE torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -36,12 +41,14 @@ jobs: run: | git clone https://github.com/huggingface/transformers cd transformers + #git checkout f370bebdc git rev-parse --short HEAD pip install . - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning,inf] + DS_ACCELERATOR=cpu pip install .[dev,1bit,autotuning,inf] + #pip install .[dev,1bit,autotuning,inf,triton] ds_report - name: Python environment @@ -51,8 +58,9 @@ jobs: - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" - TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6" - TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" + #pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="12.4" + pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="12.4" + pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="12.4" + # run ds_report again to check updated op list + ds_report diff --git a/.github/workflows/nv-lightning-v100.yml b/.github/workflows/nv-lightning-v100.yml index 68948cdd7296..eeb8516a324d 100644 --- a/.github/workflows/nv-lightning-v100.yml +++ b/.github/workflows/nv-lightning-v100.yml @@ -1,51 +1,53 @@ -name: nv-lightning-v100 - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] - - steps: - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install deepspeed - run: | - pip install .[dev,autotuning] - ds_report - - - name: Python environment - run: | - pip list - - - name: PyTorch Lightning Tests - run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - # Pin pytorch-lightning version to latest pre-2.0.0+ as these require updating the pinned torch versions above. - pip install pytorch-lightning==1.9.4 - pip install "protobuf<4.21.0" - cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose lightning/ +# name: nv-lightning-v100 + +# disabled as the v100s are no more - need to port to modal while removing v100 + +# on: +# workflow_dispatch: +# pull_request: +# paths-ignore: +# - 'docs/**' +# - 'blogs/**' +# - 'deepspeed/inference/v2/**' +# - 'tests/unit/inference/v2/**' +# merge_group: +# branches: [ master ] +# schedule: +# - cron: "0 0 * * *" + +# concurrency: +# group: ${{ github.workflow }}-${{ github.ref }} +# cancel-in-progress: true + +# jobs: +# unit-tests: +# runs-on: [self-hosted, nvidia, cu124, v100] + +# steps: +# - uses: actions/checkout@v4 + +# - id: setup-venv +# uses: ./.github/workflows/setup-venv + +# - name: Install pytorch +# run: | +# pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 +# python -c "import torch; print('torch:', torch.__version__, torch)" +# python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + +# - name: Install deepspeed +# run: | +# pip install .[dev,autotuning] +# ds_report + +# - name: Python environment +# run: | +# pip list + +# - name: PyTorch Lightning Tests +# run: | +# unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch +# pip install pytorch-lightning +# pip install "protobuf<4.21.0" +# cd tests +# pytest $PYTEST_OPTS lightning/ diff --git a/.github/workflows/nv-megatron.yml b/.github/workflows/nv-megatron.yml deleted file mode 100644 index 638b3a68af5a..000000000000 --- a/.github/workflows/nv-megatron.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: nv-megatron - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] - - steps: - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116 - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install deepspeed - run: | - pip install .[dev] - ds_report - - - name: Install apex - run: | - pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex.git - - - name: Python environment - run: | - pip list - - - name: Megatron unit tests - run: | - git clone https://github.com/microsoft/Megatron-DeepSpeed.git - cd Megatron-DeepSpeed - pip install . - unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - MEGATRON_CKPT_DIR=/blob/megatron_ckpt/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose ./ diff --git a/.github/workflows/nv-mii.yml b/.github/workflows/nv-mii.yml index f93cf6c3b376..64b97b080ede 100644 --- a/.github/workflows/nv-mii.yml +++ b/.github/workflows/nv-mii.yml @@ -1,61 +1,54 @@ name: nv-mii -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116 + pip3 install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install deepspeed + run: | + pip install .[dev] + ds_report + + # install transformers after deepspeed so that the right version of transformers is installed - name: Install transformers run: | git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 + git checkout v4.42.4 git rev-parse --short HEAD pip install . - - name: Install deepspeed - run: | - pip install .[dev] - ds_report - - name: Python environment run: | pip list - name: MII unit tests run: | - git clone https://github.com/microsoft/DeepSpeed-MII.git + BRANCH="main" + if [[ ! -z "${{ github.event.inputs.mii_branch }}" ]]; then + BRANCH="${{ github.event.inputs.mii_branch }}" + fi + echo "Cloning DeepSpeed-MII branch: $BRANCH" + git clone -b $BRANCH --depth=1 https://github.com/deepspeedai/DeepSpeed-MII.git cd DeepSpeed-MII pip install .[dev] unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m "CPU or local" ./ + cd tests/legacy + pytest $PYTEST_OPTS --forked -m "deepspeed" ./ diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml index f70da73916c3..670b0c4eda44 100644 --- a/.github/workflows/nv-nightly.yml +++ b/.github/workflows/nv-nightly.yml @@ -1,26 +1,26 @@ name: nv-nightly -on: - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + contents: read + issues: write + jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -29,10 +29,14 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 + git checkout v4.42.4 git rev-parse --short HEAD pip install . + - name: Install datasets + run: | + pip install datasets + - name: Install deepspeed run: | pip install .[dev,1bit,autotuning,inf] @@ -45,6 +49,14 @@ jobs: - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.6" + pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="2.6" --cuda_ver="12.4" + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml new file mode 100644 index 000000000000..3e82eddaba26 --- /dev/null +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -0,0 +1,52 @@ +name: nv-pre-compile-ops + +on: + workflow_dispatch: + pull_request: + branches: + '**' + paths-ignore: + - 'docs/**' + - 'blogs/**' + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + merge_group: + branches: [ master ] + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + name: nv-pre-compile-ops / precompile ops + runs-on: ubuntu-24.04 + container: + image: nvidia/cuda:12.6.3-devel-ubuntu22.04 + + steps: + - name: Install system dependencies + run: | + apt-get update && apt-get install -y git python3 python3-pip libaio-dev ninja-build + ln -sf /usr/bin/python3 /usr/bin/python + + - uses: actions/checkout@v4 + + - name: Install PyTorch + run: | + pip install torch==2.10.0 --index-url https://download.pytorch.org/whl/cu126 + + - name: environment + run: | + which python + python --version + python -c "import torch; print('torch:', torch.__version__, torch)" + #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Compile DeepSpeed Ops + run: | + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install . + - name: DS Report + run: | + DS_ACCELERATOR=cuda ds_report diff --git a/.github/workflows/nv-sd.yml b/.github/workflows/nv-sd.yml new file mode 100644 index 000000000000..f19ff8e72ab4 --- /dev/null +++ b/.github/workflows/nv-sd.yml @@ -0,0 +1,73 @@ +name: nv-sd + +on: + workflow_dispatch: + pull_request: + paths: + - "deepspeed/ops/transformer/inference/diffusers_**" + - "tests/unit/inference/test_stable_diffusion.py" + - "deepspeed/model_implementations/diffusers/unet.py" + - "deepspeed/model_implementations/diffusers/vae.py" + - "deepspeed/module_inject/containers/vae.py" + - "deepspeed/module_inject/containers/unet.py" + - ".github/workflows/nv-sd.yml" + - "requirements/requirements-sd.txt" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + sd-tests: + name: nv-sd / stable diffusion tests + runs-on: [self-hosted, nvidia, a6000] + container: + image: nvcr.io/nvidia/pytorch:24.03-py3 + ports: + - 80 + options: --gpus all --shm-size "8G" + + steps: + - uses: actions/checkout@v4 + + - name: Check container state + run: | + ldd --version + nvcc --version + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + git rev-parse --short HEAD + python -m pip install . + - name: Install deepspeed + run: | + pip install image-similarity-measures + python -m pip install opencv-python==4.6.* --force-reinstall + python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja + python -m pip install .[dev,1bit,autotuning,sd] + ds_report + - name: Python environment + run: | + python -m pip list + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.3" --cuda_ver="12" + + - name: Open GitHub issue if weekly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml deleted file mode 100644 index 6e33212f3cba..000000000000 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: nv-torch-latest-cpu - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-tests: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install deepspeed - run: | - pip install .[dev,autotuning] - ds_report - - - name: Python environment - run: | - pip list - - - name: Unit tests - run: | - unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 unit/ --torch_ver="1.12" - TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/ --torch_ver="1.12" diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index c856a3054bfa..9dbdb024ffac 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -1,34 +1,22 @@ name: nv-torch-latest-v100 -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -37,13 +25,14 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 + git checkout 981c276 git rev-parse --short HEAD pip install . - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning] + pip install .[dev,1bit,autotuning,deepcompile] + pip install pytest-timeout pytest-instafail ds_report - name: Python environment @@ -53,7 +42,6 @@ jobs: - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="2.0" --cuda_ver="11.7" - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/ --torch_ver="2.0" --cuda_ver="11.7" + pytest -x $PYTEST_OPTS --instafail --timeout 600 --forked -n 8 unit/ --torch_ver="2.6" --cuda_ver="12.4" + pytest $PYTEST_OPTS --instafail --timeout 600 --forked -m 'sequential' unit/ --torch_ver="2.6" --cuda_ver="12.4" diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index 87a861683f02..f88951b9b03e 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -1,26 +1,26 @@ name: nv-torch-nightly-v100 -on: - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + contents: read + issues: write + jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv - name: Install pytorch run: | - pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -29,7 +29,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 + # git checkout 981c276 git rev-parse --short HEAD pip install . @@ -45,7 +45,15 @@ jobs: - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/ + pytest $PYTEST_OPTS --forked -n 8 unit/ + pytest $PYTEST_OPTS --forked -m 'sequential' unit/ + + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true diff --git a/.github/workflows/nv-torch19-p40.yml b/.github/workflows/nv-torch19-p40.yml deleted file mode 100644 index 06363d3ae688..000000000000 --- a/.github/workflows/nv-torch19-p40.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: nv-torch19-p40 - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-tests: - runs-on: [self-hosted, nvidia, cu111, p40] - - steps: - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install transformers - run: | - git clone https://github.com/huggingface/transformers - cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 - git rev-parse --short HEAD - pip install . - - - name: Install deepspeed - run: | - pip install .[dev,1bit,autotuning] - ds_report - - - name: Python environment - run: | - pip list - - - name: Unit tests - run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11.1" diff --git a/.github/workflows/nv-torch19-v100.yml b/.github/workflows/nv-torch19-v100.yml deleted file mode 100644 index be37e4069881..000000000000 --- a/.github/workflows/nv-torch19-v100.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: nv-torch19-v100 - -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] - - steps: - - uses: actions/checkout@v2 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install transformers - run: | - git clone https://github.com/huggingface/transformers - cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 - git rev-parse --short HEAD - pip install . - - - name: Install deepspeed - run: | - pip install .[dev,1bit,autotuning] - ds_report - - - name: Python environment - run: | - pip list - - - name: Unit tests - run: | - unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11" - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/ --torch_ver="1.9" --cuda_ver="11" diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 4d59bece9a44..e9326613273f 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -1,27 +1,15 @@ name: nv-transformers-v100 -on: - push: - branches: - - 'staging**' - paths-ignore: - - 'docs/**' - pull_request: - paths-ignore: - - 'docs/**' - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] + runs-on: [self-hosted, nvidia, cu124, v100] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv @@ -29,10 +17,19 @@ jobs: - name: Install pytorch run: | # use the same pytorch version as transformers CI - pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/torch_stable.html + pip install -U --cache-dir $TORCH_CACHE torch==2.0.1+cu124 --index-url https://download.pytorch.org/whl/cu124 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + git checkout e7e9261a2 + git rev-parse --short HEAD + pip install . + - name: Install deepspeed run: | pip install .[dev,autotuning] @@ -44,19 +41,12 @@ jobs: - name: HF transformers tests run: | - if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi - git clone https://github.com/huggingface/transformers + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - #git checkout 6268694e2 - git rev-parse --short HEAD - # scipy/sklearn required for tests, using the 'dev' extra forces torch re-install pip install .[testing] # find reqs used in ds integration tests find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec grep -v 'torch' {} \; | xargs -I {} pip install --upgrade {} - # force datasets version due to issues - pip install datasets==2.2.2 # force protobuf version due to issues pip install "protobuf<4.21.0" pip list - HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ WANDB_DISABLED=true TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed + WANDB_DISABLED=true RUN_SLOW=1 pytest $PYTEST_OPTS tests/deepspeed diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index cd79cd4aab6c..f9a2d9f3fb2d 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -1,12 +1,15 @@ name: python on: - push: - branches: - - 'staging**' + workflow_dispatch: pull_request: branches: '**' + paths-ignore: + - 'docs/**' + - 'blogs/**' + merge_group: + branches: [ master ] schedule: - cron: "0 0 * * *" @@ -15,26 +18,33 @@ concurrency: cancel-in-progress: true jobs: - version-check: + unit-tests: + name: python / install smoke (Python ${{ matrix.pyVersion }}) strategy: matrix: - pyVersion: ["3.6", "3.7", "3.8", "3.9", "3.10"] + pyVersion: ["3.10", "3.11", "3.12"] fail-fast: false - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 container: - image: deepspeed/gh-builder:py${{ matrix.pyVersion }} + image: python:${{ matrix.pyVersion }}-slim steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Install build dependencies + run: | + apt-get update && apt-get install -y build-essential ninja-build - name: environment run: | which python python --version + - name: Install PyTorch (CPU) + run: | + pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Install deepspeed run: | - pip3 install . + pip install . - name: DS Report run: | ds_report diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000000..4bddbc26be4a --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,52 @@ +name: Build and publish DeepSpeed release + +on: + push: + tags: + - 'v*.*.*' + +jobs: + deploy: + runs-on: ubuntu-24.04 + environment: release-env + + steps: + - uses: actions/checkout@v4 + with: + ref: "master" + - id: setup-venv + uses: ./.github/workflows/setup-venv + - name: Get release version from tag + run: | + echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV + - name: Check release version + run: | + pip install packaging + python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }} + - name: Build DeepSpeed + run: | + pip install setuptools + pip install build + DS_BUILD_STRING=" " python -m build --sdist + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} + repository-url: https://upload.pypi.org/legacy/ + - name: Bump version + run: | + python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }} + - name: Create Pull Request + uses: peter-evans/create-pull-request@v6 + with: + token: ${{ secrets.GH_PAT }} + add-paths: | + version.txt + body: | + **Auto-generated PR to update version.txt after a DeepSpeed release** + Released version - ${{ env.RELEASE_VERSION }} + Author - @${{ github.actor }} + branch: AutoPR/${{ env.RELEASE_VERSION }} + assignees: ${{ github.actor }} + title: "Update version.txt after ${{ env.RELEASE_VERSION }} release" + author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com> diff --git a/.github/workflows/setup-venv/action.yml b/.github/workflows/setup-venv/action.yml index dacd50b8d471..2556d0264efd 100644 --- a/.github/workflows/setup-venv/action.yml +++ b/.github/workflows/setup-venv/action.yml @@ -6,17 +6,32 @@ runs: - id: update-env run: | sudo apt-get update - sudo apt-get install -y libaio-dev + # Temporary disable nvme UTs + # sudo apt-get install -y libaio-dev + sudo apt remove -y libaio-dev python -m pip install --user --upgrade pip python -m pip install --user --upgrade virtualenv shell: bash - id: create-venv run: | + rm -rf ./unit-test-venv python -m venv unit-test-venv source ./unit-test-venv/bin/activate python -m pip install --upgrade pip + pip install wheel # required after pip>=23.1 echo PATH=$PATH >> $GITHUB_ENV # Make it so venv is inherited for other steps shell: bash + - id: set-env-vars + run: | + echo TEST_DATA_DIR=/blob/ >> $GITHUB_ENV + echo HF_HOME=/blob/hf_home/ >> $GITHUB_ENV + echo TORCH_EXTENSIONS_DIR=./torch-extensions/ >> $GITHUB_ENV + echo TORCH_CACHE=/blob/torch_cache/ >> $GITHUB_ENV + echo HF_DATASETS_CACHE=/blob/datasets_cache/ >> $GITHUB_ENV + echo MEGATRON_CKPT_DIR=/blob/megatron_ckpt/ >> $GITHUB_ENV + echo CRITIC_CKPT_DIR=/blob/step2_opt_125m_ckpt/ >> $GITHUB_ENV + echo PYTEST_OPTS="--maxfail=100 --color=yes --durations=0 --verbose -rF" >> $GITHUB_ENV + shell: bash - id: print-env run: | which python diff --git a/.github/workflows/xpu-compile.yml b/.github/workflows/xpu-compile.yml new file mode 100644 index 000000000000..d33d22faa4a4 --- /dev/null +++ b/.github/workflows/xpu-compile.yml @@ -0,0 +1,62 @@ +name: xpu-compile + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - ".github/workflows/xpu-compile.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + compile-tests: + name: xpu-compile / compile tests + runs-on: [self-hosted, intel, xpu] + container: + image: intel/oneapi-basekit:2025.0.2-0-devel-ubuntu22.04 + ports: + - 80 + options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL + + steps: + - uses: actions/checkout@v4 + - name: Install prerequisite + run: | + apt-get update + apt-get install clinfo libaio-dev python3-pip -y + pip install torch==2.10.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu + pip install py-cpuinfo numpy + pip install .[dev,autotuning] + + - name: Check container state + run: | + ldd --version + ds_report + python3 -c "import torch; print('torch:', torch.__version__, torch)" + python3 -c "import torch; print('XPU available:', torch.xpu.is_available())" + python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)" + pip list + + - name: Compile Status + shell: bash + run: | + echo "# torch.compile graph breaks" >> $GITHUB_STEP_SUMMARY + export FI_HMEM=system + ulimit -n 1048575 + cd tests/torch_compile + export ZE_AFFINITY_MASK=0,1 + echo "## ZeRO stage 3" >> $GITHUB_STEP_SUMMARY + deepspeed test_compile.py --deepspeed_config ds_config_z3.json 2>&1 | tee log_z3.txt + # for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk + cat log_z3.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY + echo "## ZeRO stage 2" >> $GITHUB_STEP_SUMMARY + deepspeed test_compile.py --deepspeed_config ds_config_z2.json 2>&1 | tee log_z2.txt + cat log_z2.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/xpu-max1100.yml b/.github/workflows/xpu-max1100.yml new file mode 100644 index 000000000000..b9768caf6d26 --- /dev/null +++ b/.github/workflows/xpu-max1100.yml @@ -0,0 +1,85 @@ +name: xpu-max1100 + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - ".github/workflows/xpu-max1100.yml" + - "accelerator/xpu_accelerator.py" + - "accelerator/abstract_accelerator.py" + - "accelerator/cpu_accelerator.py" + - "accelerator/real_accelerator.py" + - "csrc/xpu/**" + - "deepspeed/runtime/engine.py" + - "deepspeed/runtime/bf16_optimizer.py" + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" + - "deepspeed/runtime/zero/partition_parameters.py" + - "deepspeed/runtime/zero/partitioned_param_coordinator.py" + - "deepspeed/runtime/zero/parameter_offload.py" + - "deepspeed/runtime/pipe/engine.py" + - "deepspeed/runtime/utils.py" + - "op_builder/xpu/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + + +jobs: + unit-tests: + name: xpu-max1100 / unit tests + runs-on: [self-hosted, intel, xpu] + container: + image: intel/oneapi-basekit:2025.0.2-0-devel-ubuntu22.04 + ports: + - 80 + options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL + + steps: + - uses: actions/checkout@v4 + - name: Install prerequisite + shell: bash + run: | + apt-get update + apt-get install -y python3.11 python3.11-dev python3-pip clinfo libaio-dev + pip install --upgrade pip + pip install py-cpuinfo + pip install torch==2.10.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu + pip install .[dev,autotuning] + + - name: Check container state + shell: bash + run: | + ldd --version + ds_report + python3 -c "import torch; print('torch:', torch.__version__, torch)" + python3 -c "import torch; print('XPU available:', torch.xpu.is_available())" + python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)" + pip list + + - name: Unit tests + shell: bash + run: | + cd tests/unit + export FI_PROVIDER="tcp" + export I_MPI_SHM=off + pytest --verbose accelerator/* + pytest --verbose autotuning/* + pytest --verbose model_parallelism/* + pytest --verbose monitor/* + pytest --verbose utils/* + pytest --verbose runtime/test_ds_config_model.py + pytest --verbose runtime/pipe/test_pipe_schedule.py + pytest --verbose runtime/zero/test_zero_config.py + pytest --verbose runtime/zero/test_zero_tiled.py + pytest --verbose runtime/zero/test_zeropp.py + pytest --verbose runtime/test_autocast.py + pytest --verbose runtime/test_data.py + pytest --verbose runtime/zero/test_zero_dynamic_class.py diff --git a/.gitignore b/.gitignore index ab364ad8a7e7..8c29d0c23166 100644 --- a/.gitignore +++ b/.gitignore @@ -1,31 +1,65 @@ +## Ignore Python compiled files *.pyc + +## Ignore IDE-specific files and directories +# JetBrains IDE settings .idea/ +# Visual Studio Code settings +.vscode/ +# Theia IDE settings +.theia/ + +## Ignore temporary and backup files +# General backup files *~ +# Vim swap files *.swp + +## Ignore log files *.log + +## Ignore a specific generated file deepspeed/git_version_info_installed.py + +## Ignore Python bytecode cache __pycache__ -# Build + installation data +## Build + installation data +# Build artifacts build/ +# Distribution files dist/ +# Compiled shared objects *.so +# Deepspeed package info deepspeed.egg-info/ +# Build information build.txt -# Website +## Website generated files +# Jekyll generated site docs/_site/ +# Generated documentation docs/build docs/code-docs/source/_build docs/code-docs/_build docs/code-docs/build +# SASS cache .sass-cache/ +# Jekyll cache .jekyll-cache/ .jekyll-metadata -# Testing data +## Testing data +# Saved checkpoints for testing tests/unit/saved_checkpoint/ -# Dev/IDE data -.vscode -.theia +# HIP files created during AMD compilation +*_hip.cpp +*_hip.h +*.hip +*.cuh +*hip_layers.h + +# virtual env directory for format +venv diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0250f243178..9a7bb1c9b371 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-useless-excludes - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v5.0.0 hooks: - id: check-case-conflict - id: check-json @@ -23,11 +23,11 @@ repos: - id: trailing-whitespace - repo: https://github.com/google/yapf - rev: v0.32.0 + rev: v0.40.0 hooks: - id: yapf -- repo: https://gitlab.com/daverona/pre-commit-cpp +- repo: https://gitlab.com/daverona/pre-commit/cpp rev: 0.8.0 hooks: - id: clang-format # formatter of C/C++ code based on a style guide: LLVM, Google, Chromium, Mozilla, and WebKit available @@ -38,7 +38,7 @@ repos: - id: check-torchdist name: check-torchdist entry: ./scripts/check-torchdist.py - language: script + language: python exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py) # Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm @@ -47,8 +47,9 @@ repos: - id: check-license name: check-license entry: ./scripts/check-license.py - language: script + language: python files: \.(py|c|cpp|cu|cc|h|hpp|cuh|hip|tr)$ + exclude: ^(deepspeed/inference/v2/kernels/ragged_ops/blocked_flash|deepspeed/inference/v2/kernels/cutlass_ops/grouped_gemm) - repo: https://github.com/codespell-project/codespell rev: v2.1.0 @@ -58,22 +59,31 @@ repos: # Do not check files that are automatically generated '--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json', '--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word - '--ignore-words-list=unsupport', # Word used in error messages that need rewording + '--ignore-words-list=youn,unsupport,noe,cann', # Word used in error messages that need rewording --check-filenames, --check-hidden ] - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 - args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401'] + args: ['--config=.flake8'] - repo: local hooks: - id: check-torchcuda name: check-torchcuda entry: ./scripts/check-torchcuda.py - language: script - exclude: ^(.github/workflows/|scripts/check-torchcuda.py|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py) + language: python + exclude: ^(.github/workflows/|scripts/check-torchcuda.py|docs/_tutorials/accelerator-abstraction-interface.md|docs/_tutorials/deepnvme.md|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py) # Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm + +- repo: local + hooks: + - id: check-extraindexurl + name: check-extraindexurl + entry: ./scripts/check-extraindexurl.py + language: python + files: \.(yml|yaml|sh|py)$ + exclude: ^(scripts/check-extraindexurl.py) diff --git a/.readthedocs.yml b/.readthedocs.yml index a2da36620152..91102a7de54b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,6 +1,9 @@ - # Required version: 2 +build: + os: "ubuntu-22.04" + tools: + python: "3.8" # Build documentation in the docs/ directory with Sphinx sphinx: @@ -13,6 +16,5 @@ formats: # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 install: - requirements: requirements/requirements-readthedocs.txt diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000000..1891c38295be --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,32 @@ + +# AGENTS.md — Workspace-level instructions for AI coding agents + +## DeepSpeed Project Rules + +### Commit & CI requirements + +- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`. +- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`). +- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files `. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`. +- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead. +- New files require license header: + ``` + # SPDX-License-Identifier: Apache-2.0 + # DeepSpeed Team + ``` + +### Code change discipline + +- NEVER make cosmetic/formatting-only changes to existing code. Only add/modify lines that are functionally necessary. Minimizing diff noise is critical for code review. +- Delete dead code decisively — if code is unused at runtime (only referenced in tests), remove it along with its tests. +- Prefer consolidating tests over proliferating test files. +- Blend in: when modifying code, read the surrounding context and match the style of neighboring code (naming, spacing, patterns, idioms). +- Write beginner-friendly code: avoid deeply nested expressions or chained logic. Break complex expressions into clear, named intermediate steps. +- Comments should explain **why**, not **what**. Describe the purpose and reasoning, not the mechanics that the code already shows. +- New features must include corresponding tests and documentation updates. + +## Tool Caveats + +### Edit tool auto-formatter + +The Edit tool has a hidden auto-formatter that silently changes quotes, whitespace, blank lines, and line wrapping. For format-sensitive modifications (e.g., when exact formatting matters for pre-commit), use `bash` with `sed`, `python`, or `cat` instead. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000000..1891c38295be --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ + +# AGENTS.md — Workspace-level instructions for AI coding agents + +## DeepSpeed Project Rules + +### Commit & CI requirements + +- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`. +- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`). +- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files `. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`. +- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead. +- New files require license header: + ``` + # SPDX-License-Identifier: Apache-2.0 + # DeepSpeed Team + ``` + +### Code change discipline + +- NEVER make cosmetic/formatting-only changes to existing code. Only add/modify lines that are functionally necessary. Minimizing diff noise is critical for code review. +- Delete dead code decisively — if code is unused at runtime (only referenced in tests), remove it along with its tests. +- Prefer consolidating tests over proliferating test files. +- Blend in: when modifying code, read the surrounding context and match the style of neighboring code (naming, spacing, patterns, idioms). +- Write beginner-friendly code: avoid deeply nested expressions or chained logic. Break complex expressions into clear, named intermediate steps. +- Comments should explain **why**, not **what**. Describe the purpose and reasoning, not the mechanics that the code already shows. +- New features must include corresponding tests and documentation updates. + +## Tool Caveats + +### Edit tool auto-formatter + +The Edit tool has a hidden auto-formatter that silently changes quotes, whitespace, blank lines, and line wrapping. For format-sensitive modifications (e.g., when exact formatting matters for pre-commit), use `bash` with `sed`, `python`, or `cat` instead. diff --git a/CODEOWNERS b/CODEOWNERS index 2410b3ebc09b..b0d3b8b0d77b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,50 +7,53 @@ # top-level repo folders -/.github/ @jeffra @mrwyattii @loadams -/azure/ @jeffra @awan-10 -/benchmarks/ @jeffra @awan-10 @mrwyattii @molly-smith -/bin/ @jeffra -/csrc/ @RezaYazdaniAminabadi @awan-10 @jeffra @cmikeh2 @arashb -/deepspeed/ @jeffra -/docker/ @jeffra @awan-10 -/docs/ @jeffra @mrwyattii -/examples/ @jeffra @awan-10 @mrwyattii -/op_builder/ @jeffra @RezaYazdaniAminabadi @cmikeh2 -/release/ @jeffra @mrwyattii -/requirements/ @jeffra @mrwyattii -/scripts/ @jeffra @awan-10 -/tests/ @jeffra @mrwyattii @tjruwase +/.github/ @loadams +/azure/ @loadams +/benchmarks/ @guanhuawang @tjruwase +/bin/ @loadams +/csrc/ @tjruwase +/deepspeed/ @loadams @tjruwase +/docker/ @loadams @guanhuawang +/docs/ @loadams @tjruwase +/examples/ @jomayeri @tohtana +/op_builder/ @loadams @tjruwase @jomayeri +/release/ @loadams @jomayeri +/requirements/ @loadams +/scripts/ @loadams @tjruwase +/tests/ @tjruwase @loadams @tohtana # deepspeed -/deepspeed/autotuning/ @cli99 +/deepspeed/autotuning/ @loadams /deepspeed/checkpoint/ @tjruwase -/deepspeed/comm/ @awan-10 -/deepspeed/compression/ @yaozhewei @minjiaz @xiaoxiawu-microsoft @conglongli -/deepspeed/elasticity/ @jeffra @awan-10 -/deepspeed/launcher/ @jeffra @awan-10 -/deepspeed/module_inject/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb -/deepspeed/moe/ @awan-10 -/deepspeed/monitor/ @awan-10 @jeffra -/deepspeed/nebula/ @tjruwase @jeffra -/deepspeed/ops/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb -/deepspeed/pipe/ @ShadenSmith @duli2012 -/deepspeed/profiling/ @cli99 -/deepspeed/utils/ @jeffra @tjruwase @awan-10 +/deepspeed/comm/ @guanhuawang +/deepspeed/compression/ @tjruwase +/deepspeed/elasticity/ @tjruwase +/deepspeed/launcher/ @loadams +/deepspeed/module_inject/ @hwchen2017 @loadams +/deepspeed/moe/ @tohtana +/deepspeed/monitor/ @tjruwase +/deepspeed/nebula/ @tjruwase +/deepspeed/nvme/ @tjruwase @jomayeri +/deepspeed/ops/ @tohtana +/deepspeed/pipe/ @tohtana @loadams +/deepspeed/profiling/ @loadams +/deepspeed/sequence/ @tohtana +/deepspeed/utils/ @tjruwase @tohtana # inference -/deepspeed/inference/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb -/deepspeed/model_implementations/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb +/deepspeed/inference/ @hwchen2017 @tohtana +/deepspeed/model_implementations/@tohtana @loadams # training -/deepspeed/runtime/ @jeffra @tjruwase -/deepspeed/runtime/activation_checkpointing/ @jeffra @tjruwase -/deepspeed/runtime/checkpoint_engine/ @tjruwase @jeffra -/deepspeed/runtime/comm/ @awan-10 -/deepspeed/runtime/compression/ @awan-10 @conglongli -/deepspeed/runtime/data_pipeline/ @conglongli -/deepspeed/runtime/fp16/ @jeffra @tjruwase -/deepspeed/runtime/fp16/onebit/ @conglongli @awan-10 -/deepspeed/runtime/pipe/ @ShadenSmith @duli2012 -/deepspeed/runtime/swap_tensor/ @tjruwase @mrwyattii -/deepspeed/runtime/zero/ @jeffra @tjruwase @samyam @mrwyattii +/deepspeed/runtime/ @tjruwase @tohtana +/deepspeed/runtime/activation_checkpointing/ @tjruwase +/deepspeed/runtime/checkpoint_engine/ @tjruwase +/deepspeed/runtime/comm/ @guanhuawang +/deepspeed/runtime/compression/ @tjruwase +/deepspeed/runtime/data_pipeline/ @tjruwase +/deepspeed/runtime/domino/ @guanhuawang @hwchen2017 +/deepspeed/runtime/fp16/ @tjruwase @tohtana +/deepspeed/runtime/fp16/onebit/ @tjruwase +/deepspeed/runtime/pipe/ @loadams @tohtana +/deepspeed/runtime/swap_tensor/ @tjruwase @jomayeri +/deepspeed/runtime/zero/ @tjruwase @tohtana diff --git a/COMMITTERS.md b/COMMITTERS.md new file mode 100644 index 000000000000..b97bb9d22c4c --- /dev/null +++ b/COMMITTERS.md @@ -0,0 +1,13 @@ +# DeepSpeed TSC Committers # + +| Name | GitHub ID | Affiliation +|--- | ---- | --- | +| Olatunji Ruwase | [tjruwase](https://github.com/tjruwase) | SnowFlake | +| Logan Adams | [loadams](https://github.com/loadams) | Microsoft | +| Masahiro Tanaka | [tohtana](https://github.com/tohtana) | Anyscale | +| Jeff Rasley | [jeffra](https://github.com/jeffra) | SnowFlake | +| Minjia Zhang | [minjiazhang](https://github.com/minjiazhang) | UIUC | +| Ashwin Aji | [ashwinma](https://github.com/ashwinma) | AMD | +| Sam Foreman | [saforem2](https://github.com/saforem2) | Argonne National Laboratory | +| Zhipeng Wang | [PKUWZP](https://github.com/PKUWZP) | LinkedIn | +| Guokai Ma | [delock](https://github.com/delock) | Intel | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f6e5f39869eb..b03a498144a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,17 +13,23 @@ pre-commit install Afterwards, our suite of formatting tests run automatically before each `git commit`. You can also run these manually: ```bash -pre-commit run --all-files +pre-commit run --files $(git diff --name-only master) ``` If a formatting test fails, it will fix the modified code in place and abort the `git commit`. After looking over the changes, you can `git add ` and then repeat the previous `git commit` command. +You can also run: +``` +make format +``` +which will do the same as above, and it'll also automatically build a `venv` python environment if you +don't already have one, which will isolate the requirements of this project from requirements of other projects. ## Testing DeepSpeed tracks two types of tests: unit tests and more costly model convergence tests. The model convergence tests train -[DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/) and measure +[DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/) and measure end-to-end convergence and related metrics. Unit tests are found in `tests/unit/` and the model convergence tests are found in `tests/model/`. @@ -38,9 +44,14 @@ You can also provide the `-v` flag to `pytest` to see additional information abo tests. Note that [pytest-forked](https://github.com/pytest-dev/pytest-forked) and the `--forked` flag are required to test CUDA functionality in distributed tests. +You can also run: +``` +make test +``` + ### Model Tests To execute model tests, first [install DeepSpeed](#installation). The -[DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/) repository is cloned +[DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/) repository is cloned as part of this process. Next, execute the model test driver: ```bash cd tests/model/ @@ -48,16 +59,15 @@ pytest run_sanity_check.py ``` Note that the `--forked` flag is not necessary for the model tests. -## Contributor License Agreement -This project welcomes contributions and suggestions. Most contributions require you to -agree to a Contributor License Agreement (CLA) declaring that you have the right to, and -actually do, grant us the rights to use your contribution. For details, visit -https://cla.opensource.microsoft.com. +## Developer Certificate of Origin +This project welcomes contributions and suggestions. All contributions to deepspeedai projects +require commits to be signed off with a [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) +(DCO) declaring that you have the right to, and actually do, grant us the rights to use your contribution. + +When you submit a pull request, the DCO app will check for the presence of signed commits. +Information about how this check works is here: https://github.com/dcoapp/app?tab=readme-ov-file#how-it-works -When you submit a pull request, a CLA bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply -follow the instructions provided by the bot. You will only need to do this once across -all repos using our CLA. +To sign commits, you will need to include `-s` when running `git commit`. For example, `git commit -s -m "Commit message"`. One note, creating PRs via the GitHub interface do not appear to include this option. If you forget this, clicking on the failing check in your PR will point you to commands you can run to rebase and sign previous commits. ## Code of Conduct This project has adopted the [Microsoft Open Source Code of @@ -85,8 +95,8 @@ Based on the issue we shall discuss the merit of the new feature and decide whet ### Step 2: implementation and verification Contributor will go ahead and implement the feature, and the DeepSpeed team will provide guidance/helps as needed. The required deliverables include: -* A PR to [microsoft/DeepSpeed](https://github.com/microsoft/DeepSpeed) including (1) the feature implementation (2) unit tests (3) documentation (4) tutorial -* A PR to [microsoft/DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) or [microsoft/Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed) including the examples of how to use the feature (this is related to the planned testing experiments in proposal) +* A PR to [deepspeedai/DeepSpeed](https://github.com/deepspeedai/DeepSpeed) including (1) the feature implementation (2) unit tests (3) documentation (4) tutorial +* A PR to [deepspeedai/DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) or [deepspeedai/Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) including the examples of how to use the feature (this is related to the planned testing experiments in proposal) * In the implementation (code, documentation, tutorial), we require the feature author to record their GitHub username as a contact method for future questions/maintenance. After receiving the PRs, we will review them and merge them after necessary tests/fixes. diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 000000000000..d488ec55114e --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,101 @@ + +# DeepSpeed Project Charter and Governance + +This charter sets forth the responsibilities and procedures for technical contribution to, and oversight of, the DeepSpeed open source project. All contributors (including committers, maintainers, and other technical positions) and other participants in the Project (collectively, "Collaborators") must comply with the terms of this Charter. + +## Mission and Scope of the Project + +The mission of the Project is to DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective. + +The scope of the Project includes collaborative development under the Project License (as defined herein) supporting the mission, including documentation, testing, integration, and the creation of other artifacts that aid the development, deployment, operation, or adoption of the open source project. + +## Technical Steering Committee + +1. The Technical Steering Committee (the "TSC") will be responsible for all technical oversight of the open source Project. + +2. The TSC voting members are initially the Project's Committers. At the inception of the project, the Committers of the Project will be as set forth within the "CONTRIBUTING" file within the Project's code repository. The TSC may choose an alternative approach for determining the voting members of the TSC, and any such alternative approach will be documented in the CONTRIBUTING file. Any meetings of the Technical Steering Committee are intended to be open to the public, and can be conducted electronically, via teleconference, or in person. + +3. TSC projects generally will involve Contributors and Committers. The TSC may adopt or modify roles so long as the roles are documented in the CONTRIBUTING file. Unless otherwise documented: + + - **Contributors** include anyone in the technical community that contributes code, documentation, or other technical artifacts to the Project. + - **Committers** are Contributors who have earned the ability to modify ("commit") source code, documentation, or other technical artifacts in a project's repository. + + - A Contributor may become a Committer by a majority approval of the existing Committers. A Committer may be removed by a majority approval of the other existing Committers. + +4. Participation in the Project through becoming a Contributor and Committer is open to anyone so long as they abide by the terms of this Charter. + +5. The TSC may: + - Establish workflow procedures for the submission, approval, and closure/archiving of projects. + - Set requirements for the promotion of Contributors to Committer status, as applicable. + - Amend, adjust, refine and/or eliminate the roles of Contributors and Committers, and create new roles, and publicly document any TSC roles, as it sees fit. + +6. The TSC may elect a TSC Chair, who will preside over meetings of the TSC and will serve until their resignation or replacement by the TSC. The TSC Chair, or any other TSC member so designated by the TSC, will serve as the primary communication contact between the Project and AI & Data, a directed fund of The Linux Foundation. + +7. Responsibilities: The TSC will be responsible for all aspects of oversight relating to the Project, which may include: + + - Coordinating the technical direction of the Project. + - Approving project or system proposals (including, but not limited to, incubation, deprecation, and changes to a sub-project's scope). + - Organizing sub-projects and removing sub-projects. + - Creating sub-committees or working groups to focus on cross-project technical issues and requirements. + - Appointing representatives to work with other open source or open standards communities. + - Establishing community norms, workflows, issuing releases, and security issue reporting policies. + - Approving and implementing policies and processes for contributing (to be published in the CONTRIBUTING file) and coordinating with the series manager of the Project (as provided for in the Series Agreement, the "Series Manager") to resolve matters or concerns that may arise as set forth in Section 7 of this Charter. + - Discussions, seeking consensus, and where necessary, voting on technical matters relating to the code base that affect multiple projects. + - Coordinating any marketing, events, or communications regarding the Project. + +## TSC Voting + +1. While the Project aims to operate as a consensus-based community, if any TSC decision requires a vote to move the Project forward, the voting members of the TSC will vote on a one vote per voting member basis. + +2. Quorum for TSC meetings requires at least fifty percent of all voting members of the TSC to be present. The TSC may continue to meet if quorum is not met but will be prevented from making any decisions at the meeting. + +3. Except as provided in Section 7.c. and 8.a, decisions by vote at a meeting require a majority vote of those in attendance, provided quorum is met. Decisions made by electronic vote without a meeting require a majority vote of all voting members of the TSC. + +4. In the event a vote cannot be resolved by the TSC, any voting member of the TSC may refer the matter to the Series Manager for assistance in reaching a resolution. + +## Compliance with Policies + +1. This Charter is subject to the Series Agreement for the Project and the Operating Agreement of LF Projects. Contributors will comply with the policies of LF Projects as may be adopted and amended by LF Projects, including, without limitation, the policies listed at https://lfprojects.org/policies/. + +2. The TSC may adopt a code of conduct ("CoC") for the Project, which is subject to approval by the Series Manager. In the event that a Project-specific CoC has not been approved, the LF Projects Code of Conduct listed at https://lfprojects.org/policies will apply for all Collaborators in the Project. + +3. When amending or adopting any policy applicable to the Project, LF Projects will publish such policy, as to be amended or adopted, on its website at least 30 days prior to such policy taking effect; provided, however, that in the case of any amendment of the Trademark Policy or Terms of Use of LF Projects, any such amendment is effective upon publication on LF Project's website. + +4. All Collaborators must allow open participation from any individual or organization meeting the requirements for contributing under this Charter and any policies adopted for all Collaborators by the TSC, regardless of competitive interests. Put another way, the Project community must not seek to exclude any participant based on any criteria, requirement, or reason other than those that are reasonable and applied on a non-discriminatory basis to all Collaborators in the Project community. + +5. The Project will operate in a transparent, open, collaborative, and ethical manner at all times. The output of all Project discussions, proposals, timelines, decisions, and status should be made open and easily visible to all. Any potential violations of this requirement should be reported immediately to the Series Manager. + +## Community Assets + +1. LF Projects will hold title to all trade or service marks used by the Project ("Project Trademarks"), whether based on common law or registered rights. Project Trademarks will be transferred and assigned to LF Projects to hold on behalf of the Project. Any use of any Project Trademarks by Collaborators in the Project will be in accordance with the license from LF Projects and inure to the benefit of LF Projects. + +2. The Project will, as permitted and in accordance with such license from LF Projects, develop and own all Project GitHub and social media accounts, and domain name registrations created by the Project community. + +3. Under no circumstances will LF Projects be expected or required to undertake any action on behalf of the Project that is inconsistent with the tax-exempt status or purpose, as applicable, of the Joint Development Foundation or LF Projects, LLC. + +## General Rules and Operations + +The Project will: + +1. Engage in the work of the Project in a professional manner consistent with maintaining a cohesive community, while also maintaining the goodwill and esteem of LF Projects, Joint Development Foundation, and other partner organizations in the open source community. +2. Respect the rights of all trademark owners, including any branding and trademark usage guidelines. + +## Intellectual Property Policy + +1. Collaborators acknowledge that the copyright in all new contributions will be retained by the copyright holder as independent works of authorship and that no contributor or copyright holder will be required to assign copyrights to the Project. + +2. Except as described in Section 7.c., all contributions to the Project are subject to the following: + + - All new inbound code contributions to the Project must be made using Apache License, Version 2.0 available at http://www.apache.org/licenses/LICENSE-2.0 (the "Project License"). + - All new inbound code contributions must also be accompanied by a Developer Certificate of Origin (http://developercertificate.org) sign-off in the source code system that is submitted through a TSC-approved contribution process which will bind the authorized contributor and, if not self-employed, their employer to the applicable license. + - All outbound code will be made available under the Project License. + - Documentation will be received and made available by the Project under the Creative Commons Attribution 4.0 International License (available at http://creativecommons.org/licenses/by/4.0/). + - The Project may seek to integrate and contribute back to other open source projects ("Upstream Projects"). In such cases, the Project will conform to all license requirements of the Upstream Projects, including dependencies, leveraged by the Project. Upstream Project code contributions not stored within the Project's main code repository will comply with the contribution process and license terms for the applicable Upstream Project. + +3. The TSC may approve the use of an alternative license or licenses for inbound or outbound contributions on an exception basis. To request an exception, please describe the contribution, the alternative open source license(s), and the justification for using an alternative open source license for the Project. License exceptions must be approved by a two-thirds vote of the entire TSC. + +4. Contributed files should contain license information, such as SPDX short form identifiers, indicating the open source license or licenses pertaining to the file. + +## Amendments + +1. This charter may be amended by a two-thirds vote of the entire TSC and is subject to approval by LF Projects. diff --git a/MANIFEST.in b/MANIFEST.in index 2fec750c6644..8d84aee0faf4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,9 @@ include *.txt README.md +include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so +include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt -recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include deepspeed *.cpp *.h *.hpp *.cu *.hip *.tr *.cuh *.cc *.json +recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc recursive-include op_builder *.py recursive-include benchmarks *.py recursive-include accelerator *.py diff --git a/Makefile b/Makefile new file mode 100644 index 000000000000..8756897ebedf --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +# usage: make help + +.PHONY: help test format +.DEFAULT_GOAL := help + +help: ## this help + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[0-9a-zA-Z_-]+:.*?##/ { printf " \033[36m%-22s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + echo $(MAKEFILE_LIST) + +test: ## run tests + pytest --forked tests/unit/ + +format: ## fix formatting + @if [ ! -d "venv" ]; then \ + python -m venv venv; \ + . venv/bin/activate; \ + pip install pre-commit -U; \ + pre-commit clean; \ + pre-commit uninstall; \ + pre-commit install; \ + deactivate; \ + fi + . venv/bin/activate && pre-commit run --files $$(git diff --name-only master) && deactivate diff --git a/README.md b/README.md index b708e6258abb..ddd9def3ef8f 100755 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ -[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE) +[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/deepspeedai/DeepSpeed/blob/master/LICENSE) [![PyPI version](https://badge.fury.io/py/deepspeed.svg)](https://pypi.org/project/deepspeed/) -[![Downloads](https://pepy.tech/badge/deepspeed)](https://pepy.tech/project/deepspeed) +[![Downloads](https://static.pepy.tech/badge/deepspeed)](https://pepy.tech/project/deepspeed) [![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status) +[![OpenSSF Best Practices](https://www.bestpractices.dev/projects/9530/badge)](https://www.bestpractices.dev/projects/9530) +[![Twitter](https://img.shields.io/twitter/follow/DeepSpeedAI)](https://twitter.com/intent/follow?screen_name=DeepSpeedAI) +[![Japanese Twitter](https://img.shields.io/badge/%E6%97%A5%E6%9C%AC%E8%AA%9ETwitter-%40DeepSpeedAI_JP-blue)](https://twitter.com/DeepSpeedAI_JP) +[![Chinese Zhihu](https://img.shields.io/badge/%E7%9F%A5%E4%B9%8E-%E5%BE%AE%E8%BD%AFDeepSpeed-blue)](https://www.zhihu.com/people/deepspeed) +[![Slack](https://img.shields.io/badge/Slack-4A154B?style=for-the-badge&logo=slack&logoColor=white)](https://join.slack.com/t/deepspeedworkspace/shared_invite/zt-3a8pjd8dd-PCj2hMvR4Y2syPwVnjEoww)
@@ -9,71 +14,65 @@
-## Latest News - DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). -* ***[2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)*** [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀 -* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html) -* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) -* [2022/12] [DeepSpeed Data Efficiency: A composable library that makes better use of data, increases training efficiency, and improves model quality](https://www.deepspeed.ai/2022/12/11/data-efficiency.html) -* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/benchmark/txt2img) -* [2022/10] [DeepSpeed-MII: instant speedup on 24,000+ open-source DL models with up to 40x cheaper inference](https://www.deepspeed.ai/2022/10/10/mii.html) -* [2022/09] [ZeRO-Inference: Democratizing massive model inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html) -* [2022/07] [Azure and DeepSpeed empower easy-to-use and high-performance model training](https://azure.microsoft.com/en-us/blog/azure-empowers-easytouse-highperformance-and-hyperscale-model-training-using-deepspeed/) +## Office Hours ---- +DeepSpeed hosts regular office hours on the last Tuesday of each month at 12:00 America/New_York to discuss development plans, features, etc. This meeting is public for anyone to join and ask questions. +The meeting is hosted on Zoom and can be joined [here](https://zoom-lfx.platform.linuxfoundation.org/meeting/93902569995?password=7d9c4fc9-3efa-4715-88f0-df8a6deb008b). -# Extreme Speed and Scale for DL Training and Inference +## Latest News -***[DeepSpeed](https://www.deepspeed.ai/) enables world's most powerful language models like [MT-530B](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/) and [BLOOM](https://huggingface.co/blog/bloom-megatron-deepspeed)***. It is an easy-to-use deep learning optimization software suite that powers unprecedented scale and speed for both training and inference. With DeepSpeed you can: +* [2026/05] [Using Muon Optimizer with DeepSpeed](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/muon-optimizer/README.md) -* Train/Inference dense or sparse models with billions or trillions of parameters -* Achieve excellent system throughput and efficiently scale to thousands of GPUs -* Train/Inference on resource constrained GPU systems -* Achieve unprecedented low latency and high throughput for inference -* Achieve extreme compression for an unparalleled inference latency and model size reduction with low costs +* [2026/05] [System DMA (SDMA) for ZeRO-3: offload collectives off compute units on AMD GPUs for better overlap](https://github.com/deepspeedai/DeepSpeed/blob/master/examples/sdma_allgather/README.md) ---- +* [2026/03] DeepSpeed Team gave a tutorial at ASPLOS 2026 titled ["Building Efficient Large-Scale Model Systems with DeepSpeed: From Open-Source Foundations to Emerging Research" ](https://supercomputing-system-ai-lab.github.io/events/asplos2026-llm-tutorial/index.html) -# DeepSpeed's three innovation pillars +* [2026/03] [Our SuperOffload work received an Honorable Mention for the ASPLOS 2026 Best Paper Award](https://dl.acm.org/doi/10.1145/3760250.3762217) - +* [2025/12] [DeepSpeed Core API updates: PyTorch-style backward and low-precision master states](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/core_api_update/README.md) +* [2025/11] [DeepSpeed ZeRO++ powers large-scale distillation training of LLMs for Recommendation Systems at LinkedIn](https://aclanthology.org/2025.emnlp-industry.119/) -## DeepSpeed-Training +* [2025/10] We hosted the [Ray x DeepSpeed Meetup](https://luma.com/3wctqteh) at Anyscale. We shared our most recent work on SuperOffload, ZenFlow, Muon Optimizer Support, Arctic Long Sequence Training and DeepCompile. Please find the meetup slides [here](https://docs.google.com/presentation/d/1eM3mY6oW9GYkRy1Xz0iOnbbEr5T1t0JJXOM5BKtR-Ks/edit?slide=id.g38615d6b4c2_0_87#slide=id.g38615d6b4c2_0_87). -DeepSpeed offers a confluence of system innovations, that has made large scale DL training effective, and efficient, greatly improved ease of use, and redefined the DL training landscape in terms of scale that is possible. These innovations such as ZeRO, 3D-Parallelism, DeepSpeed-MoE, ZeRO-Infinity, etc. fall under the training pillar. Learn more: [DeepSpeed-Training](https://www.deepspeed.ai/training/) +* [2025/10] [SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips](https://pytorch.org/blog/superoffload-unleashing-the-power-of-large-scale-llm-training-on-superchips/) -## DeepSpeed-Inference +* [2025/10] [Study of ZenFlow and ZeRO offload performance with DeepSpeed CPU core binding](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/zenflow-corebinding/README.md) -DeepSpeed brings together innovations in parallelism technology such as tensor, pipeline, expert and ZeRO-parallelism, and combines them with high performance custom inference kernels, communication optimizations and heterogeneous memory technologies to enable inference at an unprecedented scale, while achieving unparalleled latency, throughput and cost reduction. This systematic composition of system technologies for inference falls under the inference pillar. Learn more: [DeepSpeed-Inference](https://www.deepspeed.ai/inference) +* [2025/08] [ZenFlow: Stall-Free Offloading Engine for LLM Training](https://pytorch.org/blog/zenflow-stall-free-offloading-engine-for-llm-training/) +* [2025/06] [Arctic Long Sequence Training (ALST) with DeepSpeed: Scalable And Efficient Training For Multi-Million Token Sequences](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/) -## DeepSpeed-Compression +* [2025/06] [DeepNVMe: Affordable I/O scaling for Deep Learning Applications](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepnvme/06-2025/README.md) -To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression) ---- + +
+ + More news + +
- [Model Implementations for Inference (MII)](https://github.com/microsoft/deepspeed-mii) is an open-sourced repository for making low-latency and high-throughput inference accessible to all data scientists by alleviating the need to apply complex system optimization techniques themselves. Out-of-box, MII offers support for thousands of widely used DL models, optimized using DeepSpeed-Inference, that can be deployed with a few lines of code, while achieving significant latency reduction compared to their vanilla open-sourced versions. +--- -## DeepSpeed on Azure +# Extreme Speed and Scale for DL Training - DeepSpeed users are diverse and have access to different environments. We recommend to try DeepSpeed on Azure as it is the simplest and easiest method. The recommended method to try DeepSpeed on Azure is through AzureML [recipes](https://github.com/Azure/azureml-examples/tree/main/v1/python-sdk/workflows/train/deepspeed). The job submission and data preparation scripts have been made available [here](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml). For more details on how to use DeepSpeed on Azure, please follow the [Azure tutorial](https://www.deepspeed.ai/tutorials/azure/). +***[DeepSpeed](https://www.deepspeed.ai/) enabled the world's most powerful language models (at the time of this writing) such as [MT-530B](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/) and [BLOOM](https://huggingface.co/blog/bloom-megatron-deepspeed)***. DeepSpeed offers a confluence of [system innovations](https://www.deepspeed.ai/training/), that has made large scale DL training effective, and efficient, greatly improved ease of use, and redefined the DL training landscape in terms of scale that is possible. These innovations include ZeRO, ZeRO-Infinity, 3D-Parallelism, Ulysses Sequence Parallelism, DeepSpeed-MoE, etc. --- # DeepSpeed Adoption -DeepSpeed is an important part of Microsoft’s new +DeepSpeed was an important part of Microsoft’s [AI at Scale](https://www.microsoft.com/en-us/research/project/ai-at-scale/) initiative to enable next-generation AI capabilities at scale, where you can find more information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale). @@ -84,6 +83,7 @@ DeepSpeed has been used to train many different large-scale models, below is a l * [Jurassic-1 (178B)](https://uploads-ssl.webflow.com/60fd4503684b466578c0d307/61138924626a6981ee09caf6_jurassic_tech_paper.pdf) * [BLOOM (176B)](https://huggingface.co/blog/bloom-megatron-deepspeed) * [GLM (130B)](https://github.com/THUDM/GLM-130B) + * [xTrimoPGLM (100B)](https://www.biorxiv.org/content/10.1101/2023.07.05.547496v2) * [YaLM (100B)](https://github.com/yandex/YaLM-100B) * [GPT-NeoX (20B)](https://github.com/EleutherAI/gpt-neox) * [AlexaTM (20B)](https://www.amazon.science/blog/20b-parameter-alexa-model-sets-new-marks-in-few-shot-learning) @@ -94,11 +94,12 @@ DeepSpeed has been integrated with several different popular open-source DL fram | | Documentation | | ---------------------------------------------------------------------------------------------- | -------------------------------------------- | - | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/main/main_classes/deepspeed) | + | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/deepspeed) | | | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) | -| | [Lightning with DeepSpeed](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html) | -| | [MosaicML with DeepSpeed](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) | +| | [Lightning with DeepSpeed](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed) | +| | [MosaicML with DeepSpeed](https://docs.mosaicml.com/projects/composer/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) | | | [Determined with DeepSpeed](https://docs.determined.ai/latest/training/apis-howto/deepspeed/overview.html) | +| | [MMEngine with DeepSpeed](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#deepspeed) | --- @@ -106,12 +107,14 @@ DeepSpeed has been integrated with several different popular open-source DL fram | Description | Status | | ----------- | ------ | -| NVIDIA | [![nv-torch19-p40](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch19-p40.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch19-p40.yml) [![nv-torch19-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch19-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch19-v100.yml) [![nv-torch-latest-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml) [![nv-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-inference.yml) [![nv-nightly](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-nightly.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-nightly.yml) | -| AMD | [![amd-mi100](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi100.yml) [![amd-mi200](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml) | -| CPU | [![nv-torch-latest-cpu](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-cpu.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-cpu.yml) | -| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) | -| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml)[![nv-megatron](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-megatron.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-megatron.yml)[![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) | -| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) | +| NVIDIA | [![nv-pre-compile-ops](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-pre-compile-ops.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-pre-compile-ops.yml) [![aws-torch-latest](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-torch-latest.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-torch-latest.yml) | +| AMD | [![amd-mi200](https://github.com/deepspeedai/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/amd-mi200.yml) | +| CPU | [![torch-latest-cpu](https://github.com/deepspeedai/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/cpu-torch-latest.yml) | +| Intel Gaudi | [![hpu-gaudi2](https://github.com/deepspeedai/DeepSpeed/actions/workflows/hpu-gaudi2.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/hpu-gaudi2.yml) | +| Intel XPU | [![xpu-max1100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/xpu-max1100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/xpu-max1100.yml) | +| Integrations | [![aws-accelerate](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-accelerate.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-accelerate.yml) | +| Misc | [![Formatting](https://github.com/deepspeedai/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/deepspeedai/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/deepspeedai/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/python.yml) | +| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml/badge.svg?branch=main)](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml) | # Installation @@ -125,12 +128,23 @@ dynamically link them at runtime. ## Requirements * [PyTorch](https://pytorch.org/) must be installed _before_ installing DeepSpeed. -* For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release. +* For full feature support we recommend a version of PyTorch that is >= 2.0 and ideally the latest PyTorch stable release. * A CUDA or ROCm compiler such as [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#introduction) or [hipcc](https://github.com/ROCm-Developer-Tools/HIPCC) used to compile C++/CUDA/HIP extensions. * Specific GPUs we develop and test against are listed below, this doesn't mean your GPU will not work if it doesn't fall into this category it's just DeepSpeed is most well tested on the following: * NVIDIA: Pascal, Volta, Ampere, and Hopper architectures * AMD: MI100 and MI200 +## Contributed HW support +* DeepSpeed now support various HW accelerators. + +| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated | +|-------------|-------------------------------------|------------------| --------------------- |--------------------| +| Huawei | Huawei Ascend NPU | npu | Yes | No | +| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes | +| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes | +| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes | +| Tecorigin | Scalable Data Analytics Accelerator | sdaa | Yes | No | + ## PyPI We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases. @@ -150,15 +164,12 @@ of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/). ## Windows -Windows support is partially supported with DeepSpeed. On Windows you can build wheel with following steps, currently only inference mode is supported. -1. Install pytorch, such as pytorch 1.8 + cuda 11.1 -2. Install visual cpp build tools, such as VS2019 C++ x64/x86 build tools -3. Launch cmd console with Administrator privilege for creating required symlink folders -4. Run `python setup.py bdist_wheel` to build wheel in `dist` folder - -# Features +Many DeepSpeed features are supported on Windows for both training and inference. You can read more about this in the original blog post [here](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/windows/08-2024/README.md). Among features that are currently not supported are async io (AIO) and GDS (which does not support Windows). +1. Install PyTorch, such as pytorch 2.3+cu121. +2. Install Visual C++ build tools, such as VS2022 C++ x64/x86 build tools. +3. Launch Cmd console with Administrator permissions for creating required symlink folders and ensure MSVC tools are added to your PATH or launch the Developer Command Prompt for Visual Studio 2022 with administrator permissions. +4. Run `build_win.bat` to build wheel in `dist` folder. -Please checkout [DeepSpeed-Training](https://www.deepspeed.ai/training), [DeepSpeed-Inference](https://www.deepspeed.ai/inference) and [DeepSpeed-Compression](https://www.deepspeed.ai/compression) pages for full set of features offered along each of these three pillars. # Further Reading @@ -174,21 +185,29 @@ All DeepSpeed documentation, tutorials, and blogs can be found on our website: [ | [Blogs](https://www.deepspeed.ai/posts/) | Blogs | +# CI funding + +This being an open source project we rely on others to provide us resources for CI hardware. At this moment Modal is kindly supporting our GPU CI runs by funding the hardware for us. Modal is an AI infrastructure platform for inference, fine-tuning, batch jobs and more. Get started with $30/mo in free credits today at https://modal.com. We have been getting an amazing support from Modal's team and will surely recommend them to your business. + # Contributing DeepSpeed welcomes your contributions! Please see our [contributing](CONTRIBUTING.md) guide for more details on formatting, testing, -etc. +etc.
+Thanks so much to all of our amazing contributors! + + + + -## Contributor License Agreement +## Developer Certificate of Origin This project welcomes contributions and suggestions. Most contributions require you to -agree to a Contributor License Agreement (CLA) declaring that you have the right to, and -actually do, grant us the rights to use your contribution. For details, visit -https://cla.opensource.microsoft.com. +agree to a Developer Certificate of Origin [DCO](https://wiki.linuxfoundation.org/dco) +stating that they agree to the terms published at https://developercertificate.org for +that *particular* contribution. -When you submit a pull request, a CLA bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply -follow the instructions provided by the bot. You will only need to do this once across -all repos using our CLA. +DCOs are per-commit, so each commit needs to be signed off. These can be signed in +the commit by adding the `-s` flag. DCO enforcement can also be signed off in the PR +itself by clicking on the DCO enforcement check. ## Code of Conduct This project has adopted the [Microsoft Open Source Code of @@ -200,24 +219,41 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information 1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: memory optimizations toward training trillion parameter models. [arXiv:1910.02054](https://arxiv.org/abs/1910.02054) and [In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis (SC '20)](https://dl.acm.org/doi/10.5555/3433701.3433727). 2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703). 3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html). -4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840) and [USENIX ATC 2021](https://www.usenix.org/conference/atc21/presentation/ren-jie). +4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840) and [USENIX ATC 2021](https://www.usenix.org/conference/atc21/presentation/ren-jie). [[paper]](https://arxiv.org/abs/2101.06840) [[slides]](https://www.usenix.org/system/files/atc21_slides_ren-jie.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/) 5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888) and [ICML 2021](http://proceedings.mlr.press/v139/tang21a.html). -6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857) and [SC 2021](https://dl.acm.org/doi/abs/10.1145/3458817.3476205). +6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857) and [SC 2021](https://dl.acm.org/doi/abs/10.1145/3458817.3476205). [[paper]](https://arxiv.org/abs/2104.07857) [[slides]](docs/assets/files/SC21-ZeRO-Infinity.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) 7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069) and [HiPC 2022](https://hipc.org/advance-program/). 8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) The Stability-Efficiency Dilemma: Investigating Sequence Length Warmup for Training GPT Models. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084) and [NeurIPS 2022](https://openreview.net/forum?id=JpZ5du_Kdh). 9. Yucheng Lu, Conglong Li, Minjia Zhang, Christopher De Sa, Yuxiong He. (2022) Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam. [arXiv:2202.06009](https://arxiv.org/abs/2202.06009). -10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596) and [ICML 2022](https://proceedings.mlr.press/v162/rajbhandari22a.html). +10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596) and [ICML 2022](https://proceedings.mlr.press/v162/rajbhandari22a.html). [[pdf]](https://arxiv.org/abs/2201.05596) [[slides]](docs/assets/files/ICML-5mins.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/) 11. Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, Elton Zhang, Rewon Child, Reza Yazdani Aminabadi, Julie Bernauer, Xia Song, Mohammad Shoeybi, Yuxiong He, Michael Houston, Saurabh Tiwary, Bryan Catanzaro. (2022) Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model [arXiv:2201.11990](https://arxiv.org/abs/2201.11990). 12. Xiaoxia Wu, Zhewei Yao, Minjia Zhang, Conglong Li, Yuxiong He. (2022) Extreme Compression for Pre-trained Transformers Made Simple and Efficient. [arXiv:2206.01859](https://arxiv.org/abs/2206.01859) and [NeurIPS 2022](https://openreview.net/forum?id=xNeAhc2CNAl). -13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861) and [NeurIPS 2022](https://openreview.net/forum?id=f-fVCElZ-G1). -14. Reza Yazdani Aminabadi, Samyam Rajbhandari, Minjia Zhang, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Jeff Rasley, Shaden Smith, Olatunji Ruwase, Yuxiong He. (2022) DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. [arXiv:2207.00032](https://arxiv.org/abs/2207.00032) and [SC 2022](https://dl.acm.org/doi/abs/10.5555/3571885.3571946). +13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861) and [NeurIPS 2022](https://openreview.net/forum?id=f-fVCElZ-G1) [[slides]](docs/assets/files/zeroquant_series.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-compression-a-composable-library-for-extreme-compression-and-zero-cost-quantization/) +14. Reza Yazdani Aminabadi, Samyam Rajbhandari, Minjia Zhang, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Jeff Rasley, Shaden Smith, Olatunji Ruwase, Yuxiong He. (2022) DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. [arXiv:2207.00032](https://arxiv.org/abs/2207.00032) and [SC 2022](https://dl.acm.org/doi/abs/10.5555/3571885.3571946). [[paper]](https://arxiv.org/abs/2207.00032) [[slides]](docs/assets/files/sc22-ds-inference.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/) 15. Zhewei Yao, Xiaoxia Wu, Conglong Li, Connor Holmes, Minjia Zhang, Cheng Li, Yuxiong He. (2022) Random-LTD: Random and Layerwise Token Dropping Brings Efficient Training for Large-scale Transformers. [arXiv:2211.11586](https://arxiv.org/abs/2211.11586). -16. Conglong Li, Zhewei Yao, Xiaoxia Wu, Minjia Zhang, Yuxiong He. (2022) DeepSpeed Data Efficiency: Improving Deep Learning Model Quality and Training Efficiency via Efficient Data Sampling and Routing. [arXiv:2212.03597](https://arxiv.org/abs/2212.03597). -17. Xiaoxia Wu, Cheng Li, Reza Yazdani Aminabadi, Zhewei Yao, Yuxiong He. (2023) Understanding INT4 Quantization for Transformer Models: Latency Speedup, Composability, and Failure Cases. [arXiv:2301.12017](https://arxiv.org/abs/2301.12017). +16. Conglong Li, Zhewei Yao, Xiaoxia Wu, Minjia Zhang, Yuxiong He. (2022) DeepSpeed Data Efficiency: Improving Deep Learning Model Quality and Training Efficiency via Efficient Data Sampling and Routing. [arXiv:2212.03597](https://arxiv.org/abs/2212.03597) [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) +17. Xiaoxia Wu, Cheng Li, Reza Yazdani Aminabadi, Zhewei Yao, Yuxiong He. (2023) Understanding INT4 Quantization for Transformer Models: Latency Speedup, Composability, and Failure Cases. [arXiv:2301.12017](https://arxiv.org/abs/2301.12017) and [ICML2023](https://icml.cc/Conferences/2023). 18. Syed Zawad, Cheng Li, Zhewei Yao, Elton Zheng, Yuxiong He, Feng Yan. (2023) DySR: Adaptive Super-Resolution via Algorithm and System Co-design. [ICLR:2023](https://openreview.net/forum?id=Pgtn4l6eKjv). -19. Sheng Shen, Zhewei Yao, Chunyuan Li, Trevor Darrell, Kurt Keutzer, Yuxiong He. (2023) Scaling Vision-Language Models with Sparse Mixture of Experts. [arXiv:2303.07226](https://arxiv.org/abs/2303.07226). +19. Sheng Shen, Zhewei Yao, Chunyuan Li, Trevor Darrell, Kurt Keutzer, Yuxiong He. (2023) Scaling Vision-Language Models with Sparse Mixture of Experts. [arXiv:2303.07226](https://arxiv.org/abs/2303.07226) and [Finding at EMNLP2023](https://2023.emnlp.org/). 20. Quentin Anthony, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He, Aamir Shafi, Mustafa Abduljabbar, Hari Subramoni, Dhabaleswar Panda. (2023) MCR-DL: Mix-and-Match Communication Runtime for Deep Learning [arXiv:2303.08374](https://arxiv.org/abs/2303.08374) and will appear at IPDPS 2023. - +21. Siddharth Singh, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He, Abhinav Bhatele. (2023) A Hybrid Tensor-Expert-Data Parallelism Approach to Optimize Mixture-of-Experts Training [arXiv:2303.06318](https://arxiv.org/abs/2303.06318) and [ICS 2023](https://dl.acm.org/doi/10.1145/3577193.3593704). +22. Guanhua Wang, Heyang Qin, Sam Ade Jacobs, Xiaoxia Wu, Connor Holmes, Zhewei Yao, Samyam Rajbhandari, Olatunji Ruwase, Feng Yan, Lei Yang, Yuxiong He. (2023) ZeRO++: Extremely Efficient Collective Communication for Giant Model Training [arXiv:2306.10209](https://arxiv.org/abs/2306.10209) and [ML for Sys Workshop at NeurIPS2023](http://mlforsystems.org/) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/) +23. Zhewei Yao, Xiaoxia Wu, Cheng Li, Stephen Youn, Yuxiong He. (2023) ZeroQuant-V2: Exploring Post-training Quantization in LLMs from Comprehensive Study to Low Rank Compensation [arXiv:2303.08302](https://arxiv.org/abs/2303.08302) and [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) [[slides]](docs/assets/files/zeroquant_series.pdf) +24. Pareesa Ameneh Golnari, Zhewei Yao, Yuxiong He. (2023) Selective Guidance: Are All the Denoising Steps of Guided Diffusion Important? [arXiv:2305.09847](https://arxiv.org/abs/2305.09847) +25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320). +26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782) and [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) [[slides]](docs/assets/files/zeroquant_series.pdf) +27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf) +28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610) [[blog]](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/) +29. Zhewei Yao, Reza Yazdani Aminabadi, Stephen Youn, Xiaoxia Wu, Elton Zheng, Yuxiong He. (2023) ZeroQuant-HERO: Hardware-Enhanced Robust Optimized Post-Training Quantization Framework for W8A8 Transformers [arXiv:2310.17723](https://arxiv.org/abs/2310.17723) + +30. Xiaoxia Wu, Haojun Xia, Stephen Youn, Zhen Zheng, Shiyang Chen, Arash Bakhtiari, Michael Wyatt, Reza Yazdani Aminabadi, Yuxiong He, Olatunji Ruwase, Leon Song, Zhewei Yao (2023) ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks [arXiv:2312.08583](https://arxiv.org/abs/2312.08583) + +31. Haojun Xia, Zhen Zheng, Xiaoxia Wu, Shiyang Chen, Zhewei Yao, Stephen Youn, Arash Bakhtiari, Michael Wyatt, Donglin Zhuang, Zhongzhu Zhou, Olatunji Ruwase, Yuxiong He, Shuaiwen Leon Song. (2024) FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design [arXiv:2401.14112](https://arxiv.org/abs/2401.14112) +32. Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Reza Yazdani Aminadabi, Shuaiwen Leon Song, Samyam Rajbhandari, Yuxiong He. (2024) [System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://dl.acm.org/doi/10.1145/3662158.3662806) +33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820) +34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996) +35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242) +36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [arxiv](https://arxiv.org/abs/2509.21271), [ASPLOS 2026](https://www.asplos-conference.org/asplos2026) # Videos 1. DeepSpeed KDD 2020 Tutorial @@ -231,7 +267,8 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information * Registration is free and all videos are available on-demand. * [ZeRO & Fastest BERT: Increasing the scale and speed of deep learning training in DeepSpeed](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html). 3. [DeepSpeed on AzureML](https://youtu.be/yBVXR8G8Bg8) -4. Community Tutorials +4. [Large Model Training and Inference with DeepSpeed // Samyam Rajbhandari // LLMs in Prod Conference](https://www.youtube.com/watch?v=cntxC3g22oU) [[slides]](docs/assets/files/presentation-mlops.pdf) +5. Community Tutorials * [DeepSpeed: All the tricks to scale to gigantic models (Mark Saroufim)](https://www.youtube.com/watch?v=pDGI668pNg0) * [Turing-NLG, DeepSpeed and the ZeRO optimizer (Yannic Kilcher)](https://www.youtube.com/watch?v=tC01FRB0M7w) * [Ultimate Guide To Scaling ML Models (The AI Epiphany)](https://www.youtube.com/watch?v=hc0u4avAkuM) diff --git a/SECURITY.md b/SECURITY.md index e0dfff56a956..74d0d866a145 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,41 +1,33 @@ - +# Security Policy -## Security +## Reporting security issues -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). +Please report security issues privately using [the vulnerability submission form](https://github.com/deepspeedai/deepspeed/security/advisories/new). -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. +## Issue triage -## Reporting Security Issues +Reports will then be triaged by the maintainers. -**Please do not report security vulnerabilities through public GitHub issues.** +## Threat model -Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. -If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). +## Issue severity -You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). +We will determine the risk of each issue, taking into account our experience dealing with past issues, versions affected, common defaults, and use cases. We use the following severity categories: -Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: +### CRITICAL Severity - * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) - * Full paths of source file(s) related to the manifestation of the issue - * The location of the affected source code (tag/branch/commit or direct URL) - * Any special configuration required to reproduce the issue - * Step-by-step instructions to reproduce the issue - * Proof-of-concept or exploit code (if possible) - * Impact of the issue, including how an attacker might exploit the issue +Vulnerabilities that allow remote attackers to execute arbitrary code, take full control of the system, or significantly compromise confidentiality, integrity, or availability without any interaction or privileges needed, examples include remote code execution via network, deserialization issues that allow exploit chains. Generally those issues which are rated as CVSS ≥ 9.0. -This information will help us triage your report more quickly. +### HIGH Severity -If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. +Serious security flaws that allow elevated impact—like RCE in specific, limited contexts or significant data loss—but require advanced conditions or some trust, examples include RCE in advanced deployment modes (e.g. multi-node), or high impact issues where some sort of privileged network access is required. These issues typically have CVSS scores between 7.0 and 8.9 -## Preferred Languages +### MODERATE Severity -We prefer all communications to be in English. +Vulnerabilities that cause denial of service or partial disruption, but do not allow arbitrary code execution or data breach and have limited impact. These issues have a CVSS rating between 4.0 and 6.9 -## Policy +### LOW Severity -Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). - - +Minor issues such as informational disclosures, logging errors, non-exploitable flaws, or weaknesses that require local or high-privilege access and offer negligible impact. Examples include side channel attacks or hash collisions. These issues often have CVSS scores less than 4.0 diff --git a/accelerator/__init__.py b/accelerator/__init__.py index c0d9a7bf36ef..efed1ef84aca 100644 --- a/accelerator/__init__.py +++ b/accelerator/__init__.py @@ -4,4 +4,4 @@ # DeepSpeed Team from .abstract_accelerator import DeepSpeedAccelerator -from .real_accelerator import get_accelerator, set_accelerator +from .real_accelerator import get_accelerator, set_accelerator, is_current_accelerator_supported diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index dc26edf26faf..c764760b962c 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -8,10 +8,28 @@ class DeepSpeedAccelerator(ABC): + supports_nvtx_domain = False def __init__(self): self._name = None self._communication_backend_name = None + self._compile_backend = None + + @abc.abstractmethod + def is_synchronized_device(self): + ... + + @abc.abstractmethod + def use_host_timers(self): + ... + + @abc.abstractmethod + def resolves_data_dependency(self): + ... + + @abc.abstractmethod + def handles_memory_backpressure(self): + ... # Device APIs @abc.abstractmethod @@ -64,7 +82,7 @@ def manual_seed_all(self, seed): ... @abc.abstractmethod - def initial_seed(self, seed): + def initial_seed(self): ... @abc.abstractmethod @@ -143,6 +161,10 @@ def max_memory_reserved(self, device_index=None): def total_memory(self, device_index=None): ... + @abc.abstractmethod + def available_memory(self, device_index=None): + ... + # Data types @abc.abstractmethod def is_bf16_supported(self): @@ -152,21 +174,21 @@ def is_bf16_supported(self): def is_fp16_supported(self): ... - # Misc @abc.abstractmethod - def amp(self): + def supported_dtypes(self): ... + # Misc @abc.abstractmethod def is_available(self): ... @abc.abstractmethod - def range_push(self, msg): + def range_push(self, msg, domain=None, category=None): ... @abc.abstractmethod - def range_pop(self): + def range_pop(self, domain=None): ... @abc.abstractmethod @@ -177,6 +199,23 @@ def lazy_call(self, callback): def communication_backend_name(self): ... + @abc.abstractmethod + def is_triton_supported(self): + ... + + # Graph operations + @abc.abstractmethod + def create_graph(self): + ... + + @abc.abstractmethod + def capture_to_graph(self, graph, pool=None, stream=None): + ... + + @abc.abstractmethod + def replay_graph(self, graph): + ... + # Tensor operations @property @abc.abstractmethod @@ -214,7 +253,11 @@ def LongTensor(self): ... @abc.abstractmethod - def pin_memory(self, tensor): + def pin_memory(self, tensor, align_bytes=1): + ... + + @abc.abstractmethod + def is_pinned(self, tensor): ... @abc.abstractmethod @@ -238,3 +281,23 @@ def get_op_builder(self, class_name): @abc.abstractmethod def build_extension(self): ... + + @abc.abstractmethod + def export_envs(self): + ... + + @abc.abstractmethod + def visible_devices_envs(self): + ... + + @abc.abstractmethod + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + ... + + @abc.abstractmethod + def get_compile_backend(self): + ... + + @abc.abstractmethod + def set_compile_backend(self, backend): + ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py new file mode 100644 index 000000000000..4ff0f4dd7527 --- /dev/null +++ b/accelerator/cpu_accelerator.py @@ -0,0 +1,358 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .abstract_accelerator import DeepSpeedAccelerator + +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch +except ImportError as e: + pass + +try: + import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore + oneccl_imported_p = True +except ImportError as e: + oneccl_imported_p = False + +import os + + +# accelerator for Intel CPU +class CPU_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'cpu' + self._compile_backend = "inductor" + if oneccl_imported_p: + self._communication_backend_name = 'ccl' + else: + # fallback to gloo if oneccl_binding_for_pytorch is not installed + self._communication_backend_name = 'gloo' + try: + import psutil + mem = psutil.Process().memory_info().rss + self.max_mem = mem + except ImportError as e: + self.max_mem = 0 + + def is_synchronized_device(self): + return True + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + return 'cpu' + + def device(self, device_index=None): + return None + + def set_device(self, device_index): + return + + def current_device(self): + return os.environ.get('LOCAL_RANK', 0) + + def current_device_name(self): + return 'cpu' + + def device_count(self): + device_count = int(os.environ.get('LOCAL_SIZE', 0)) + if device_count > 0: + return device_count + else: + from deepspeed.utils.numa import get_numa_cores + # Count NUMA node for number of cpu accelerators. On machine with HBM + # In flat mode, HBM is in separate NUMA node with no cores on this node. + # Ignore these NUMA nodes with no cores. + numa_core_lists = get_numa_cores() + if not numa_core_lists: + return 1 + numa_count = 0 + prev_core_list = [] + for core_list in numa_core_lists: + if len(core_list) > 0 and core_list != prev_core_list: + numa_count += 1 + prev_core_list = core_list + return numa_count + + def synchronize(self, device_index=None): + return + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + if device_index is None: + return torch.set_rng_state(new_state) + return torch.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + return torch.get_rng_state() + + def manual_seed(self, seed): + return torch.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.manual_seed(seed) + + def initial_seed(self): + return torch.initial_seed() + + def default_generator(self, device_index): + return torch.default_generator + + # Streams/Events + @property + def Stream(self): + return None + + def stream(self, stream): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def current_stream(self, device_index=None): + return None + + def default_stream(self, device_index=None): + return None + + @property + def Event(self): + return None + + # Memory management + def empty_cache(self): + return + + def get_rss(self): + import psutil + mem = psutil.Process().memory_info().rss + if mem > self.max_mem: + self.max_mem = mem + return mem + + def reset_rss(self): + import psutil + mem = psutil.Process().memory_info().rss + self.max_mem = mem + return mem + + def memory_allocated(self, device_index=None): + return self.get_rss() + + def max_memory_allocated(self, device_index=None): + self.get_rss() + return self.max_mem + + def reset_max_memory_allocated(self, device_index=None): + self.reset_rss() + return + + def memory_cached(self, device_index=None): + return self.get_rss() + + def max_memory_cached(self, device_index=None): + self.get_rss() + return self.max_mem + + def reset_max_memory_cached(self, device_index=None): + self.reset_rss() + return + + def memory_stats(self, device_index=None): + mem = self.get_rss() + mem_stat = {} + mem_stat['allocated_bytes.all.current'] = mem + mem_stat['allocated_bytes.all.peak'] = self.max_mem + return mem_stat + + def reset_peak_memory_stats(self, device_index=None): + self.reset_rss() + return + + def memory_reserved(self, device_index=None): + return self.get_rss() + + def max_memory_reserved(self, device_index=None): + self.get_rss() + return self.max_mem + + def total_memory(self, device_index=None): + import psutil + return psutil.virtual_memory().total + + def available_memory(self, device_index=None): + import psutil + return psutil.virtual_memory().available + + # Misc + def is_available(self): + return True + + def range_push(self, msg, domain=None, category=None): + # TODO itt is currently not supported yet + # return torch.profiler.itt.range_push(msg) + return + + def range_pop(self, domain=None): + # TODO itt is currently not supported yet + # return torch.profiler.itt.range_pop() + return + + def lazy_call(self, callback): + return callback() + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + try: + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + return True + except Exception: + return False + + def supported_dtypes(self): + supported_dtypes = [torch.float, torch.bfloat16] + if self.is_fp16_supported(): + supported_dtypes.append(torch.float16) + return supported_dtypes + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + @property + def BFloat16Tensor(self): + return torch.BFloat16Tensor + + @property + def ByteTensor(self): + return torch.ByteTensor + + @property + def DoubleTensor(self): + return torch.DoubleTensor + + @property + def FloatTensor(self): + return torch.FloatTensor + + @property + def HalfTensor(self): + return torch.HalfTensor + + @property + def IntTensor(self): + return torch.IntTensor + + @property + def LongTensor(self): + return torch.LongTensor + + def pin_memory(self, tensor, align_bytes=1): + return tensor + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.cpu" + except ImportError: + return "deepspeed.ops.op_builder.cpu" + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('cpu'): + return True + else: + return False + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, op_name): + builder_class = self.get_op_builder(op_name) + if builder_class is not None: + return builder_class() + return None + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + except ImportError: + from deepspeed.ops.op_builder.cpu import AsyncIOBuilder, CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + + if class_name == "CCLCommBuilder": + return CCLCommBuilder + elif class_name == "ShareMemCommBuilder": + return ShareMemCommBuilder + elif class_name == "FusedAdamBuilder": + return FusedAdamBuilder + elif class_name == "CPUAdamBuilder": + return CPUAdamBuilder + elif class_name == "AsyncIOBuilder": + return AsyncIOBuilder + else: + # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests + return NotImplementedBuilder + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return [] + + # TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 36341a3c19b3..24766f8c0a81 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -3,9 +3,11 @@ # DeepSpeed Team +import functools import os import pkgutil import importlib +import sys from .abstract_accelerator import DeepSpeedAccelerator # During setup stage torch may not be installed, pass on no torch will @@ -15,37 +17,61 @@ except ImportError: pass +try: + import nvtx +except ImportError: + nvtx = None + +# Delay import pynvml to avoid import error when CUDA is not available +pynvml = None + class CUDA_Accelerator(DeepSpeedAccelerator): + supports_nvtx_domain = True def __init__(self): self._name = 'cuda' - self._communication_backend_name = 'nccl' - - # begin initialize for create_op_builder() - # put all valid class name <--> class type mapping into class_dict - op_builder_dir = self.op_builder_dir() - op_builder_module = importlib.import_module(op_builder_dir) - for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]): - # avoid self references - if module_name != 'all_ops' and module_name != 'builder': - module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) - for member_name in module.__dir__(): - if member_name.endswith( - 'Builder' - ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes - if not member_name in self.class_dict: - self.class_dict[member_name] = getattr(module, member_name) - # end initialize for create_op_builder() + self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo' + self._compile_backend = "inductor" + self._nvtx_domains = {} + if pynvml is None: + self._init_pynvml() + + def _init_pynvml(self): + global pynvml + try: + import pynvml + except ImportError: + return + try: + pynvml.nvmlInit() + except pynvml.NVMLError: + pynvml = None + return + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() # Device APIs def device_name(self, device_index=None): - if device_index == None: + if device_index is None: return 'cuda' return 'cuda:{}'.format(device_index) + def communication_backend_version(self): + return torch.cuda.nccl.version() + def device(self, device_index=None): - return torch.cuda.device(device_index) + return torch.device('cuda', device_index) def set_device(self, device_index): torch.cuda.set_device(device_index) @@ -84,8 +110,8 @@ def manual_seed(self, seed): def manual_seed_all(self, seed): return torch.cuda.manual_seed_all(seed) - def initial_seed(self, seed): - return torch.cuda.initial_seed(seed) + def initial_seed(self): + return torch.cuda.initial_seed() def default_generator(self, device_index): return torch.cuda.default_generators[device_index] @@ -149,33 +175,85 @@ def max_memory_reserved(self, device_index=None): def total_memory(self, device_index=None): return torch.cuda.get_device_properties(device_index).total_memory + def _get_nvml_gpu_id(self, torch_gpu_id): + """ + credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020 + + Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES. + + If the latter isn't set return the same id + """ + # if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var + if "CUDA_VISIBLE_DEVICES" in os.environ: + ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))) + return ids[torch_gpu_id] # remap + else: + return torch_gpu_id + + def available_memory(self, device_index=None): + if pynvml: + if device_index is None: + device_index = self.current_device() + handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index)) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.free + else: + return self.total_memory(device_index) - self.memory_allocated(device_index) + # Data types def is_bf16_supported(self): + if not torch.cuda.is_available(): + return True return torch.cuda.is_bf16_supported() def is_fp16_supported(self): + if not torch.cuda.is_available(): + return True + # See https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix + # FP16 on compute capability 6.x is deprecated + allow_deprecated_fp16 = os.environ.get('DS_ALLOW_DEPRECATED_FP16', '0') == '1' major, _ = torch.cuda.get_device_capability() if major >= 7: return True + elif major == 6 and allow_deprecated_fp16: + return True else: return False - # Misc - def amp(self): - if hasattr(torch.cuda, 'amp'): - return torch.cuda.amp - return None + def supported_dtypes(self): + supported_dtypes = [torch.float] + if self.is_fp16_supported(): + supported_dtypes.append(torch.half) + if self.is_bf16_supported(): + supported_dtypes.append(torch.bfloat16) + return supported_dtypes + # Misc def is_available(self): return torch.cuda.is_available() - def range_push(self, msg): - if hasattr(torch.cuda.nvtx, 'range_push'): - return torch.cuda.nvtx.range_push(msg) - - def range_pop(self): - if hasattr(torch.cuda.nvtx, 'range_pop'): - return torch.cuda.nvtx.range_pop() + def _get_nvtx_domain(self, domain): + if nvtx is None or domain is None: + return None + if domain not in self._nvtx_domains: + self._nvtx_domains[domain] = nvtx.get_domain(domain) + return self._nvtx_domains[domain] + + def range_push(self, msg, domain=None, category=None): + nvtx_domain = self._get_nvtx_domain(domain) + if nvtx_domain is not None: + return nvtx_domain.push_range(message=msg, category=category) + torch_nvtx = getattr(torch.cuda, 'nvtx', None) + if torch_nvtx is not None and hasattr(torch_nvtx, 'range_push'): + return torch_nvtx.range_push(msg) + + def range_pop(self, domain=None): + nvtx_domain = self._get_nvtx_domain(domain) + if nvtx_domain is not None: + return nvtx_domain.pop_range() + torch_nvtx = getattr(torch.cuda, 'nvtx', None) + if torch_nvtx is not None and hasattr(torch_nvtx, 'range_pop'): + return torch_nvtx.range_pop() def lazy_call(self, callback): return torch.cuda._lazy_call(callback) @@ -183,39 +261,62 @@ def lazy_call(self, callback): def communication_backend_name(self): return self._communication_backend_name + def is_triton_supported(self): + if not self.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + if major >= 8: + return True + else: + return False + + # Graph operations + def create_graph(self): + return torch.cuda.CUDAGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return torch.cuda.graph(graph, pool, stream) + + def replay_graph(self, graph): + graph.replay() + return + # Tensor operations @property def BFloat16Tensor(self): - return torch.cuda.BFloat16Tensor + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda') @property def ByteTensor(self): - return torch.cuda.ByteTensor + return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda') @property def DoubleTensor(self): - return torch.cuda.DoubleTensor + return functools.partial(torch.tensor, dtype=torch.double, device='cuda') @property def FloatTensor(self): - return torch.cuda.FloatTensor + return functools.partial(torch.tensor, dtype=torch.float, device='cuda') @property def HalfTensor(self): - return torch.cuda.HalfTensor + return functools.partial(torch.tensor, dtype=torch.half, device='cuda') @property def IntTensor(self): - return torch.cuda.IntTensor + return functools.partial(torch.tensor, dtype=torch.int, device='cuda') @property def LongTensor(self): - return torch.cuda.LongTensor + return functools.partial(torch.tensor, dtype=torch.long, device='cuda') - def pin_memory(self, tensor): + def pin_memory(self, tensor, align_bytes=1): return tensor.pin_memory() + def is_pinned(self, tensor): + return tensor.is_pinned() + def on_accelerator(self, tensor): device_str = str(tensor.device) if device_str.startswith('cuda:'): @@ -227,7 +328,7 @@ def op_builder_dir(self): try: # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path - from op_builder import __deepspeed__ # noqa: F401 + from op_builder import __deepspeed__ # noqa: F401 # type: ignore return "op_builder" except ImportError: return "deepspeed.ops.op_builder" @@ -235,10 +336,35 @@ def op_builder_dir(self): # dict that holds class name <--> class type mapping i.e. # 'AsyncIOBuilder': # this dict will be filled at init stage - class_dict = {} + class_dict = None + + def _lazy_init_class_dict(self): + if self.class_dict is not None: + return + else: + self.class_dict = {} + # begin initialize for create_op_builder() + # put all valid class name <--> class type mapping into class_dict + op_builder_dir = self.op_builder_dir() + op_builder_module = importlib.import_module(op_builder_dir) + op_builder_absolute_path = os.path.dirname(op_builder_module.__file__) + for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]): + # avoid self references, + # skip sub_directories which contains ops for other backend(cpu, npu, etc.). + if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir( + os.path.join(op_builder_absolute_path, module_name)): + module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) + for member_name in module.__dir__(): + if member_name.endswith( + 'Builder' + ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes + if not member_name in self.class_dict: + self.class_dict[member_name] = getattr(module, member_name) + # end initialize for create_op_builder() # create an instance of op builder and return, name specified by class_name def create_op_builder(self, class_name): + self._lazy_init_class_dict() if class_name in self.class_dict: return self.class_dict[class_name]() else: @@ -246,6 +372,7 @@ def create_op_builder(self, class_name): # return an op builder class, name specified by class_name def get_op_builder(self, class_name): + self._lazy_init_class_dict() if class_name in self.class_dict: return self.class_dict[class_name] else: @@ -254,3 +381,24 @@ def get_op_builder(self, class_name): def build_extension(self): from torch.utils.cpp_extension import BuildExtension return BuildExtension + + def export_envs(self): + return ['NCCL'] + + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py new file mode 100644 index 000000000000..e809ae3e1dab --- /dev/null +++ b/accelerator/hpu_accelerator.py @@ -0,0 +1,328 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import functools +import os +import pkgutil +import importlib +import torch + +from .abstract_accelerator import DeepSpeedAccelerator + + +class HPU_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'hpu' + self._communication_backend_name = 'hccl' + self._compile_backend = "hpu_backend" + self.apply_hpu_workarounds() + try: + import habana_frameworks.torch.hpu as hpu + self.hpu = hpu + torch.use_deterministic_algorithms(True) + # TODO: remove this WA when memory mapping break is resolved. + torch.utils.deterministic.fill_uninitialized_memory = False + except ImportError as e: + raise ValueError( + "HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") + + self.fp16_supported = None + + def apply_hpu_workarounds(self): + + def update_wa_env_var(key, value): + if key not in os.environ.keys(): + os.environ[key] = value + + update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0") + update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0") + + # Device APIs + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return False + + def resolves_data_dependency(self): + return True + + def handles_memory_backpressure(self): + return True + + def device_name(self, device_index=None): + # ignoring device_index. + return 'hpu' + + def device(self, device_index=None): + return torch.device(self.device_name(device_index)) + + def set_device(self, device_index): + self.hpu.set_device(device_index) + + def current_device(self): + return (self.hpu.current_device()) + + def current_device_name(self): + return 'hpu:{}'.format(self.current_device()) + + def device_count(self): + return self.hpu.device_count() + + def synchronize(self, device_index=None): + return self.hpu.synchronize() + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + self.hpu.random.set_rng_state(new_state) + + def get_rng_state(self, device_index=None): + return self.hpu.random.get_rng_state() + + def manual_seed(self, seed): + return self.hpu.random.manual_seed(seed) + + def manual_seed_all(self, seed): + self.hpu.random.manual_seed_all(seed) + + def initial_seed(self): + return self.hpu.random.initial_seed() + + def default_generator(self, device_index): + return self.hpu.random.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + return self.hpu.Stream + + def stream(self, stream): + return self.hpu.stream(stream) + + def current_stream(self, device_index=None): + return self.hpu.current_stream() + + def default_stream(self, device_index=None): + return self.hpu.default_stream() + + @property + def Event(self): + import habana_frameworks.torch.core as htcore + return htcore.hpu.Event + + # Memory management + def empty_cache(self): + return + + def memory_allocated(self, device_index=None): + return self.hpu.memory_allocated() + + def max_memory_allocated(self, device_index=None): + return self.hpu.max_memory_allocated() + + def reset_max_memory_allocated(self, device_index=None): + return self.hpu.reset_max_memory_allocated() + + def memory_cached(self, device_index=None): + return self.hpu.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + return self.hpu.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + return None + + def memory_stats(self, device_index=None): + return self.hpu.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + self.hpu.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + return self.hpu.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + return self.hpu.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return self.memory_stats(device_index)['Limit'] + + def available_memory(self, device_index=None): + return self.total_memory(device_index) - self.memory_allocated(device_index) + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + if self.fp16_supported is None: + import habana_frameworks.torch.utils.experimental as htexp + self.fp16_supported = htexp._is_fp16_supported() + return self.fp16_supported + + def supported_dtypes(self): + supported_dtypes = [torch.float, torch.bfloat16] + if self.is_fp16_supported(): + supported_dtypes.append(torch.half) + return supported_dtypes + + # Misc + def is_available(self): + return self.hpu.is_available() + + def range_push(self, msg, domain=None, category=None): + return + + def range_pop(self, domain=None): + return + + def lazy_call(self, callback): + callback() + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return self.hpu.HPUGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return self.hpu.graph(graph, stream=stream) + + def replay_graph(self, graph): + graph.replay() + return + + # Tensor operations + @property + def BFloat16Tensor(self): + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu') + + @property + def ByteTensor(self): + return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu') + + @property + def DoubleTensor(self): + return functools.partial(torch.tensor, dtype=torch.double, device='hpu') + + @property + def FloatTensor(self): + return functools.partial(torch.tensor, dtype=torch.float, device='hpu') + + @property + def HalfTensor(self): + return functools.partial(torch.tensor, dtype=torch.half, device='hpu') + + @property + def IntTensor(self): + return functools.partial(torch.tensor, dtype=torch.int, device='hpu') + + @property + def LongTensor(self): + return functools.partial(torch.tensor, dtype=torch.long, device='hpu') + + def pin_memory(self, tensor, align_bytes=1): + return tensor.pin_memory(self.device()) + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('hpu:'): + return True + else: + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.hpu" + except ImportError: + return "deepspeed.ops.op_builder.hpu" + + # dict that holds class name <--> class type mapping i.e. + # 'AsyncIOBuilder': + # this dict will be filled at init stage + class_dict = None + + def _lazy_init_class_dict(self): + if self.class_dict is not None: + return + else: + self.class_dict = {} + # begin initialize for create_op_builder() + # put all valid class name <--> class type mapping into class_dict + op_builder_dir = self.op_builder_dir() + op_builder_module = importlib.import_module(op_builder_dir) + op_builder_absolute_path = os.path.dirname(op_builder_module.__file__) + for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]): + # avoid self references, + # skip sub_directories which contains ops for other backend(cpu, npu, etc.). + if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir( + os.path.join(op_builder_absolute_path, module_name)): + module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) + for member_name in module.__dir__(): + if member_name.endswith( + 'Builder' + ) and member_name != "OpBuilder" and member_name != "CPUOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes + if not member_name in self.class_dict: + self.class_dict[member_name] = getattr(module, member_name) + # end initialize for create_op_builder() + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name]() + else: + return None + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return [] + + def visible_devices_envs(self): + # Current way deepspeed set this env var is not applicable with all HPU instances + # User has to follow instructions in: + # https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html + # keeping CUDA_VISIBLE_DEVICES + return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/mlu_accelerator.py b/accelerator/mlu_accelerator.py new file mode 100644 index 000000000000..55ae8dee3d0d --- /dev/null +++ b/accelerator/mlu_accelerator.py @@ -0,0 +1,295 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import importlib +import inspect +import functools + +from .abstract_accelerator import DeepSpeedAccelerator +import torch +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. + + +class MLU_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'mlu' + self._communication_backend_name = 'cncl' + self._compile_backend = "inductor" + self.class_dict = None + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + if device_index == None: + return 'mlu' + return 'mlu:{}'.format(device_index) + + def device(self, device_index=None): + return torch.mlu.device(device_index) + + def set_device(self, device_index): + torch.mlu.set_device(device_index) + + def current_device(self): + return torch.mlu.current_device() + + def current_device_name(self): + return 'mlu:{}'.format(torch.mlu.current_device()) + + def device_count(self): + return torch.mlu.device_count() + + def synchronize(self, device_index=None): + return torch.mlu.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + if device_index is None: + return torch.mlu.set_rng_state(new_state) + + return torch.mlu.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index is None: + return torch.mlu.get_rng_state() + + return torch.mlu.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.mlu.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.mlu.manual_seed_all(seed) + + def initial_seed(self, seed): + return torch.mlu.initial_seed(seed) + + def default_generator(self, device_index): + return torch.mlu.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + return torch.mlu.Stream + + def stream(self, stream): + return torch.mlu.stream(stream) + + def current_stream(self, device_index=None): + return torch.mlu.current_stream(device_index) + + def default_stream(self, device_index=None): + return torch.mlu.default_stream(device_index) + + @property + def Event(self): + return torch.mlu.Event + + # Memory management + def empty_cache(self): + return torch.mlu.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.mlu.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.mlu.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.mlu.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.mlu.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + return torch.mlu.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.mlu.reset_max_memory_cached(device_index) + + def memory_stats(self, device_index=None): + if hasattr(torch.mlu, 'memory_stats'): + return torch.mlu.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + if hasattr(torch.mlu, 'reset_peak_memory_stats'): + return torch.mlu.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + if hasattr(torch.mlu, 'memory_reserved'): + return torch.mlu.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + if hasattr(torch.mlu, 'max_memory_reserved'): + return torch.mlu.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.mlu.get_device_properties(device_index).total_memory + + def available_memory(self, device_index=None): + return self.total_memory(device_index) - self.memory_allocated(device_index) + + # Data types + def is_bf16_supported(self): + return torch.mlu.is_bf16_supported() + + def is_fp16_supported(self): + return True + + def supported_dtypes(self): + supported_dtypes = [torch.float] + if self.is_fp16_supported(): + supported_dtypes.append(torch.half) + if self.is_bf16_supported(): + supported_dtypes.append(torch.bfloat16) + return supported_dtypes + + # Misc + def is_available(self): + return torch.mlu.is_available() + + def range_push(self, msg, domain=None, category=None): + if hasattr(torch.mlu.cnpx, 'range_push'): + return torch.mlu.cnpx.range_push(msg) + + def range_pop(self, domain=None): + if hasattr(torch.mlu.cnpx, 'range_pop'): + return torch.mlu.cnpx.range_pop() + + def lazy_call(self, callback): + return torch.mlu._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return True + + # Graph operations + def create_graph(self): + torch.mlu.MLUGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return torch.mlu.graph(graph, pool, stream) + + def replay_graph(self, graph): + graph.replay() + return + + # Tensor operations + + @property + def BFloat16Tensor(self): + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='mlu') + + @property + def ByteTensor(self): + return functools.partial(torch.tensor, dtype=torch.uint8, device='mlu') + + @property + def DoubleTensor(self): + return functools.partial(torch.tensor, dtype=torch.double, device='mlu') + + @property + def FloatTensor(self): + return functools.partial(torch.tensor, dtype=torch.float, device='mlu') + + @property + def HalfTensor(self): + return functools.partial(torch.tensor, dtype=torch.half, device='mlu') + + @property + def IntTensor(self): + return functools.partial(torch.tensor, dtype=torch.int, device='mlu') + + @property + def LongTensor(self): + return functools.partial(torch.tensor, dtype=torch.long, device='mlu') + + def pin_memory(self, tensor): + return tensor.pin_memory() + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('mlu:'): + return True + else: + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.mlu" + except ImportError: + return "deepspeed.ops.op_builder.mlu" + + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/mlu/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + builder_class = self.get_op_builder(class_name) + return builder_class() + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return ['NEUWARE_HOME', 'CNCL', 'LD_LIBRARY', 'PATH'] + + def visible_devices_envs(self): + return ['MLU_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py new file mode 100644 index 000000000000..f6600beb779c --- /dev/null +++ b/accelerator/mps_accelerator.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .abstract_accelerator import DeepSpeedAccelerator + +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch.mps +except ImportError: + pass + + +class MPS_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = "mps" + self._communication_backend_name = None + self._compile_backend = "inductor" + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + # Event timers are not supported on MPS + return True + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + if device_index is None: + return "mps" + return "mps:{}".format(device_index) + + def device(self, device_index): + return torch.device("mps", index=0) + + def set_device(self, device_index): + return + + def current_device(self): + return torch.device("mps", index=0) + + def current_device_name(self): + return "mps:0" + + def device_count(self): + return 1 + + def synchronize(self, device_index=None): + return torch.mps.synchronize() + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + return torch.mps.set_rng_state(new_state) + + def get_rng_state(self, device_index=None): + return torch.mps.get_rng_state() + + def manual_seed(self, seed): + return torch.mps.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.mps.manual_seed(seed) + + def seed(self): + return torch.mps.seed() + + def initial_seed(self): + return + + def default_generator(self, device_index): + return + + # Streams/Events + @property + def Stream(self): + return None + + def stream(self, stream): + return None + + def current_stream(self, device_index=None): + return None + + def default_stream(self, device_index=None): + return None + + @property + def Event(self): + return None + + # Memory management + def empty_cache(self): + return torch.mps.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.mps.current_allocated_memory() + + def max_memory_allocated(self, device_index=None): + return torch.mps.driver_allocated_memory() + + def set_per_process_memory_fraction(self, fraction): + return torch.mps.set_per_process_memory_fraction(fraction) + + def reset_max_memory_allocated(self, device_index=None): + return + + def memory_cached(self, device_index=None): + return + + def max_memory_cached(self, device_index=None): + return + + def reset_max_memory_cached(self, device_index=None): + return + + def memory_stats(self, device_index=None): + return + + def reset_peak_memory_stats(self, device_index=None): + return + + def memory_reserved(self, device_index=None): + return + + def max_memory_reserved(self, device_index=None): + return + + def total_memory(self, device_index=None): + return + + def available_memory(self, device_index=None): + return + + # Data types + def is_bf16_supported(self): + return False + + def is_fp16_supported(self): + return False + + def supported_dtypes(self): + return [torch.float] + + # Misc + def is_available(self): + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + def range_push(self, msg, domain=None, category=None): + return + + def range_pop(self, domain=None): + return + + def lazy_call(self, callback): + return + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + @property + def BFloat16Tensor(self): + return + + @property + def ByteTensor(self): + return + + @property + def DoubleTensor(self): + return + + @property + def FloatTensor(self): + return + + @property + def HalfTensor(self): + return + + @property + def IntTensor(self): + return + + @property + def LongTensor(self): + return + + def pin_memory(self, tensor, align_bytes=1): + return tensor.pin_memory() + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith("mps"): + return True + else: + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + + return "op_builder" + except ImportError: + return "deepspeed.ops.op_builder" + + # create an instance of op builder, specified by class_name + def create_op_builder(self, op_name): + builder_class = self.get_op_builder(op_name) + if builder_class is not None: + return builder_class() + return None + + # return an op builder class, specified by class_name + def get_op_builder(self, class_name): + from deepspeed.ops.op_builder.cpu import NotImplementedBuilder + + return NotImplementedBuilder + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + + return BuildExtension + + def export_envs(self): + return [] + + # TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + # TODO: could not find visible devices env for mps + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py new file mode 100644 index 000000000000..91515e8112ee --- /dev/null +++ b/accelerator/npu_accelerator.py @@ -0,0 +1,294 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import importlib +import inspect + +from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch.npu +except ImportError: + pass + + +class NPU_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + super().__init__() + self._name = 'npu' + self._communication_backend_name = 'hccl' + self._compile_backend = "inductor" + # dict that holds class name <--> class type mapping i.e. + # 'AsyncIOBuilder': + # this dict will be filled at init stage + self.class_dict = None + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + if device_index is None: + return 'npu' + return 'npu:{}'.format(device_index) + + def device(self, device_index=None): + return torch.device('npu', device_index) + + def set_device(self, device_index): + torch.npu.set_device(device_index) + + def current_device(self): + return torch.npu.current_device() + + def current_device_name(self): + return 'npu:{}'.format(torch.npu.current_device()) + + def device_count(self): + return torch.npu.device_count() + + def synchronize(self, device_index=None): + return torch.npu.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + if device_index is None: + return torch.npu.set_rng_state(new_state) + + return torch.npu.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index is None: + return torch.npu.get_rng_state() + + return torch.npu.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.npu.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.npu.manual_seed_all(seed) + + def initial_seed(self): + return torch.npu.initial_seed() + + def default_generator(self, device_index): + return torch.npu.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + return torch.npu.Stream + + def stream(self, stream): + return torch.npu.stream(stream) + + def current_stream(self, device_index=None): + return torch.npu.current_stream(device_index) + + def default_stream(self, device_index=None): + return torch.npu.default_stream(device_index) + + @property + def Event(self): + return torch.npu.Event + + # Memory management + def empty_cache(self): + return torch.npu.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.npu.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.npu.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.npu.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.npu.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + return torch.npu.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.npu.reset_max_memory_cached(device_index) + + def memory_stats(self, device_index=None): + if hasattr(torch.npu, 'memory_stats'): + return torch.npu.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + if hasattr(torch.npu, 'reset_peak_memory_stats'): + return torch.npu.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + if hasattr(torch.npu, 'memory_reserved'): + return torch.npu.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + if hasattr(torch.npu, 'max_memory_reserved'): + return torch.npu.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.npu.get_device_properties(device_index).total_memory + + def available_memory(self, device_index=None): + return self.total_memory(device_index) - self.memory_allocated(device_index) + + # Data types + def is_bf16_supported(self): + return torch.npu.is_bf16_supported() + + def is_fp16_supported(self): + return True + + def supported_dtypes(self): + return [torch.float, torch.half, torch.bfloat16] + + # Misc + def is_available(self): + return torch.npu.is_available() + + def range_push(self, msg, domain=None, category=None): + return + + def range_pop(self, domain=None): + return + + def lazy_call(self, callback): + return torch.npu._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + + @property + def BFloat16Tensor(self): + return torch.npu.BFloat16Tensor + + @property + def ByteTensor(self): + return torch.npu.ByteTensor + + @property + def DoubleTensor(self): + return torch.npu.DoubleTensor + + @property + def FloatTensor(self): + return torch.npu.FloatTensor + + @property + def HalfTensor(self): + return torch.npu.HalfTensor + + @property + def IntTensor(self): + return torch.npu.IntTensor + + @property + def LongTensor(self): + return torch.npu.LongTensor + + def pin_memory(self, tensor, align_bytes=1): + return tensor.pin_memory() + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('npu:'): + return True + else: + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.npu" + except ImportError: + return "deepspeed.ops.op_builder.npu" + + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/npu/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + builder_class = self.get_op_builder(class_name) + return None if builder_class is None else builder_class() + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH'] + + def visible_devices_envs(self): + return ['ASCEND_RT_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index a31846f53aa9..35cac4b94b70 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -2,6 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +import os + +try: + # Importing logger currently requires that torch is installed, hence the try...except + # TODO: Remove logger dependency on torch. + from deepspeed.utils import logger as accel_logger +except ImportError as e: + accel_logger = None try: from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1 @@ -12,6 +20,8 @@ except ImportError as e: dsa2 = None +SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa'] + ds_accelerator = None @@ -26,46 +36,212 @@ def _validate_accelerator(accel_obj): # accelerator.abstractor_accelerator # or deepspeed.accelerator.abstract_accelerator, consider accel_obj # is a conforming object - if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))): - raise AssertionError(f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator') + if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))): + raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator") # TODO: turn off is_available test since this breaks tests - #assert accel_obj.is_available(), \ + # assert accel_obj.is_available(), \ # f'{accel_obj.__class__.__name__} accelerator fails is_available() test' +def is_current_accelerator_supported(): + return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST + + def get_accelerator(): global ds_accelerator - if ds_accelerator is None: + if ds_accelerator is not None: + return ds_accelerator + + accelerator_name = None + ds_set_method = None + # 1. Detect whether there is override of DeepSpeed accelerators from environment variable. + if "DS_ACCELERATOR" in os.environ.keys(): + accelerator_name = os.environ["DS_ACCELERATOR"] + if accelerator_name == "xpu": + try: + import torch + assert hasattr(torch, 'xpu') and torch.xpu.is_available(), \ + "XPU_Accelerator requires PyTorch with XPU support (torch.xpu)." + except (ImportError, AssertionError) as e: + raise ValueError(f"XPU_Accelerator requires PyTorch with XPU support: {e}") + elif accelerator_name == "cpu": + pass + elif accelerator_name == "npu": + try: + import torch_npu # noqa: F401 # type: ignore + except ImportError as e: + raise ValueError("NPU_Accelerator requires torch_npu, which is not installed on this system.") + pass + elif accelerator_name == "sdaa": + try: + import torch_sdaa # noqa: F401 # type: ignore + except ImportError as e: + raise ValueError("SDAA_Accelerator requires torch_sdaa, which is not installed on this system.") + pass + elif accelerator_name == "mps": + try: + import torch.mps + + # should use torch.mps.is_available() if it exists someday but this is used as proxy + torch.mps.current_allocated_memory() + except (RuntimeError, ImportError) as e: + raise ValueError("MPS_Accelerator requires torch.mps, which is not installed on this system.") + elif accelerator_name == "hpu": + try: + import habana_frameworks.torch.hpu # noqa: F401 + except ImportError as e: + raise ValueError( + "HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") + elif accelerator_name == "mlu": + try: + import torch_mlu # noqa: F401 + except ImportError as e: + raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.") + elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST: + raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. ' + f'Value "{accelerator_name}" is not supported') + ds_set_method = "override" + + # 2. If no override, detect which accelerator to use automatically + if accelerator_name is None: + # We need a way to choose among different accelerator types. + # Currently we detect which accelerator extension is installed + # in the environment and use it if the installing answer is True. + # An alternative might be detect whether CUDA device is installed on + # the system but this comes with two pitfalls: + # 1. the system may not have torch pre-installed, so + # get_accelerator().is_available() may not work. + # 2. Some scenario like install on login node (without CUDA device) + # and run on compute node (with CUDA device) may cause mismatch + # between installation time and runtime. + try: - from intel_extension_for_deepspeed import XPU_Accelerator + import torch + + # Detect XPU via PyTorch + if hasattr(torch, 'xpu'): + if torch.xpu.is_available(): + accelerator_name = "xpu" except ImportError as e: pass - else: - ds_accelerator = XPU_Accelerator() - _validate_accelerator(ds_accelerator) - return ds_accelerator + if accelerator_name is None: + try: + import torch_npu # noqa: F401,F811 # type: ignore + + accelerator_name = "npu" + except ImportError as e: + pass + if accelerator_name is None: + try: + import torch_sdaa # noqa: F401,F811 # type: ignore + accelerator_name = "sdaa" + except ImportError as e: + pass + if accelerator_name is None: + try: + import torch.mps + + # should use torch.mps.is_available() if it exists someday but this is used as proxy + torch.mps.current_allocated_memory() + accelerator_name = "mps" + except (RuntimeError, ImportError) as e: + pass + if accelerator_name is None: + try: + import habana_frameworks.torch.hpu # noqa: F401,F811 + + accelerator_name = "hpu" + except ImportError as e: + pass + if accelerator_name is None: + try: + import torch_mlu # noqa: F401,F811 + + accelerator_name = "mlu" + except ImportError as e: + pass + if accelerator_name is None: + try: + import torch + + # Determine if we are on a GPU or x86 CPU with torch. + # "torch.cuda.is_available()" provides a stronger guarantee, #ignore-cuda + # ensuring that we are free from CUDA initialization errors. + # While "torch.cuda.device_count() > 0" check ensures that #ignore-cuda + # we won't try to do any CUDA calls when no device is available + # For reference: https://github.com/deepspeedai/DeepSpeed/pull/6810 + if torch.cuda.device_count() > 0 and torch.cuda.is_available(): #ignore-cuda + accelerator_name = "cuda" + except (RuntimeError, ImportError) as e: + # TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection + pass + if accelerator_name is None: + # borrow this log from PR#5084 + if accel_logger is not None: + accel_logger.warning( + "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.") + # cpu added as catch-all when accelerator detection fails + accelerator_name = "cpu" + + ds_set_method = "auto detect" + + # 3. Set ds_accelerator accordingly + if accelerator_name == "cuda": from .cuda_accelerator import CUDA_Accelerator + ds_accelerator = CUDA_Accelerator() - _validate_accelerator(ds_accelerator) + elif accelerator_name == "cpu": + from .cpu_accelerator import CPU_Accelerator + + ds_accelerator = CPU_Accelerator() + elif accelerator_name == "xpu": + from .xpu_accelerator import XPU_Accelerator + + ds_accelerator = XPU_Accelerator() + elif accelerator_name == "npu": + from .npu_accelerator import NPU_Accelerator + + ds_accelerator = NPU_Accelerator() + elif accelerator_name == "sdaa": + from .sdaa_accelerator import SDAA_Accelerator + + ds_accelerator = SDAA_Accelerator() + elif accelerator_name == "mps": + from .mps_accelerator import MPS_Accelerator + + ds_accelerator = MPS_Accelerator() + elif accelerator_name == 'hpu': + from .hpu_accelerator import HPU_Accelerator + + ds_accelerator = HPU_Accelerator() + elif accelerator_name == 'mlu': + from .mlu_accelerator import MLU_Accelerator + + ds_accelerator = MLU_Accelerator() + _validate_accelerator(ds_accelerator) + if accel_logger is not None: + accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})") return ds_accelerator def set_accelerator(accel_obj): global ds_accelerator _validate_accelerator(accel_obj) + if accel_logger is not None and accel_obj is not None: + accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)") ds_accelerator = accel_obj -''' +""" -----------[code] test_get.py ----------- from deepspeed.accelerator import get_accelerator my_accelerator = get_accelerator() -print(f'{my_accelerator._name=}') -print(f'{my_accelerator._communication_backend=}') -print(f'{my_accelerator.HalfTensor().device=}') -print(f'{my_accelerator.total_memory()=}') +logger.info(f'{my_accelerator._name=}') +logger.info(f'{my_accelerator._communication_backend=}') +logger.info(f'{my_accelerator.HalfTensor().device=}') +logger.info(f'{my_accelerator.total_memory()=}') -----------[code] test_get.py ----------- ---[output] python test_get.py--------- @@ -79,16 +255,16 @@ def set_accelerator(accel_obj): -----------[code] test_set.py ----------- from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator cu_accel = CUDA_Accelerator() -print(f'{id(cu_accel)=}') +logger.info(f'{id(cu_accel)=}') from deepspeed.accelerator import set_accelerator, get_accelerator set_accelerator(cu_accel) my_accelerator = get_accelerator() -print(f'{id(my_accelerator)=}') -print(f'{my_accelerator._name=}') -print(f'{my_accelerator._communication_backend=}') -print(f'{my_accelerator.HalfTensor().device=}') -print(f'{my_accelerator.total_memory()=}') +logger.info(f'{id(my_accelerator)=}') +logger.info(f'{my_accelerator._name=}') +logger.info(f'{my_accelerator._communication_backend=}') +logger.info(f'{my_accelerator.HalfTensor().device=}') +logger.info(f'{my_accelerator.total_memory()=}') -----------[code] test_set.py ----------- @@ -100,4 +276,4 @@ def set_accelerator(accel_obj): my_accelerator.HalfTensor().device=device(type='cuda', index=0) my_accelerator.total_memory()=34089730048 ---[output] python test_set.py--------- -''' +""" diff --git a/accelerator/sdaa_accelerator.py b/accelerator/sdaa_accelerator.py new file mode 100755 index 000000000000..9940d4a97d17 --- /dev/null +++ b/accelerator/sdaa_accelerator.py @@ -0,0 +1,323 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team + +import importlib +import inspect +import functools + +from .abstract_accelerator import DeepSpeedAccelerator +# During setup stage torch may not be installed, pass on no torch will +# allow op builder related API to be executed. +try: + import torch.sdaa +except ImportError: + pass + + +class SDAA_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'sdaa' + self._communication_backend_name = 'tccl' + self._compile_backend = "inductor" + self.class_dict = None + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + if device_index is None: + return 'sdaa' + return 'sdaa:{}'.format(device_index) + + def device(self, device_index=None): + return torch.sdaa.device(device_index) + + def set_device(self, device_index): + torch.sdaa.set_device(device_index) + + def current_device(self): + return torch.sdaa.current_device() + + def current_device_name(self): + return 'sdaa:{}'.format(torch.sdaa.current_device()) + + def device_count(self): + return torch.sdaa.device_count() + + def synchronize(self, device_index=None): + return torch.sdaa.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + if device_index is None: + return torch.sdaa.set_rng_state(new_state) + + return torch.sdaa.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index is None: + return torch.sdaa.get_rng_state() + + return torch.sdaa.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.sdaa.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.sdaa.manual_seed_all(seed) + + def initial_seed(self): + return torch.sdaa.initial_seed() + + def default_generator(self, device_index): + return torch.sdaa.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + return torch.sdaa.Stream + + def stream(self, stream): + return torch.sdaa.stream(stream) + + def current_stream(self, device_index=None): + return torch.sdaa.current_stream(device_index) + + def default_stream(self, device_index=None): + return torch.sdaa.default_stream(device_index) + + @property + def Event(self): + return torch.sdaa.Event + + # Memory management + def empty_cache(self): + return torch.sdaa.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.sdaa.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.sdaa.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.sdaa.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.sdaa.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + return torch.sdaa.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.sdaa.reset_max_memory_cached(device_index) + + def memory_stats(self, device_index=None): + if hasattr(torch.sdaa, 'memory_stats'): + return torch.sdaa.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + if hasattr(torch.sdaa, 'reset_peak_memory_stats'): + return torch.sdaa.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + if hasattr(torch.sdaa, 'memory_reserved'): + return torch.sdaa.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + if hasattr(torch.sdaa, 'max_memory_reserved'): + return torch.sdaa.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.sdaa.get_device_properties(device_index).total_memory + + def available_memory(self, device_index=None): + return self.total_memory(device_index) - self.memory_allocated(device_index) + + # Data types + def is_bf16_supported(self): + return torch.sdaa.is_bf16_supported() + + def is_fp16_supported(self): + return True + + def supported_dtypes(self): + supported_dtypes = [torch.float] + if self.is_fp16_supported(): + supported_dtypes.append(torch.half) + if self.is_bf16_supported(): + supported_dtypes.append(torch.bfloat16) + return supported_dtypes + + # Misc + def is_available(self): + return torch.sdaa.is_available() + + def range_push(self, msg, domain=None, category=None): + return + + def range_pop(self, domain=None): + return + + def lazy_call(self, callback): + return torch.sdaa._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + + @property + def BFloat16Tensor(self): + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='sdaa') + + @property + def ByteTensor(self): + return functools.partial(torch.tensor, dtype=torch.uint8, device='sdaa') + + @property + def DoubleTensor(self): + return functools.partial(torch.tensor, dtype=torch.double, device='sdaa') + + @property + def FloatTensor(self): + return functools.partial(torch.tensor, dtype=torch.float, device='sdaa') + + @property + def HalfTensor(self): + return functools.partial(torch.tensor, dtype=torch.half, device='sdaa') + + @property + def IntTensor(self): + return functools.partial(torch.tensor, dtype=torch.int, device='sdaa') + + @property + def LongTensor(self): + return functools.partial(torch.tensor, dtype=torch.long, device='sdaa') + + def pin_memory(self, tensor, align_bytes=1): + return tensor.pin_memory() + + def is_pinned(self, tensor): + return tensor.is_pinned() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('sdaa:'): + return True + else: + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.sdaa" + except ImportError: + return "deepspeed.ops.op_builder.sdaa" + + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/sdaa/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + builder_class = self.get_op_builder(class_name) + return builder_class() + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return ['NCCL', 'LD_LIBRARY', 'PATH'] + + def visible_devices_envs(self): + return ['SDAA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py new file mode 100644 index 000000000000..9f6b21af54ea --- /dev/null +++ b/accelerator/xpu_accelerator.py @@ -0,0 +1,315 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator +import functools +import importlib +import inspect + +try: + import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore + oneccl_imported_p = True +except ImportError as e: + oneccl_imported_p = False + + +class XPU_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'xpu' + if oneccl_imported_p: + self._communication_backend_name = 'ccl' + else: + # changed to xccl if not using torch-CCL on XPU device + self._communication_backend_name = 'xccl' + self._compile_backend = "inductor" + self.aligned_tensors = [] + self.class_dict = None + + def is_synchronized_device(self): + return False + + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + + # Device APIs + def device_name(self, device_index=None): + if device_index == None: + return 'xpu' + return 'xpu:{}'.format(device_index) + + def device(self, device_index=None): + return torch.device('xpu', device_index) + + def set_device(self, device_index): + torch.xpu.set_device(device_index) + + def current_device(self): + return torch.xpu.current_device() + + def current_device_name(self): + return 'xpu:{}'.format(torch.xpu.current_device()) + + def device_count(self): + return torch.xpu.device_count() + + def synchronize(self, device_index=None): + return torch.xpu.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.xpu.random + + def set_rng_state(self, new_state, device_index=None): + if device_index == None: + return torch.xpu.set_rng_state(new_state) + return torch.xpu.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index == None: + return torch.xpu.get_rng_state() + return torch.xpu.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.xpu.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.xpu.manual_seed_all(seed) + + def initial_seed(self): + return torch.xpu.initial_seed() + + def default_generator(self, device_index): + return torch.xpu.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + return torch.xpu.Stream + + def stream(self, stream): + return torch.xpu.stream(stream) + + def current_stream(self, device_index=None): + return torch.xpu.current_stream(device_index) + + def default_stream(self, device_index=None): + # torch.xpu does not support the sync behavior of default stream as cuda + # use current_stream as workaround + # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams + return torch.xpu.current_stream(device_index) + + @property + def Event(self): + return torch.xpu.Event + + # Memory management + def empty_cache(self): + return torch.xpu.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.xpu.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.xpu.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.xpu.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.xpu.memory_reserved(device_index) + + def max_memory_cached(self, device_index=None): + return torch.xpu.max_memory_reserved(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.xpu.reset_max_memory_reserved(device_index) + + def memory_stats(self, device_index=None): + return torch.xpu.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + return torch.xpu.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + return torch.xpu.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + return torch.xpu.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.xpu.get_device_properties(device_index).total_memory + + def available_memory(self, device_index=None): + return self.total_memory(device_index) - self.memory_allocated(device_index) + + # Misc + def is_available(self): + return torch.xpu.is_available() + + def range_push(self, msg, domain=None, category=None): + # TODO itt is currently not supported yet + # return torch.profiler.itt.range_push(msg) + return + + def range_pop(self, domain=None): + # TODO itt is currently not supported yet + # return torch.profiler.itt.range_pop() + return + + def lazy_call(self, callback): + if hasattr(torch.xpu, "_lazy_call"): + return torch.xpu._lazy_call(callback) + else: + return torch.xpu.lazy_init._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + return True + + def supported_dtypes(self): + return [torch.float, torch.half, torch.bfloat16] + + # Tensor operations + + @property + def BFloat16Tensor(self): + return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name) + + @property + def ByteTensor(self): + return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name) + + @property + def DoubleTensor(self): + return functools.partial(torch.tensor, dtype=torch.double, device=self._name) + + @property + def FloatTensor(self): + return functools.partial(torch.tensor, dtype=torch.float, device=self._name) + + @property + def HalfTensor(self): + return functools.partial(torch.tensor, dtype=torch.half, device=self._name) + + @property + def IntTensor(self): + return functools.partial(torch.tensor, dtype=torch.int, device=self._name) + + @property + def LongTensor(self): + return functools.partial(torch.tensor, dtype=torch.long, device=self._name) + + def pin_memory(self, tensor, align_bytes=1): + if align_bytes == 1: + return tensor.pin_memory(device=self.current_device_name()) + elif align_bytes == 0: + from deepspeed.ops.op_builder.xpu import AsyncIOBuilder + self.aio_handle = AsyncIOBuilder().load().aio_handle(128 * 1024, 8, False, False, False) + aligned_t = self.aio_handle.new_cpu_locked_tensor(tensor.numel(), tensor) + aligned_t = aligned_t[:tensor.numel()].copy_(tensor) + self.aligned_tensors.append([aligned_t.data_ptr(), aligned_t[-1].data_ptr()]) + return aligned_t + + def is_pinned(self, tensor): + if tensor.is_pinned(device=self.current_device_name()): + return True + else: + for begin, end in self.aligned_tensors: + if begin <= tensor.data_ptr() and tensor.data_ptr() <= end: + return True + return False + + def op_builder_dir(self): + try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + return "op_builder.xpu" + except ImportError: + return "deepspeed.ops.op_builder.xpu" + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('xpu:'): + return True + else: + return False + + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/xpu/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + builder_class = self.get_op_builder(class_name) + return builder_class() + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return self.class_dict['NotImplementedBuilder'] + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return [] + + def visible_devices_envs(self): + return ['ZE_AFFINITY_MASK'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/benchmarks/README.md b/benchmarks/README.md index 4c88b2dd091c..a2b332732042 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -2,5 +2,5 @@ If you are looking for DeepSpeed benchmarks, please see the following resources: -1. [Communication Benchmarking Suite](https://github.com/microsoft/DeepSpeedExamples/tree/master/benchmarks/communication) -2. [Inference Benchmarks](https://github.com/microsoft/DeepSpeedExamples/tree/master/benchmarks/inference) +1. [Communication Benchmarking Suite](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/communication) +2. [Inference Benchmarks](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/inference) diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py new file mode 100644 index 000000000000..e6ad8faf86ce --- /dev/null +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Benchmark: AutoSP multimodal sequence parallelism (ViT SP + fusion adapter). + +Measures per-iteration latency, throughput, and peak GPU memory for the +ViT-SP + fusion-adapter pipeline at a given SP degree. + +Launch (from repo root): + + # SP degree 2 — two GPUs: + NCCL_P2P_DISABLE=1 torchrun --nproc_per_node=2 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + + # Baseline — single GPU (all-gather/scatter are no-ops): + torchrun --nproc_per_node=1 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + +Compare the two output tables to quantify memory savings and throughput scaling. + +Arguments: + --arch {internvl, qwen2vl} architecture to simulate (default: internvl) + --batch-size N samples per batch (default: 2) + --seq-len N text sequence length (default: 512) + --visual-tokens N total visual tokens per sample (default: 256) + --hidden N hidden dimension (default: 1024) + --num-layers N ViT and LLM layers each (default: 2) + --iters N measured iterations (default: 50) + --warmup N warmup iterations (default: 10) +""" + +import argparse +import logging +import statistics + +import torch +import torch.nn as nn + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes — names match autosp_detector registries exactly +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model building blocks +# --------------------------------------------------------------------------- + + +class _ViTBlock(nn.Module): + """One ViT transformer block: attention (to be SP-wrapped) + linear FFN.""" + + def __init__(self, attn_cls, hidden: int) -> None: + super().__init__() + self.attn = attn_cls() + self.ffn = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x, **kwargs): + out = self.attn(x, **kwargs) + if isinstance(out, (tuple, list)): + out = out[0] + return self.ffn(out) + + +class _MinimalInternVLModel(nn.Module): + """InternVL-like benchmark model. + + Module paths detected by autosp_detector: + - ``vision_encoder.*.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``language_model`` uses plain nn.Linear layers so it is NOT wrapped by + DistributedAttention (avoids the Q/K/V interface requirement) yet still + contributes realistic compute on the scattered fused sequence. + """ + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(*[_ViTBlock(InternVisionAttention, hidden) for _ in range(num_layers)]) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.language_model(local_fused) + + +class _MinimalQwen2VLModel(nn.Module): + """Qwen2VL-like benchmark model.""" + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.visual = nn.Sequential(*[_ViTBlock(Qwen2VLVisionAttention, hidden) for _ in range(num_layers)]) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.model(local_fused) + + +# --------------------------------------------------------------------------- +# Setup helpers +# --------------------------------------------------------------------------- + + +def _build_model_and_inputs(arch: str, args, sp_group, device): + rank = dist.get_rank(sp_group) + world_size = dist.get_world_size(sp_group) + + local_v = args.visual_tokens // world_size + bs, text_len, hidden = args.batch_size, args.seq_len, args.hidden + + torch.manual_seed(0) + local_patches = torch.randn(bs, local_v, hidden, device=device) + text_embeds = torch.randn(bs, text_len, hidden, device=device) + input_ids = torch.zeros(bs, text_len, dtype=torch.long, device=device) + + if arch == "internvl": + num_ctx = min(local_v * world_size, text_len - 2) + input_ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + model = _MinimalInternVLModel(hidden, args.num_layers).to(device) + # Suppress the Phase 2 projection-layer warning: we wrap manually below. + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) + auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(device) + else: # qwen2vl + num_inner = min(local_v * world_size, text_len - 3) + input_ids[:, 1] = _QWEN2VL_START_ID + input_ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + model = _MinimalQwen2VLModel(hidden, args.num_layers).to(device) + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) + auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(device) + + return model, local_patches, text_embeds, input_ids + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + + +def _run(arch: str, args) -> None: + deepspeed.init_distributed(dist_backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(get_accelerator().device_name(), rank % get_accelerator().device_count()) + get_accelerator().set_device(rank % get_accelerator().device_count()) + + sp_group = dist.new_group(ranks=list(range(world_size))) + model, local_patches, text_embeds, input_ids = _build_model_and_inputs(arch, args, sp_group, device) + model.eval() + + # Warmup + with torch.no_grad(): + for _ in range(args.warmup): + model(local_patches, text_embeds, input_ids) + get_accelerator().synchronize() + get_accelerator().reset_peak_memory_stats() + + # Timed iterations using CUDA events for accurate GPU-side measurement. + latencies_ms = [] + with torch.no_grad(): + for _ in range(args.iters): + t_start = get_accelerator().Event(enable_timing=True) + t_end = get_accelerator().Event(enable_timing=True) + t_start.record() + model(local_patches, text_embeds, input_ids) + t_end.record() + get_accelerator().synchronize() + latencies_ms.append(t_start.elapsed_time(t_end)) + + peak_mem_mb = get_accelerator().max_memory_allocated() / 1024**2 + mean_ms = statistics.mean(latencies_ms) + std_ms = statistics.stdev(latencies_ms) if len(latencies_ms) > 1 else 0.0 + # tokens/s: fused sequence length approximated by seq_len (length-preserving adapters). + throughput = (args.batch_size * args.seq_len) / (mean_ms / 1000.0) + + if rank == 0: + sep = "=" * 62 + print(f"\n{sep}") + print(f" AutoSP Benchmark arch={arch} sp_degree={world_size}") + print(sep) + print(f" batch_size : {args.batch_size}") + print(f" seq_len : {args.seq_len}") + print(f" visual_tokens : {args.visual_tokens} (local={args.visual_tokens // world_size}/rank)") + print(f" hidden : {args.hidden}") + print(f" num_layers : {args.num_layers}") + print(f" warmup / iters : {args.warmup} / {args.iters}") + print(f" {'─' * 58}") + print(f" Latency : {mean_ms:.2f} ± {std_ms:.2f} ms/iter") + print(f" Throughput : {throughput:,.0f} tokens/s") + print(f" Peak GPU memory : {peak_mem_mb:.1f} MB") + print(f"{sep}\n") + + dist.destroy_process_group() + + +def main() -> None: + parser = argparse.ArgumentParser(description="AutoSP multimodal SP benchmark") + parser.add_argument("--arch", + choices=["internvl", "qwen2vl"], + default="internvl", + help="Model architecture to simulate") + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--visual-tokens", + type=int, + default=256, + help="Total visual tokens (must be divisible by --nproc_per_node)") + parser.add_argument("--hidden", type=int, default=1024) + parser.add_argument("--num-layers", type=int, default=2, help="Number of ViT blocks and LLM linear layers each") + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--warmup", type=int, default=10) + args = parser.parse_args() + + _run(args.arch, args) + + +if __name__ == "__main__": + main() diff --git a/bin/deepspeed.bat b/bin/deepspeed.bat new file mode 100644 index 000000000000..8e488bde380c --- /dev/null +++ b/bin/deepspeed.bat @@ -0,0 +1,2 @@ +@echo off +python "%~dp0\ds" %* diff --git a/bin/ds_bench b/bin/ds_bench index bfacbc8e25c8..80bf4029604e 100755 --- a/bin/ds_bench +++ b/bin/ds_bench @@ -10,7 +10,10 @@ import sys required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] if not all(map(lambda v: v in os.environ, required_env)): import subprocess - subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True) + r = subprocess.check_output(["which", "ds_bench"]) + ds_bench_bin = r.decode('utf-8').strip() + safe_cmd = ["deepspeed", ds_bench_bin] + sys.argv[1:] + subprocess.run(safe_cmd) else: args = benchmark_parser().parse_args() rank = args.local_rank diff --git a/bin/ds_io b/bin/ds_io new file mode 100644 index 000000000000..681fd634764c --- /dev/null +++ b/bin/ds_io @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from deepspeed.nvme import ds_io_main + +if __name__ == '__main__': + ds_io_main() diff --git a/bin/ds_nvme_tune b/bin/ds_nvme_tune new file mode 100644 index 000000000000..117adfba22c0 --- /dev/null +++ b/bin/ds_nvme_tune @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +from deepspeed.nvme import sweep_main, generate_main, parse_sweep_arguments + +if __name__ == '__main__': + args = parse_sweep_arguments() + print(f"Running DeepNVMe performance tuning on {args.nvme_dir}") + sweep_main(args) + generate_main(args.log_dir) diff --git a/bin/ds_report.bat b/bin/ds_report.bat new file mode 100644 index 000000000000..78c7225f147c --- /dev/null +++ b/bin/ds_report.bat @@ -0,0 +1,2 @@ +@echo off +python "%~dp0\ds_report" %* diff --git a/blogs/assets/images/triton-bert-base-latency.png b/blogs/assets/images/triton-bert-base-latency.png new file mode 100644 index 000000000000..4f733f4d1afe Binary files /dev/null and b/blogs/assets/images/triton-bert-base-latency.png differ diff --git a/blogs/assets/images/triton-bert-large-latency.png b/blogs/assets/images/triton-bert-large-latency.png new file mode 100644 index 000000000000..d82dc0ccac51 Binary files /dev/null and b/blogs/assets/images/triton-bert-large-latency.png differ diff --git a/blogs/comm-opt/README.md b/blogs/comm-opt/README.md new file mode 100644 index 000000000000..7049e3b9f162 --- /dev/null +++ b/blogs/comm-opt/README.md @@ -0,0 +1,82 @@ +
+ +# Communication Optimizations for Large-Scale Training + +
+ + +## Table of Contents +1. [Introduction](#introduction) +2. [Gradient AllReduce Optimization for ZeRO stages 1 and 2](#ar-opt) +3. [Optimizing Parameter All-Gather for ZeRO2 Training](#ag-opt) +4. [Optimizing AlltoAll for Sequence-Parallel Training](#sp-opt) + + +## 1. Introduction +Training LLMs on large datasets can be extremely costly both in terms of hardware resources and time. An important step to minimize such costs is to carefully combine an appropriate number of resources together with a scalable library that guarantees training completion within a time limit. In this post, we discuss a key aspect of the scalability features of DeepSpeed, the communication optimization. Communication collectives (e.g., all-reduce, all-gather, etc.) are critical pieces of many popular DeepSpeed technologies (e.g., ZeRO, MoE, AutoTP, etc.), and in the following sections we discuss our new optimizations of some of these collectives. These optimizations are available in DeepSpeed versions >= 0.x.x. + +## 2. Gradient AllReduce Optimization for ZeRO stages 1 and 2 + +Before diving into this optimization, let's take a step back and show some of the case studies that demonstrate the need. + +AllReduce operation is an important part of the training process. In ZeRO, we handle this in buckets, which can be configured to get good communication throughput. As the number of GPUs increases, we encounter smaller-partition AllReduces. In this case, the current bucketing scheme cannot help with the communication overhead. This mostly becomes an issue when training smaller-scale models (like Llama-7B) with large number of GPUs. + +For instance, when training a dense-7B architecture with Zero stages 1 or 2, we encounter a 1 and 2 second increase for the AllReduce time by increasing from 256 to 512 and 1024 A100 GPUs. This issue mostly arises from the fact that, the gradient-averaging happens with smaller partitions (#parameters / #GPUs) per-GPU rank. This issue gets more serious when training MoE architectures (3 - 12 second) for which the expert's parameters can be farther away due to the current parallelism layout of data and expert parallelism. + +In this section, we introduce two main optimization techniques for alleviating these communication bottleneck. + +First, Multi-rank bucketing for the same process group: for this optimization, we simply pack all data that requires to be reduced from different ranks into one big flattened tensor and call AllReduce instead of reduce operations. After the reduction, we scatter the right portion of data to the corresponding ranks. + +Second, add new layout for the expert-data parallelism: the default parallelism layout for MoE architecture (as shown in Fig 1) is planned in a way that the experts are placed first on E parallel GPUs and replicated D times (data-parallel). With this layout, we encounter slower AllReduce as data-parallel ranks are placed farther away especially when we have cross-rank communication. We call this layout E + D. + +
+
+ + *Fig 1: Different MoE parallel layout. left) E + D, which places the GPUs in EP dimension first before adding DP, right) D + E, that replicates each expert by DP size, before constructing EP. We get faster AllReduce for the second layout while increasing the AlltoAll time. It potentially results in faster e2e training time, as the communication volume for AllReduce (total parameter size) is normally much more than AlltoAll (MLP activation memory).*
+
+By changing this layout from E + D to D + E (shown in Fig 1), where we first replicate each expert by D times and then add them across expert-parallel dimension, we can reduce the AllReduce time substantially. On an A100-DGX cluster, where each node has 8 GPUs, we see about 8x reduction in cross-node infiniband communication-volume for the parameter update process, which are now processed faster using the intra-node NVLinks. Note that by adding this optimization, we increase the cost of AlltoAll happening for the MoE part of the model, however, we have seen that the performance benefit of AllReduce overweighs this cost. + +Table 1 summarizes the saving observed for training a 7B dense and a MoE architecture by using the optimized AllReduce scheme. After applying the multi-rank bucketing technique, we reduce the AllReduce time by 4x for dense architecture and 5x - 8x for the MoE one. In addition, we obtain an extra 3x saving using the new D + E layout for the MoE architecture. Therefore, we see higher performance gain on MoE architectures when using large number of GPUs. For instance, when training a 7B-base MoE architecture, we reduce iteration-time from 13 sec to 9.5 sec on 512 GPUs (37%) and from 16.1 sec to 5.1 sec on 1k-GPU setup (3.2x). +
+ +| | GPUs | AllReduce time | Iteration time | +|----------|:------:|:------:|:------:| +baseline (dense) | 1024| 1.2 | 5.4 +optimized (dense) | 1024| 0.36 | 4.5 +baseline (MoE) | 1024 | 11.5 | 16.1 +optimized (MoE) | 1024 | 0.45 | 5.1 + +Table 1. AllReduce saving observed for both dense and MoE architectures. + +
+ +## 3. Optimizing Parameter All-Gather for ZeRO2 Training + +The same as with AllReduce, all-gather takes longer as we have more partitions. As the parameters are stored in a flattened buffer for ZeRO stage-2, we can simply have a one call to all-gather the parameters into this tensor. + +When all-gathering the updated parameters at Zero-Stage2, the bucketing scheme uses several narrow operations and creates a list of tensors with the bucket size from each partition. We needed this scheme to align with the `all_gather` operation from PyTorch. +However, by adding the support for the `all_gather_into_tensor`, operation that has been added to the newer versions of PyTorch, we can simply have a kernel call to do the full-parameter all-gather. With this optimization, we see about 2x reduction in the step time for large-scale training. + +## 4. Optimizing AlltoAll for Sequence-Parallel Training + +For this part of the optimization, we add some fusion for the communication that is required for the DeepSpeed-Ulysses to provide a more scalable approach for when we increase the SP from 2 to 8 (for this study, we consider A100-DGX hardware, which has 8 GPUs per-node and by increasing the parallelism more than 8, we encounter performance-hit by the cross-node communication). + +These fusions are done at two levels: +1. Fuse the sequence AlltoAll for q,k, and v: we Scatter the heads using the mixed tensor rather than splitting them beforehand. For this part, we need to get some more information from the modeling side (such as the number of q and kv heads), to split the heads before calling AlltoAll. We have added some new changes on the Megatron-DeepSpeed repo that incorporate these changes for the sequence-parallelism. +2. Fuse the AlltoAll tensors and call the PyTorch's AlltoAll-single API: we reshape the tensors for the scatter dimension and use a single tensor for AlltoAll which alleviates the overhead of using a list of tensors which requires a contiguous call for each element of the list. + +By adding these optimizations, we see about 10 to 15% speedup compared to the previous design, and obtain good scalability across different SP-degree and context-lengths. In the following table, we show the improvement achieved by using SP, when doubling the GPU-count and increasing the SP-degree. We obtain over 80% of efficiency when increasing from 256 to 512 GPUs using SP-2. Furthermore, by increasing the sequence-length and SP, while keeping the processed tokens similar, we achieve over 75% of efficiency for 2x more resources. On the other hand, if we can double the number of tokens (shown on the last row of table 2), we can improve the performance to 1.81x. + +
+ +| GPUs | bsz | seq | Tokens (M) | SP | Sample (4K)-per-second | Speedup (x) | +|----------|:------:|:------:|:------:|:------:|:------:|:------:| +256 | 256| 8192 |2|1 | 60.71 |1 +512 | 256| 8192 |2|2 | 111.18 | 1.83 +512 | 128| 16384 |2|4 | 108.81 | 1.79 +512 | 64 |32768 |2|8 | 106.54 | 1.75 +512 | 64 |65536 |4|8 | 110.05 | 1.81 + +Table 2. Sequence-Parallelism scalability using DeepSpeed-Ulysses. + +
diff --git a/blogs/comm-opt/assets/images/e+d.png b/blogs/comm-opt/assets/images/e+d.png new file mode 100644 index 000000000000..72ad0f583857 Binary files /dev/null and b/blogs/comm-opt/assets/images/e+d.png differ diff --git a/blogs/comm-opt/assets/images/sp+fp.png b/blogs/comm-opt/assets/images/sp+fp.png new file mode 100644 index 000000000000..0b2940418f7a Binary files /dev/null and b/blogs/comm-opt/assets/images/sp+fp.png differ diff --git a/blogs/comm-opt/assets/images/sp-conv.png b/blogs/comm-opt/assets/images/sp-conv.png new file mode 100644 index 000000000000..e1e36b4436a0 Binary files /dev/null and b/blogs/comm-opt/assets/images/sp-conv.png differ diff --git a/blogs/core_api_update/README.md b/blogs/core_api_update/README.md new file mode 100644 index 000000000000..e3ee05d7f9fb --- /dev/null +++ b/blogs/core_api_update/README.md @@ -0,0 +1,161 @@ +# DeepSpeed Core API updates: PyTorch-style backward and low-precision master states + +DeepSpeed is continuously evolving its core APIs to feel more natural to PyTorch users while giving them more control over performance and memory. + +In this short blog, we highlight two recent core improvements: + + * **PyTorch-compatible backward API** – You can now use standard `tensor.backward(...)` patterns with DeepSpeed engines, including non-scalar outputs. ([\#7665](https://github.com/deepspeedai/DeepSpeed/pull/7665)) + * **Low-precision master params / grads / optimizer states** – You can keep more state in bf16/fp16 to reduce memory usage and work better with `torch.autocast`. ([\#7700](https://github.com/deepspeedai/DeepSpeed/pull/7700)) + +These changes enable more flexible training pipelines, such as [disaggregated hybrid parallelism](https://www.anyscale.com/blog/30-faster-multimodal-ai-training-with-ray-and-disaggregated-hybrid) ([code](https://github.com/ray-project/multimodal-training)) for multimodal models using [Ray](https://github.com/ray-project/ray), and make DeepSpeed feel closer to “vanilla PyTorch”. + +## 1\. PyTorch-compatible backward API + +Traditionally, DeepSpeed’s training loop relied on the engine’s backward API: + +```python +loss = model_engine(batch) +model_engine.backward(loss) +model_engine.step() +``` + +This API was sufficient for traditional pretraining and fine-tuning pipelines. However, recent complex training pipelines require more flexibility. There were two major constraints: + +1. It only accepted a **scalar loss**. +2. You had to call **`model_engine.backward(loss)`**, rather than using the usual PyTorch `loss.backward()` style. + +Due to these constraints, users could not simply implement patterns that plain PyTorch allows. Here are some examples: + +```python +# 1. Combine multiple models and losses +output1 = model1(batch1) +output2 = model2(batch2) +loss = criterion(output1, output2) +loss.backward() + +# 2. Define a loss function separately from the main model +output = model(batch) +loss = loss_fn(output) +loss.backward() + +# 3. Call backward through non-scalar tensors with custom gradients +output = model(batch) +output.backward(grad) +``` + +The DeepSpeed Engine was able to handle these use cases using internal APIs; however, this required code changes. Additionally, if a user employed these patterns, the DeepSpeed engine might skip internal preprocessing/postprocessing (such as loss scaling and ZeRO-related logic), potentially leading to incorrect behavior. + +With this API update, we can now use the same code as native PyTorch while keeping DeepSpeed's unique features, including ZeRO. + +One example use case for this new API is [disaggregated hybrid parallelism](https://www.anyscale.com/blog/30-faster-multimodal-ai-training-with-ray-and-disaggregated-hybrid) for multimodal models using [Ray](https://github.com/ray-project/ray). In this training pipeline, two Ray Actor groups handle the vision encoder and the LLM separately. + +On a backward pass, the LLM passes a gradient to the vision encoder, and the vision encoder calls the backward function with that gradient. However, because the gradient is a non-scalar tensor, such a use case wasn't officially supported by DeepSpeed APIs. + +Below is the pseudo-code for the two models running on different actors. Since they run in different processes, we pass gradients via Ray actor communication. As seen here, the gradient of the vision embedding is a non-scalar tensor. With this update, we can now simply call `self.vision_output.backward` while utilizing other DeepSpeed features, including ZeRO and highly efficient sequence parallelism (DeepSpeed-Ulysses). + +```python +# Runs on LLM actors +def text_backward_step(self): + # ... + self.loss.backward() + return self.vision_embeddings.grad.detach().clone() + +# Runs on Vision actors +def vision_backward_step(self, vision_embedding_grad): + self.vision_output.backward(gradient=vision_embedding_grad) +``` + +Check out the [repository](https://github.com/ray-project/multimodal-training) for the complete training pipeline. + + +## 2\. Low-precision master params, grads, and optimizer states + +DeepSpeed supports mixed precision, which computes in bfloat16 or float16 while its optimizer maintains **FP32 master parameters, gradients, and optimizer states**. + +On the other hand, PyTorch now offers `torch.autocast`, a different approach for mixed precision that casts data types for precision-sensitive operators on the fly. As this often requires less peak memory, many recent training pipelines use this approach. + +DeepSpeed supports `torch.autocast` via configuration (see the [API documentation](https://deepspeed.readthedocs.io/en/rtd-staging/training.html#pytorch-automatic-mixed-precision-amp)). While it is technically safer to keep FP32 model states (master parameters/gradients and optimizer states) even with `torch.autocast`, there are many cases where training converges stably without them. Previously, the lack of an option to bypass creating FP32 states limited the trainablity of large models with constrained hardware resources. + +To reduce memory usage in such cases, DeepSpeed now allows users to avoid creating FP32 states entirely. + +### Enabling pure BF16/FP16 model states + +For BF16 training, you can use the following settings under `bf16`: + + * `bf16_master_weights_and_grads`: Keep master parameters and gradients in bf16. + * `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in bf16. + +These configurations are compatible with ZeRO stages 1, 2, and 3. Note that there is also a supported mixed configuration where `bf16_master_weights_and_grads == true` and `bf16_optimizer_states == false`, but **only when using CPU offload**. + +We offer similar support for FP16 training. You can use this setting under `fp16`: + + * `fp16_master_weights_and_gradients`: Keep master parameters and gradients in fp16. + +We actually offered this option in previous versions, but it was undocumented and worked only for ZeRO 1 and 2. We now officially support it, and it works for all ZeRO stages. We intentionally excluded `fp16_optimizer_states` as it is generally impractical due to convergence instability. + +A notable improvement is that we can combine these settings with `torch.autocast` support (via the [`torch_autocast` section](https://www.google.com/search?q=%5Bhttps://deepspeed.readthedocs.io/en/rtd-staging/training.html%23pytorch-automatic-mixed-precision-amp%5D\(https://deepspeed.readthedocs.io/en/rtd-staging/training.html%23pytorch-automatic-mixed-precision-amp\))). This combination drastically improves both memory efficiency and convergence. + +### Example: Pure bf16 config with low-precision master state + +Below is a simplified DeepSpeed config that keeps bf16 master weights, grads, and optimizer states, and uses `torch.autocast`: + +```json +{ +... + "zero_optimization": { + "stage": 3, + ... + }, + "bf16": { + "enabled": true, + "bf16_master_weights_and_grads": true, + "bf16_optimizer_states": true + }, + "torch_autocast": { + "enabled": true, + "dtype": "bfloat16" + } +} +``` + +Our [example script](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/bf16_master_weight) demonstrates the significant memory savings: + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s | +| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s | + +To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset: + +
+ + + +*Loss curve comparison* + +
+ +| Configuration | Final Loss | Mean Loss | +|---------------|------------|-----------| +| Baseline (fp32 master) | 3.09 | 2.78 | +| BF16 Low-Precision | 3.12 | 2.90 | + +Please check out our [example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/bf16_master_weight) for more details. + +## Closing thoughts + +These core API improvements are incremental but important steps toward making DeepSpeed: + + * **More PyTorch-native** – Training loops can increasingly look like standard PyTorch code. + * **More memory-efficient** – Especially when combined with bf16/fp16 and ZeRO on large models. + * **Easier to compose** – Enabling multi-model and custom-gradient workflows without relying on DeepSpeed internal APIs. + +We're excited to see how you use these APIs in your own training setups, and we welcome feedback and issues on GitHub as you try them out. + +## Related Tests + +For more usage examples, see the unit tests in the repository: + +- [PyTorch-compatible backward API](https://github.com/deepspeedai/DeepSpeed/tree/master/tests/unit/v1/zero/test_zero_user_backward.py) +- [Low-precision master params/grads/optimizer states](https://github.com/deepspeedai/DeepSpeed/tree/master/tests/unit/v1/half_precision/test_bf16.py) +- [Combination with torch.autocast](https://github.com/deepspeedai/DeepSpeed/tree/master/tests/unit/v1/half_precision/test_with_autocast.py) diff --git a/blogs/core_api_update/assets/loss_comparison.png b/blogs/core_api_update/assets/loss_comparison.png new file mode 100644 index 000000000000..8ae2e82d615a Binary files /dev/null and b/blogs/core_api_update/assets/loss_comparison.png differ diff --git a/blogs/deepcompile/README.md b/blogs/deepcompile/README.md new file mode 100644 index 000000000000..7fca2cba1fa1 --- /dev/null +++ b/blogs/deepcompile/README.md @@ -0,0 +1,174 @@ +
+ +# DeepCompile: Unlocking Compiler Optimization for Distributed Training + +
+ +# Introduction + +
+ + + +
+ +Distributed training has become essential for scaling today’s massive deep learning models. While deep learning compilers like PyTorch compiler dramatically improved single-GPU training performance through optimizations like kernel fusion and operator scheduling, they fall short when it comes to distributed workloads. +Existing distributed training frameworks such as DeepSpeed and FSDP have made large-scale model training feasible through advanced parallelization strategies. While powerful, their optimizations are implemented at the PyTorch framework level, which limits the ability to apply compiler-style techniques like dependency analysis or operator scheduling. + +DeepCompile addresses this gap by enabling compiler-level optimizations for distributed training. It takes a standard single-GPU model implementation and transforms it into an optimized multi-GPU training graph without requiring changes to the model code. Unlike existing approaches, DeepCompile automatically applies parameter sharding, communication scheduling, and memory-aware execution at the compiler IR level, enabling global analysis and optimization that are difficult to express in traditional frameworks. Furthermore, during training, DeepCompile employs profile-guided optimization techniques to dynamically tune these parallelization strategies and improve training performance. + +Our evaluation demonstrates that DeepCompile improves training performance over ZeRO-3 baselines, achieving up to 1.5x speedup when sufficient GPU resources are available, and up to 7x speedup in GPU-constrained settings that require offloading. DeepCompile is available in DeepSpeed versions >= [0.16.6](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.16.6). As it is under active development, we recommend using the latest version of DeepSpeed or installing from source to access the most recent updates and bug fixes. + +# Design Overview + +DeepCompile extends the capabilities of deep learning compilers to support distributed training. It starts from a standard single-GPU model implementation, such as those available on the Hugging Face model hub, and automatically transforms it by inserting necessary distributed training operations such as parameter sharding and communication primitives. Users are not required to embed any distributed logic into the model code. + +The process begins by compiling the model into an intermediate representation (IR), which forms a computation graph. DeepCompile then applies a sequence of *optimization passes*, each responsible for a specific transformation of the computation graph or a targeted performance improvement, to incrementally introduce distributed behavior and optimize the graph. These include operations such as all-gather for sharded parameters or offloading of optimizer states, all while preserving the original computation semantics (Fig. 1). + +
+ + + +*Figure 1: Workflow of compilation and optimization with DeepCompile.* + +
+ +At its core, DeepCompile builds on two key capabilities: + +- **Automatic parallelization**: DeepCompile allows optimization passes to rewrite the single-GPU computation graph into a distributed multi-GPU version, incorporating strategies such as ZeRO, FSDP, and more. This eliminates the need for manual implementation of distributed training logic, drastically reducing engineering effort. +- **Profile-guided performance tuning**: At runtime, DeepCompile collects profiling data such as operator-level memory usage and execution latency. It uses this information to dynamically schedule computation and communication operators. This enables effects such as an improved overlap between communication and computation, and an avoidance of memory bottlenecks. Fine-grained tuning through these optimization passes often leads to better performance than even manually engineered implementations. + +Figure 2 illustrates the optimization cycle employed by DeepCompile. After the initial computation graph is generated by the compiler, DeepCompile profiles its behavior by measuring operator execution time, communication overhead, and memory usage throughout the forward and backward passes. + +
+ + + +*Figure 2. Optimization cycle.* + +
+ +Based on the collected profiling data, DeepCompile applies a sequence of optimization passes. These passes modify the computation graph by inserting, removing, or reordering operators to improve overall efficiency. The modified graph is then re-profiled, and this cycle of profiling and optimization is repeated. + +Once a stable set of optimizations has been applied, the graph is deployed for the remaining training iterations. During execution, memory usage and other runtime characteristics may change. In such cases, DeepCompile can resume the profiling and optimization cycle according to the predefined schedule of passes, allowing the graph to adapt and maintain high performance. + +# Optimizations + +DeepCompile is designed as a general compiler framework for applying and optimizing a wide range of parallelization strategies. In the following, we describe several optimizations that have been implemented as optimization passes within DeepCompile. + +## ZeRO3 + +As an initial step, we have used DeepCompile to implement and enhance ZeRO-3-style optimizations at the compiler level. ZeRO-3 partitions model parameters, gradients, and optimizer states across devices, reducing memory usage and enabling large-scale training. + +In conventional ZeRO-3 implementations, operations such as all-gather, reduce-scatter, and buffer release are typically inserted using Python hooks at runtime. DeepCompile replaces this approach by injecting these operations directly into the computation graph during compilation. This allows the compiler to determine their placement precisely, guided by both the static structure of the graph and runtime profiling information. + +One of the key optimizations is **proactive prefetching**, which launches all-gather operations earlier in the computation based on memory usage profiling. This reordering increases the overlap between communication and computation thereby improving throughput, while avoiding OOMs. In addition, small communication operations are often fused to reduce launch latency and improve efficiency. + +Another optimization is **selective unsharding**, which keeps certain parameters in an unsharded form during the forward and backward passes when memory conditions permit. This reduces the frequency of all-gather operations and avoids redundant communication, particularly in scenarios where gradient accumulation is enabled. + +## Offloading + +DeepCompile also supports **adaptive offloading**, which offloads optimizer states to reduce GPU memory pressure. Unlike approaches that offload all the optimizer states, adaptive offloading identifies only the portions that exceed the memory limit—such as momentum and variance used by the Adam optimizer—and schedules data transfers to overlap with computation. This selective and asynchronous strategy minimizes overhead and enables efficient training even in memory-constrained environments. + +## ZeRO1 + +ZeRO-1 differs from ZeRO-3 in that it shards only the optimizer states across devices, while keeping parameters and gradients fully replicated. This approach reduces memory usage with minimal changes to computation flow, making it a lightweight alternative for certain training scenarios. +DeepCompile implements ZeRO-1-style optimization by inserting reduce-scatter operations directly into the computation graph. By avoiding Python-level hooks, this graph-level integration reduces overhead and improves execution efficiency. + +# Performance Improvements + +## ZeRO-3 + +We evaluated DeepCompile on Llama-3-70B and Mixtral 8x7B using parameter sharding on top of Hugging Face model implementations. +Figure 3 shows training throughput (TFLOPs/GPU) across different gradient accumulation steps, using 32 H100 GPUs with a sequence length of 1024. +We compare DeepCompile against two DeepSpeed ZeRO-3 baselines: (i) an eager-mode version without compiler support (labelled ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (labelled ZeRO3+Compile). For DeepCompile, we enabled both proactive prefetching and selective unsharding to demonstrate the combined effect of these optimization passes. + +
+ +*Figure 3. Achieved throughputs for ZeRO3 training of Llama-3 70B and Mixtral 8x7B models.* + +
+Across both models, DeepCompile consistently delivers higher throughput. The benefit becomes more pronounced at higher accumulation steps, where the reduced frequency of parameter updates makes selective unsharding more effective. DeepCompile with proactive prefetching and selective unsharding achieves up to 1.28× speedup over ZeRO-3 on Llama-3-70B and 1.54× on Mixtral 8x7B. + +Meanwhile, enabling the PyTorch compiler with ZeRO-3, i.e., ZeRO3+Compile introduces minor overheads in some settings. This is because ZeRO-3 includes many conditional branches for runtime features such as prefetching. When the compiler encounters branches that cannot be statically resolved, it splits the computation into multiple graph segments. These fragmented segments can reduce optimization opportunities and introduce additional overheads during execution. + +## Offloading + +Training models as large as Llama-3 70B with ZeRO-3 typically requires 32 GPUs with 80GB of memory. +DeepSpeed addresses this challenge by offering offloading capabilities, which transfer optimizer states and optionally model parameters to CPU memory to reduce GPU memory usage. DeepCompile also supports offloading through a dedicated optimization pass, but with a few key differences in design. + +Unlike the traditional approach of offloading both optimizer computation and memory, DeepCompile offloads only optimizer memory (e.g., momentum, variance, and master weights of Adam optimizer) while the optimizer computation remains on GPU. DeepCompile profiles memory usage during both forward and backward passes to identify when offloading is necessary, and transfers only the required data. This fine-grained approach avoids unnecessary overhead and helps maintain high computational throughput. +Furthermore, DeepCompile overlaps data transfers with computation whenever possible, dynamically adjusting the timing based on observed memory usage patterns. This asynchronous behavior is a crucial aspect of DeepCompile’s offloading strategy, allowing it to reduce GPU memory pressure without stalling execution. + +We evaluated DeepCompile's offloading using Llama-3 70B on 16xH100-80GB (half the required GPU counts) and present the results in Figure 4. + +
+ + + +*Figure 4. Achieved throughput of optimizer offloading for Llama-3 70B on 16x80GB GPUs* + +
+ +We compare against two ZeRO-3 offloading baselines: (i) an eager-mode version without compiler support (ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO3+Compile). As shown by the results, DeepCompile significantly improves offloading efficiency and provides up to 7× speedup over ZeRO3+Eager. In contrast, we see that ZeRO3+Compile achieves similar performance as ZeRO3+Eager. + + +## ZeRO-1 + +We also evaluated DeepCompile with ZeRO-1 using the Llama-3-8B model. We compare DeepCompile against two ZeRO-1 baselines: (i) an eager-mode version without compiler support (ZeRO1+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO1+Compile). In our experiment with 8 GPUs and a batch size of 2, DeepCompile achieved consistent throughput improvements across different sequence lengths, as shown in Figure 5. + +
+ + + +*Figure 5. Achieved throughput of ZeRO-1 training of Llama-3 8B* + +
+ +The most significant speedup was observed with batch size 1 and sequence length 512, where DeepCompile outperformed ZeRO1+Eager by up to 1.9×, and ZeRO1+Compile by up to 2.5×. + +While compiler-based approaches can be effective for large batch sizes and long sequences by replacing suboptimal operations with more efficient kernels, they may also introduce overheads in ZeRO-1-style training in the form of *graph breaks* around the communication operations. These overheads become more pronounced with smaller batch sizes and sequence lengths, thus hurting performance compared to the non-compiled execution. In contrast, DeepCompile inserts communication operators directly into the computation graph during compilation, avoiding graph fragmentation and minimizing associated overhead. This makes DeepCompile more robust to small-scale workloads, while still benefiting from compiler-level optimizations. + +## Additional Results and Analysis + +Please refer to our [arXiv paper](https://arxiv.org/abs/2504.09983) for additional results, such as detailed comparisons across different batch sizes, sequence lengths, and memory usage. + +# Looking Ahead + +DeepCompile brings the power of compiler-based optimizations to distributed deep learning. By transforming computation graphs and applying profile-guided optimization passes, it enables more efficient training without requiring changes to model code. + +This release is just the beginning. We’re actively working on expanding the set of optimization passes and improving integration with a broader range of distributed training strategies. Future directions include automated parallelization (sequence/tensor parallelisms), smarter memory management, and dynamic adaptation to runtime behavior. + +We invite the community to try DeepCompile, explore its capabilities, and contribute to its evolution. Let’s build the next generation of scalable deep learning together. + +# Acknowledgments + +We would like to thank everyone who supported this project. + +This project would not have been possible without the PyTorch Compiler—a platform that is not only powerful and flexible, but also a pleasure to work with. We are especially grateful to the developers and researchers behind PyTorch Compiler for making such an excellent foundation available to the community. + +# Contributors + +This project is the result of a close collaboration between Microsoft and the University of Virginia. The contributors are: Masahiro Tanaka, Du Li, and Umesh Chand, Olatunji Ruwase (Microsoft); and Ali Zafar and Haiying Shen (University of Virginia). + +# Appendix + +## Examples and Benchmarks + +Our DeepSpeedExamples repository provides [example code](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/deepcompile) to enable DeepCompile. + +## Optimization Passes + +The following optimization passes are currently available in DeepCompile: + +- All-gather & reduce-scatter insertion (ZeRO3) +- Proactive prefetching (ZeRO3) +- Selective unsharding (ZeRO3) +- Reduce-scatter insertion (ZeRO1) +- Adaptive offloading + +We used the following combinations of passes in the experiments presented above: + +- Improved communication scheduling for ZeRO-3: All-gather & reduce-scatter → Proactive prefetching → Selective unsharding +- Offloading optimizer states for ZeRO3: Adding all-gather & reduce-scatter → Adaptive offloading +- Reduced overhead and improved overlap for ZeRO-1: Adding reduce-scatter diff --git a/blogs/deepcompile/media/opt_loop.png b/blogs/deepcompile/media/opt_loop.png new file mode 100644 index 000000000000..a3a4ca33a684 Binary files /dev/null and b/blogs/deepcompile/media/opt_loop.png differ diff --git a/blogs/deepcompile/media/perf_offload.png b/blogs/deepcompile/media/perf_offload.png new file mode 100644 index 000000000000..1506f20bc133 Binary files /dev/null and b/blogs/deepcompile/media/perf_offload.png differ diff --git a/blogs/deepcompile/media/perf_summary.png b/blogs/deepcompile/media/perf_summary.png new file mode 100644 index 000000000000..798ff54acb7d Binary files /dev/null and b/blogs/deepcompile/media/perf_summary.png differ diff --git a/blogs/deepcompile/media/perf_zero1.png b/blogs/deepcompile/media/perf_zero1.png new file mode 100644 index 000000000000..a7256919f9a5 Binary files /dev/null and b/blogs/deepcompile/media/perf_zero1.png differ diff --git a/blogs/deepcompile/media/perf_zero3.png b/blogs/deepcompile/media/perf_zero3.png new file mode 100644 index 000000000000..a93e929312a3 Binary files /dev/null and b/blogs/deepcompile/media/perf_zero3.png differ diff --git a/blogs/deepcompile/media/workflow.png b/blogs/deepcompile/media/workflow.png new file mode 100644 index 000000000000..72a358408099 Binary files /dev/null and b/blogs/deepcompile/media/workflow.png differ diff --git a/blogs/deepnvme/06-2025/README.md b/blogs/deepnvme/06-2025/README.md new file mode 100644 index 000000000000..0d225d5de6b8 --- /dev/null +++ b/blogs/deepnvme/06-2025/README.md @@ -0,0 +1,137 @@ +
+ +# DeepNVMe: Affordable I/O scaling for Deep Learning Applications. + +
+ +# Introduction +We introduced [DeepNVMe](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepnvme/08-2024/README.md) in summer 2024 as a suite of optimizations for tackling I/O bottlenecks in Deep Learning (DL). DeepNVMe delivers significant speedups for I/O bound DL workloads by leveraging storage innovations including local NVMe SSDs, NVIDIA Magnum IOTM GPUDirect® Storage (GDS), and Linux Asynchronous I/O (AIO). +In this update, we are delighted to announce DeepNVMe improvements on multiple fronts: (i) expanding application coverage to FastPersist model checkpointing and SGLang inference, (ii) I/O performance scaling by upgrading from PCIe Gen4 to Gen5 NVMe SSDs, and (iii) expanding usability to CPU-only environments, offset-based I/O operations, and tensor data type casting. The results reported in this blog are available in DeepSpeed versions >= [0.17.1](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.17.1). + +# Evaluation environments +Our experiments are conducted on Azure [ND-H200-v5](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-h200-v5-series?tabs=sizebasic) VM. The key software configurations are summarized in the following table. + +|Software | Version +|---|--| +|Ubuntu | 24.04.2| +|PyTorch | 2.6.0| +|CUDA | 12.6 | +SGLang | 0.4.4.post4 | + +# Addressing I/O Bottlenecks of Deep Learning +We used DeepNVMe to develop FastPersist and ZeRO-Inference to target I/O bottlenecks in DL training and inference respectively. Our experiments are conducted using a single VM, in which we combine the available NVMe SSDs into a single RAID-0 (i.e., disk striping) volume to leverage aggregate read and write bandwidths. Since DeepNVMe can offload tensors using CPU bounce buffers (a.k.a., AIO), or NVIDIA GPUDirect Storage (a.k.a., GDS), we report results for both modes. + +## FastPersist: Faster Model Checkpoint Creation +Although saving model checkpoints to persistent storage is critical in model training, it is also a major bottleneck due to the inefficiencies of existing approaches. We developed [FastPersist](https://arxiv.org/abs/2406.13768) to address the performance challenges of checkpointing. FastPersist makes checkpointing overheads negligible during training through three key techniques: (i) DeepNVMe, (ii) data parallelism, and (iii) overlapping I/O and computation. + +Our goal here is to demonstrate the impact of DeepNVMe in FastPersist using single-process micro-benchmarks (available [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint)) which serialize a model checkpoint state from HBM to local NVMe. We use the popular PyTorch `torch.save()` as the baseline in our experiments, and integrate FastPersist into `torch.save()` to simplify adoption and performance comparisons. + +### Faster Saving of PyTorch Models to local NVMe Storage +We measure the throughput of serializing Phi-3-Mini checkpoint state from HBM to local NVMe storage. The results are summarized in the Figure below. We observe significantly faster checkpointing with FastPersist compared to the baseline. We see speedups of over 20X in the 8xGen5 NVMe settings. We also observe FastPersist scaling with increased NVMe bandwidth of 8xGen5 compared with 4xGen5. + + +
+ FastPersist provides significantly faster model checkpointing to local NVMe. +
+ +## ZeRO-Inference: Democratizing Generative AI +[ZeRO-Inference](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) is a technology that democratizes access to state-of-the-art models by reducing the GPU costs of model inference. ZeRO-Inference enables inference computations of massive models (hundreds-of-billions of parameters) on as few as one GPU by offloading the model weights to DRAM and NVMe storage. ZeRO-Inference is designed for offline or throughput-oriented inference scenarios. In this blog, we share two updates on ZeRO-Inference. First, we have integrated ZeRO-Inference into SGLang, a state-of-the-art model serving framework. Second, we observed ZeRO-Inference performance scales with the faster NVMe SSDs in the latest Azure SKUs. + +### Democratizing SGLang through ZeRO-Inference integration +[SGLang](https://docs.sglang.ai/) is a state-of-the-art serving framework for large language models (LLMs) and vision language models (VLMs). Our integration of ZeRO-Inference into SGLang makes SGLang available to budget-constrained users, and offers a cost-reduction option to existing SGLang users. We used SGLang's [offline benchmarking tool](https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_offline_throughput.py) to measure the generation throughput of LLAMA3-70B on a single H200 with NVMe offloading (LLAMA3-70B cannot fit in the 141GB VRAM without offloading). The experiment is configured with prompt length of 512, generation length of 32, and batch size of 128. We summarize the results in the figure below for both AIO and GDS offloading. + + +
+ ZeRO-Inference improves SGLang inference with NVMe offloading to reduce hardware costs. +
+ + +### Scaling HF Transformer Generation with Faster NVMe SSDs +ZeRO-Inference enhances HF Transformer inference with efficient model offloading to DRAM or NVMe. We previously [evaluated](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepnvme/08-2024/README.md#high-performance-offloading-via-nvme-scaling) LLAMA-3-70B generation performance with NVMe offloading on a single GPU and four Gen4 NVMes in an Azure [NC_A100_v4](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizebasic) VM. We measured the generation speed for a prompt of 512 tokens, output of 32 tokens, and batch size 96. Since NVMe bandwidth was the main bottleneck, we repeat the experiments on Azure ND-H200-v5 offering Gen5 NVMes. The results summarized in the Figure below show that ZeRO-Inference uses the increased NVMe bandwidths to improve generation speeds. For example, with GDS, generation speed improves from 7 tokens/sec with four Gen4 NVMes to 17 tokens/sec with four Gen5 NVMes, and further to 26 tokens/sec with eight Gen5 NVMes. We observe similar improvements without GDS. These results show that ZeRO-Inference performance can be improved in cost-effective manner by increasing NVMe bandwidths. + + +
+ ZeRO-Inference leverages available NVMe bandwidth to scale LLAMA-3-70B generation. +
+ + +# I/O performance scaling +We used our `ds_io` benchmarking tool to demonstrate DeepNVMe proportionally scaling I/O performance with available NVMe bandwidths. This empowers users to accelerate I/O bound DL applications at modest cost using more or faster NVMe SSDs. In our experiments, we measure the achieved read and write bandwidths of 1GB data transfers between HBM and NVMes. We evaluate scaling up NVMes from PCIe Gen4 to Gen5, and scaling out from 4 to 8 SSDs. The SSDs are combined into a single RAID-0 (disk striping) volume. We summarize the results in the Figure below which show that DeepNVMe scales I/O performance on both dimensions. Scaling up from 4xGen4 SSDs to 4xGen5 SSDs improves reads from 10GB/sec to 27GB/sec, and writes from 5GB/sec to 11GB/sec. Scaling out from 4xGen5 to 8xGen5 further improves reads to 48GB/sec, and writes to 26GB/sec. + + +
+ Microbenchmark shows DeepNVMe scales I/O performance with available NVMe bandwidth +
+ + +# Broadening usability +We have increased the usage scenarios of DeepNVMe by removing restrictions regarding hardware environments and I/O operations, as explained below. + +## CPU-Only environments +Although GPUs (and similar accelerators) dominate DL, CPUs are still used in important machine learning (ML) workloads such as recommendation systems. However, DeepNVMe was previously unusable in CPU-only environments. This was because DeepNVMe relied on `torch.pin_memory()` for page-locked CPU tensors, whereas `torch.pin_memory()` does not work in the CPU versions of `torch` as illustrated below. + +```bash +>>> import torch +>>> torch.__version__ +'2.6.0+cpu' +>>> x = torch.empty(1024).pin_memory() +Traceback (most recent call last): + File "", line 1, in +RuntimeError: Cannot access accelerator device when none is available. +>>> +``` + +We have made DeepNVMe usable in CPU environments by adding mechanisms for allocating (`new_cpu_locked_tensor()`) and releasing (`free_cpu_locked_tensor()`) page-locked CPU tensors. The snippet below illustrates allocating a pinned CPU tensor (`x`). + +```bash +>> import torch +>>> torch.__version__ +'2.6.0+cpu' +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> x = h.new_cpu_locked_tensor(1024, torch.Tensor()) +>>> x.shape +torch.Size([1024]) +>>> x.dtype +torch.float32 +``` + +## Offset-based I/O operations +Previously, DeepNVMe functionality was restricted to reading or writing the entire contents of a file. We have now improved DeepNVMe to read or write a user-specified portion of file content from a given offset. In particular, we have extended the existing read/write APIs to accept a user-specified `file offset` argument (with default value 0) such as below: + +```bash +>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> help(AsyncIOBuilder().load().aio_handle().pread) +Help on method pread in module async_io: + +pread(...) method of async_io.aio_handle instance + pread(self: async_io.aio_handle, buffer: torch.Tensor, filename: str, validate: bool, async: bool, file_offset: int = 0) -> int +``` + + +## Tensor data type casting +While developing FastPersist, we needed to manipulate model tensors, typically of floating point data types, in byte format for both performance and convenience of I/O operations. However, we could not find a zero-copy mechanism for casting tensors from arbitrary data types to a byte data type (i.e., torch.uint8), so we decided to create one. This functionality is available via the `UtilsBuilder` op as demonstrated in the example below. In the example, we cast a `torch.bfloat16` tensor into `torch.uint8`. Note that due to the zero-copy nature of the functionality, `bf16_tensor` and `byte_tensor` are aliases. + +``` +>>> import torch +>>> from deepspeed.ops.op_builder import UtilsBuilder +>>> util_ops = UtilsBuilder().load() +>>> bf16_tensor = torch.zeros(1024, dtype=torch.bfloat16, device='cuda') +>>> bf16_tensor +tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0', dtype=torch.bfloat16) +>>> byte_tensor = util_ops.cast_to_byte_tensor(bf16_tensor) +>>> byte_tensor +tensor([0, 0, 0, ..., 0, 0, 0], device='cuda:0', dtype=torch.uint8) +>>> bf16_tensor += 1.0 +>>> bf16_tensor +tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16) +>>> byte_tensor +tensor([128, 63, 128, ..., 63, 128, 63], device='cuda:0', + dtype=torch.uint8) +``` + +# Summary +This blog post has provided updates on our continued development of DeepNVMe, an I/O optimization technology for accelerating DL applications. We have announced DeepNVMe improvements on multiple aspects, including application coverage, I/O performance scaling, and usability. + +# Acknowledgements +This blog describes work done by Joe Mayer, Logan Adams, and Olatunji Ruwase of the DeepSpeed team at Microsoft. diff --git a/blogs/deepnvme/06-2025/media/dnvme_file_access.png b/blogs/deepnvme/06-2025/media/dnvme_file_access.png new file mode 100644 index 000000000000..83cde84d6559 Binary files /dev/null and b/blogs/deepnvme/06-2025/media/dnvme_file_access.png differ diff --git a/blogs/deepnvme/06-2025/media/dnvme_scaling.png b/blogs/deepnvme/06-2025/media/dnvme_scaling.png new file mode 100644 index 000000000000..a921a8a323cd Binary files /dev/null and b/blogs/deepnvme/06-2025/media/dnvme_scaling.png differ diff --git a/blogs/deepnvme/06-2025/media/fastpersist_phi3_mini.png b/blogs/deepnvme/06-2025/media/fastpersist_phi3_mini.png new file mode 100644 index 000000000000..81a5a925f621 Binary files /dev/null and b/blogs/deepnvme/06-2025/media/fastpersist_phi3_mini.png differ diff --git a/blogs/deepnvme/06-2025/media/fastpersist_tensor.png b/blogs/deepnvme/06-2025/media/fastpersist_tensor.png new file mode 100644 index 000000000000..ea2f9f427c3c Binary files /dev/null and b/blogs/deepnvme/06-2025/media/fastpersist_tensor.png differ diff --git a/blogs/deepnvme/06-2025/media/hf_zinf_llama_70b.png b/blogs/deepnvme/06-2025/media/hf_zinf_llama_70b.png new file mode 100644 index 000000000000..f6de3dfecbe6 Binary files /dev/null and b/blogs/deepnvme/06-2025/media/hf_zinf_llama_70b.png differ diff --git a/blogs/deepnvme/06-2025/media/sg_zinf_llama_70b.png b/blogs/deepnvme/06-2025/media/sg_zinf_llama_70b.png new file mode 100644 index 000000000000..53781992a8c7 Binary files /dev/null and b/blogs/deepnvme/06-2025/media/sg_zinf_llama_70b.png differ diff --git a/blogs/deepnvme/08-2024/README.md b/blogs/deepnvme/08-2024/README.md new file mode 100644 index 000000000000..29bfdd842ee5 --- /dev/null +++ b/blogs/deepnvme/08-2024/README.md @@ -0,0 +1,88 @@ +
+ +# DeepNVMe: Improving DL Applications through I/O Optimizations + +
+ +# Introduction + +Deep Learning (DL) continues to drive unprecedented advancements across important +Artificial Intelligence domains including language, speech, video, and multimodal applications. +A key factor to these advancements is dramatic scalability on multiple dimensions including model size, +sequence length, and hardware parallelism. From a system perspective, DL scalability puts significant +pressure on essential subsystems including computation, memory, communication, and storage. However, +existing DL optimization efforts have mostly neglected the storage subsystem, making I/O operations such +as data loading, model checkpointing, and offloading the main bottlenecks of large-scale DL. To address +this problem, DeepSpeed has created a suite of I/O optimizations collectively called DeepNVMe. + +DeepNVMe improves the performance and efficiency of I/O-bound DL applications by accelerating I/O operations +and reducing hardware requirements. It achieves this by leveraging storage innovations such as Non-Volatile +Memory Express (NVMe) Solid State Drives (SSDs) and NVIDIA Magnum IOTM GPUDirect® Storage (GDS). In this +blog we show the benefits of DeepNVMe using microbenchmarks and an inference application. In experiments +conducted on an Azure NC96ads\_A100\_v4 VM, we observed that DeepNVMe saturates available NVMe bandwidth for +data transfers with GPU or CPU memory, achieving up to 10GB/sec reads and 5 GB/secs writes. + +# Background +High-performance access to persistent storage is a common challenge in many computing domains, including DL. Thus, a significant number of hardware and software solutions have been proposed. DeepNVMe builds on three such solutions: (1) NVMe SSDs, (2) NVIDIA GDS, and (3) Linux Asynchronous I/O (libaio). We will briefly describe each of these technologies. + +NVMe SSDs are Flash-based storage devices that are replacing much slower hard disk drives (HDD) as primary persistent storage in modern servers. For example, an Azure NC96ads\_A100\_v4 VM is equipped with four NVMe SSDs which are individually capable of 3.25 GB/sec reads and can be combined in a RAID-0 configuration for a theoretical aggregate read bandwidth of 13 GB/sec. NVIDIA GDS enables direct transfers between NVMe and GPU memory thus avoiding the inefficiencies of the traditional approach of using intermediate CPU memory (bounce buffer). NVIDIA GDS is generally available in CUDA versions 11.4 and above. Finally, libaio is an asynchronous I/O stack introduced in Linux to better extract raw performance of fast storage devices like NVMe SSDs compared to the traditional I/O stack. + +# DeepNVMe: an Optimization Module for Deep Learning I/O + +DeepNVMe is a Python module that we developed with two key design principles. First, it leverages the above discussed storage technologies to implement powerful optimizations such as non-blocking I/O operations, bulk submission of I/O operations, parallelization of an individual I/O operation, and a lightweight runtime. Second, it exposes these I/O optimizations through a simple POSIX-like interface to foster easy integration into DL applications while avoiding the complexities of the underlying technologies. + +# Evaluation + +Our experiments are conducted on an Azure NC96ads\_A100\_v4 VM with setup details summarized in Table 1. For multi-device experiments, the SSDs are combined in a RAID-0 configuration. + + + +
+Table 1: Experimental setup details +
+ +## Microbenchmark Performance + +We used three benchmarking tools for our evaluations. The first is fio, the popular I/O benchmarking tool written in C. The second is gdsio from NVIDIA for benchmarking GDS performance. The third is ds\_io, a Python tool that we created for easy integration with DeepNVMe and to be more representative of DL applications which are commonly Python-based. + +## High-Performance I/O with CPU Buffers via NVMe Scaling + +Our first set of microbenchmark evaluations used fio and ds\_io to measure the performance of transferring 1GB data between NVMe and CPU memory. We configure fio to use the libaio backend for these experiments. The results are summarized in Figure 1, from which we make two observations. First, DeepNVMe demonstrates high performance as it roughly matches fio, despite being more representative of DL applications. Second, DeepNVMe scales I/O performance almost linearly with available NVMe bandwidth, achieving rates of 10GB/sec reads and 5GB/sec writes. + + + +
+Figure 1: Using DeepNVMe to scale data transfers between NVMe and CPU buffer +
+ +## High-Performance I/O with GPU Buffers via NVMe Scaling + +Our second set of microbenchmark evaluations used gdsio and ds\_io to measure the performance of 1GB data transfer between NVMe and GPU memory. For this experiment, we configure ds\_io to use both the traditional bounce buffer approach and the more efficient GDS approach. The results are summarized in Figure 2, from which we make three observations. First, we see that GDS improves performance in DeepNVMe compared to the traditional bounce buffer approach, with up to 37% speedup. Second, DeepNVMe demonstrates high performance by matching (and sometimes surpassing) gdsio despite being more representative of DL applications. Third, we see that DeepNVMe, with and without GDS, scales I/O performance with available NVMe bandwidth. With GDS, DeepNVMe achieves a maximum of 9.6GB/sec reads and 5GB/sec writes, and without GDS achieves 7GB/sec reads and 4GB/sec writes. + + + +
+Figure 2: Using DeepNVMe to scale data transfers between NVMe and GPU memory +
+ +## ZeRO-Inference: Generative AI Performance + +ZeRO-Inference is an AI democratization technology that reduces the hardware cost of inferencing massive models by using DeepNVMe to offload model weights to CPU or NVMe memory. ZeRO-Inference is well suited for throughput-oriented applications, such as offline inferencing, and for scenarios with limited hardware budget. We use token generation workload to evaluate DeepNVMe performance for NVMe offloading. + +## High-Performance Offloading via NVMe Scaling + +We measure the generation throughput of inferencing a LLAMA3-70B model on a single NVIDIA A100-80GB with a prompt length of 512, generation length of 32, and batch size of 96. We scale the number of NVMe SSDs from 1 to 4 and present the results for ZeRO-Inference with and without GDS in Figure 3. We make two observations from these results. First, GDS consistently provides better performance compared to the bounce buffer approach, achieving 10-18% faster token generation. Second, DeepNVMe, with and without GDS, scales generation performance with available NVMe bandwidth. With four NVMe SSDs, DeepNVMe achieves generation throughput rates of 7 tokens per second with GDS and 6 tokens per second without GDS. Our profiling results suggest that DeepNVMe will continue to scale with more NVMe bandwidth, making it an economic option for boosting generative application performance. + + + +
+Figure 3: Using DeepNVMe to scale LLAMA3-70B token generation performance with NVMe offloading. +
+ +# Summary + +In this blog post, we introduced DeepNVMe, an I/O optimization technology created to tackle the emergence of I/O operations as key bottlenecks of Deep Learning scalability. DeepNVMe enables fast and efficient data transfers between persistent storage and DL application memory through optimizations built on popular storage technologies such as NVMe SSDs and NVIDIA GDS. We showed benefits of using DeepNVMe for LLAMA3-70B token generation on single A100-80GB GPU with NVMe offloading, for which it achieves up to 7 tokens per second in generation throughput on an Azure NC96ads\_A100\_v4 VM. DeepNVMe will be open-sourced and generally available in DeepSpeed versions >= [0.15.0](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.15.0). In future blogs, we will report DeepNVMe improvements for other I/O bound DL applications such as model checkpointing and data loading. + + +# Acknowlegements +This work is the result of a deep collaboration between Microsoft and NVIDIA. The contributors include Joe Mayer, Martin Cai, and Olatunji Ruwase from Microsoft; Kiran Modukuri, Vahid Noormofidi, Sourab Gupta, and Sandeep Joshi from Nvidia. diff --git a/blogs/deepnvme/08-2024/chinese/README.md b/blogs/deepnvme/08-2024/chinese/README.md new file mode 100644 index 000000000000..9fa9d7150c42 --- /dev/null +++ b/blogs/deepnvme/08-2024/chinese/README.md @@ -0,0 +1,77 @@ +
+ +# DeepNVMe: 通过I/O优化提高深度学习应用性能 + +
+ +# 引言 + +深度学习(DL)在语言、语音、视频和多模态应用等重要人工智能领域不断推动着前所未有的进展。这些进展的关键因素是模型大小、序列长度和硬件并行性等多个维度上的显著可扩展性。从系统角度来看,深度学习的可扩展性给计算、内存、通信和存储等关键子系统带来了巨大的压力。然而,现有的深度学习优化工作大多忽略了存储子系统,使得数据加载、模型检查点和卸载等I/O操作成为大规模深度学习中的主要瓶颈。为了解决这个问题,DeepSpeed开发了一整套I/O优化技术,统称为DeepNVMe。 + +DeepNVMe通过加速I/O操作和减少硬件需求,提高了I/O受限的深度学习应用的性能和效率。它通过利用存储创新,如非易失性内存快速通道(NVMe)固态硬盘(SSD)和NVIDIA Magnum IOTM GPUDirect®存储(GDS)实现这一目标。在本文中,我们通过微基准测试和推理应用来展示DeepNVMe的优势。在对Azure NC96ads\_A100\_v4虚拟机进行的实验中,我们观察到DeepNVMe能充分利用可用的NVMe带宽进行GPU或CPU内存的数据传输,读取速度达到10GB/秒,写入速度达到5GB/秒。 + +# 背景 + +高性能访问持久存储是许多计算领域(包括深度学习)中的一个常见挑战。因此,已经提出了大量的硬件和软件解决方案。DeepNVMe基于三种解决方案: (1) NVMe SSDs,(2) NVIDIA GDS,(3) Linux异步I/O(libaio)。我们将简要介绍每项技术。 + +NVMe SSDs是基于闪存的存储设备,正在取代传统的硬盘驱动器(HDD),成为现代服务器的主要持久存储。例如,Azure NC96ads\_A100\_v4虚拟机配备了四个NVMe SSD,每个SSD可提供3.25GB/秒的读取速度,并且可以组合成RAID-0配置,理论上的总读取带宽为13GB/秒。NVIDIA GDS可以实现NVMe和GPU内存之间的直接数据传输,从而避免了传统使用中间CPU内存(缓冲区)方法的低效。NVIDIA GDS在CUDA 11.4及以上版本中可用。最后,libaio是Linux引入的异步I/O栈,它比传统的I/O栈更有效地提取NVMe SSD等高速存储设备的原始性能。 + +# DeepNVMe: 深度学习I/O优化模块 + +DeepNVMe是一个Python模块,我们开发时遵循了两个关键设计原则。首先,它利用上述存储技术,实现了强大的优化,如非阻塞I/O操作、批处理I/O操作提交、单个I/O操作的并行化以及轻量级运行时。其次,它通过一个简单的POSIX-like接口让用户使用I/O优化,便于深度学习应用集成,同时避免了底层技术的复杂性。 + +# 评估 + +我们的实验在Azure NC96ads\_A100\_v4虚拟机上进行,实验设置的详细信息见表1。对于多设备实验,SSD是以RAID-0配置组合使用的。 + + + +
+表1: 实验设置详细信息 +
+ +## 微基准性能测试 + +我们使用了三种基准测试工具进行评估。第一个是fio,这是一个用C语言编写的流行I/O基准测试工具。第二个是来自NVIDIA的gdsio,用于基准测试GDS性能。第三个是ds\_io,这是我们创建的Python工具,便于与DeepNVMe集成,并且更能代表常见的基于Python的深度学习应用。 + +## 通过NVMe扩展CPU缓冲区,从而提高I/O性能 + +我们的第一组微基准评估使用fio和ds\_io,测量1GB数据在NVMe和CPU内存之间的传输性能。我们配置fio使用libaio后端进行这些实验。结果总结在图1中,我们可以得出两个结论。首先,DeepNVMe表现出高性能,尽管它更能代表深度学习应用,但其性能与fio大致相当。其次,DeepNVMe的I/O性能几乎与可用的NVMe带宽成线性扩展,达到了10GB/秒的读取速度和5GB/秒的写入速度。 + + + +
+图1: 使用DeepNVMe扩展NVMe与CPU缓冲区之间的数据传输 +
+ +## 通过NVMe扩展GPU缓冲区,从而提高I/O性能 + +我们的第二组微基准评估使用gdsio和ds\_io,测量1GB数据在NVMe和GPU内存之间的传输性能。在此实验中,我们配置ds\_io同时使用传统的缓冲区方法和更高效的GDS方法。结果总结在图2中,我们可以得出三个结论。首先,我们看到GDS提高了DeepNVMe的性能,相比传统缓冲区方法,速度提高了最多37%。其次,DeepNVMe表现出高性能,尽管它更能代表深度学习应用,但其性能与gdsio相匹配(有时甚至超过)。第三,我们看到DeepNVMe,无论是否使用GDS,都能根据可用的NVMe带宽扩展I/O性能。使用GDS时,DeepNVMe的读取速度最高达到9.6GB/秒,写入速度为5GB/秒;不使用GDS时,读取速度为7GB/秒,写入速度为4GB/秒。 + + + +
+图2: 使用DeepNVMe扩展NVMe与GPU内存之间的数据传输 +
+ +## ZeRO-Inference: 生成式AI性能 + +ZeRO-Inference是一项AI普及技术,通过使用DeepNVMe将模型权重卸载(Offload)到CPU或NVMe内存,降低了推理大规模模型的硬件成本。ZeRO-Inference非常适合于面向吞吐量的应用,如离线推理,和硬件预算有限的场景。我们使用token生成工作负载来评估DeepNVMe在NVMe卸载下的性能。 + +## 通过NVMe扩展的高性能卸载(Offload) + +我们测量了在单个NVIDIA A100-80GB上推理LLAMA3-70B模型的生成吞吐量,使用512的提示长度、32的生成长度和96的批量大小。我们将NVMe SSD的数量从1扩展到4,并呈现了ZeRO-Inference在有GDS和没有GDS的情况下的结果,如图3所示。我们从这些结果中得出两个结论。首先,GDS始终提供比传统缓冲区方法更好的性能,token生成速度提高了10-18%。其次,DeepNVMe,无论是否使用GDS,都能根据可用的NVMe带宽扩展生成性能。在四个NVMe SSD的情况下,DeepNVMe的生成吞吐量分别为每秒7个token(使用GDS)和每秒6个token(不使用GDS)。我们的分析结果表明,DeepNVMe将在更多的NVMe带宽下继续扩展,是提升生成应用性能的经济选择。 + + + +
+图3: 使用DeepNVMe通过NVMe卸载(offload)扩展LLAMA3-70B的token生成性能 +
+ +# 总结 + +在本文中,我们介绍了DeepNVMe,一项为了解决I/O操作成为深度学习可扩展性关键瓶颈而创建的I/O优化技术。DeepNVMe通过基于流行存储技术(如NVMe SSD和NVIDIA GDS)的优化,实现了持久存储与深度学习应用内存之间的快速高效数据传输。我们展示了在Azure NC96ads\_A100\_v4虚拟机上,DeepNVMe通过NVMe卸载支持LLAMA3-70B的token生成,最高达到每秒7个token的生成吞吐量。DeepNVMe将在DeepSpeed版本>= 0.15.0中开源,并广泛发布。在未来的博客中,我们将报告DeepNVMe在其他I/O受限的深度学习应用中的改进,如模型检查点和数据加载。 + +# 致谢 + +这项工作是微软和NVIDIA之间深入合作的结果。贡献者包括微软的Joe Mayer、Martin Cai和Olatunji Ruwase;NVIDIA的Kiran Modukuri、Vahid Noormofidi、Sourab Gupta和Sandeep Joshi。 diff --git a/blogs/deepnvme/08-2024/japanese/README.md b/blogs/deepnvme/08-2024/japanese/README.md new file mode 100644 index 000000000000..26320d00ab94 --- /dev/null +++ b/blogs/deepnvme/08-2024/japanese/README.md @@ -0,0 +1,77 @@ +
+ +# DeepNVMe: I/O最適化による深層学習アプリケーションの高速化 + +
+ +# はじめに + +深層学習(Deep Learning)は、言語、音声、ビデオ、マルチモーダルアプリケーションなどの重要なAIの応用領域において、かつてない進歩を続けています。この進歩の鍵となる要因は、モデルサイズ、シーケンス長、ハードウェア並列性などの複数の次元での劇的なスケーラビリティです。システムの観点から見ると、深層学習のスケーラビリティは計算、メモリ、通信、ストレージなどの重要なサブシステムに大きな負荷をかけます。しかし、既存の取り組みは、ストレージサブシステムの最適化はほとんど扱われておらず、データロード、モデルチェックポイント、オフロードなどのI/O操作が大規模な深層学習の主要なボトルネックとなっています。この問題に対処するために、DeepSpeedは一連のI/O最適化機能を「DeepNVMe」と呼ばれる形で提供します。 + +DeepNVMeは、I/O操作の高速化とハードウェア要件の緩和によって、I/Oがボトルネックとなる深層学習アプリケーションのパフォーマンスと効率を向上させます。これを実現するために、Non-Volatile Memory Express(NVMe)やSSD、NVIDIA Magnum IO ``TM `` GPUDirect® Storage(GDS)などのストレージ技術を活用しています。このブログでは、マイクロベンチマークと推論アプリケーションの性能評価結果に基づいて、DeepNVMeの利点を示します。Azure NC96ads_A100_v4 VMで実施された実験では、DeepNVMeがGPUまたはCPUメモリへのデータ転送で利用可能なNVMe帯域幅を最大限に活用し、最大10GB/秒の読み取りと5GB/秒の書き込みを達成しました。 + +# 背景 + +永続ストレージへの高性能アクセスは、深層学習を含む多くのコンピューティングドメインで共通の課題です。これに対して、多くのハードウェアおよびソフトウェアソリューションが提案されています。DeepNVMeは、以下の3つのソリューションを基に構築されています。(1) NVMe SSD、(2) NVIDIA GDS、(3) Linux非同期I/O(libaio)。これらの技術について簡単に説明します。 + +NVMe SSDは、現代のサーバーで主要な永続ストレージとして、従来の遅いハードディスクドライブ(HDD)に取って代わるフラッシュベースのストレージデバイスです。たとえば、Azure NC96ads_A100_v4 VMには4つのNVMe SSDが装備されており、それぞれが3.25 GB/秒の読み取り速度を持ち、RAID-0構成で組み合わせると理論上の合計読み取り帯域幅は13 GB/秒となります。NVIDIA GDSは、NVMeとGPUメモリ間の直接転送を可能にすることで、中間のCPUメモリ(バウンスバッファ)を使用する従来のアプローチの非効率を回避します。NVIDIA GDSは、CUDAバージョン11.4以上で利用可能です。最後に、libaioは、従来のI/Oスタックと比較して、NVMe SSDのような高速ストレージデバイスの性能をより引き出すためにLinuxに導入された非同期I/Oスタックです。 + +# DeepNVMe: 深層学習のためのI/O最適化モジュール + +DeepNVMeは、以下の2つの主要な設計原則に基づいて開発されたPythonモジュールです。第一に、上記のストレージ技術を活用して、ノンブロッキングI/O操作、I/O操作の一括送信、個々のI/O操作の並列化、軽量なランタイムなどの最適化を実装しています。第二に、これらのI/O最適化をシンプルなPOSIXライクなインターフェースを通じて提供し、深層学習アプリケーションへの容易な統合を促進し、基盤となっている複雑な技術を直接扱うことなく、その性能を活用することを可能にします。 + +# 評価 + +実験は、Azure NC96ads_A100_v4 VMで実施されました。設定の詳細は表1の通りです。 + + + +
+表1: 実験設定の詳細 +
+ +## マイクロベンチマーク + +評価には3つのベンチマークツールを使用しました。一つ目は、C言語で書かれた一般的なI/Oベンチマークツールであるfioです。次に、GDSパフォーマンスのベンチマークを行うためのNVIDIAのgdsioです。最後に、DeepNVMeとの容易な統合のために我々た作成したds_ioです。ds_ioは、深層学習アプリケーションで代表的に使用されるPythonで作成されています。 + +## CPUバッファを使用したNVMeスケーリングによる高性能I/O + +最初のマイクロベンチマーク評価では、fioとds_ioを使用して、NVMeとCPUメモリ間で1GBのデータを転送するパフォーマンスを測定しました。これらの実験ではfioをlibaioバックエンドに設定しました。結果は図1の通りです。ここから、2つの点が読み取れます。第一に、DeepNVMeは、深層学習アプリケーションにおける性能改善を目指したものであるにも関わらず、このマイクロベンチマークでもfioに匹敵する高性能を示しています。第二に、DeepNVMeは、利用可能なNVMe帯域幅にほぼ線形にスケールし、10GB/秒の読み取りおよび5GB/秒の書き込み速度を達成しています。 + + + +
+図1: DeepNVMeを使用したNVMeとCPUバッファ間のデータ転送のスケーリング +
+ +## GPUバッファを使用したNVMeスケーリングによる高性能I/O + +二つ目のマイクロベンチマーク評価では、gdsioとds_ioを使用して、NVMeとGPUメモリ間で1GBのデータ転送のパフォーマンスを測定しました。この実験では、ds_ioを従来のバウンスバッファアプローチとより効率的なGDSアプローチの両方で設定します。結果は図2の通りです。ここから、次の3点が観察できます。第一にGDSを用いるケースで、従来のバウンスバッファアプローチと比較して、DeepNVMeは最大で37%のスピードアップを実現しています。第二に、DeepNVMeは、深層学習アプリケーションのために作成されたものであるにも関わらず、gdsioに匹敵する(時にはそれを上回る)高性能を示します。第三に、DeepNVMeは、GDSの有無にかかわらず、NVMe帯域幅を最大限に活用できます。GDSを使用した場合、DeepNVMeは最大9.6GB/秒の読み取りおよび5GB/秒の書き込み速度を達成し、GDSを使用しない場合は7GB/秒の読み取りおよび4GB/秒の書き込み速度を達成します。 + + + +
+図2: DeepNVMeを使用したNVMeとGPUメモリ間のデータ転送のスケーリング +
+ +## ZeRO-Inference: 生成AIパフォーマンス + +ZeRO-Inferenceは、モデルの重み(パラメータ)をCPUまたはNVMeメモリにオフロードすることで、大規模モデルの推論に必要なハードウェアコストを削減し、限られたハードウェア資源しかないユーザでも大規模モデルを活用できるようにするための技術です。ZeRO-Inferenceは、オフライン推論などのスループット指向のアプリケーションや、ハードウェア予算が限られているシナリオに適しています。DeepNVMeのNVMeオフロードのパフォーマンスを評価するために、トークン生成ワークロードを使用します。 + +## NVMeスケーリングによる高性能オフロード + +LLAMA3-70Bモデルの推論を単一のNVIDIA A100-80GBで、プロンプト長512、生成長32、バッチサイズ96で実行し、生成スループットを測定します。NVMe SSDの数を1から4までスケーリングし、GDSの有無でZeRO-Inferenceの結果を図3に示します。この結果から、2つの観察ができます。第一に、GDSはバウンスバッファアプローチと比較して一貫して優れたパフォーマンスを提供し、トークン生成を10-18%高速化します。第二に、DeepNVMeは、GDSの有無にかかわらず、利用可能なNVMe帯域幅にスケールします。4つのNVMe SSDを使用する場合、DeepNVMeはGDSを使用して1秒あたり7トークン、GDSを使用しない場合は1秒あたり6トークンの生成スループットを達成します。プロファイリング結果は、DeepNVMeがより多くのNVMe帯域幅で引き続きスケールし、生成アプリケーションのパフォーマンスを低コストで向上できることを示しています。 + + + +
+図3: DeepNVMeを使用したLLAMA3-70Bトークン生成パフォーマンスのNVMeオフロードによるスケーリング +
+ +# まとめ + +このブログ記事では、深層学習のスケーラビリティにおいて主要なボトルネックとなるI/O操作を最適化する、DeepNVMeを紹介しました。DeepNVMeは、NVMe SSDやNVIDIA GDSなどのストレージ技術に基づいた最適化を通じて、永続ストレージと深層学習アプリケーションのデータ転送を高速かつ効率的に実現します。Azure NC96ads_A100_v4 VMでの単一A100-80GB GPUを使用したLLAMA3-70Bトークン生成において、DeepNVMeを使用することで、NVMeオフロードで最大7トークン/秒の生成スループットを達成しました。DeepNVMeはオープンソース化され、DeepSpeedバージョン[0.15.0](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.15.0).以上で利用可能です。今後のブログでは、モデルチェックポイントやデータロードなどの他のI/Oがボトルネックとなる深層学習アプリケーションに対するDeepNVMeの改善について報告します。 + +# 謝辞 + +この成果は、MicrosoftとNVIDIAの協力によるものです。MicrosoftからはJoe Mayer、Martin Cai、Olatunji Ruwase、NVIDIAからはKiran Modukuri、Vahid Noormofidi、Sourab Gupta、Sandeep Joshiが貢献しました。 diff --git a/blogs/deepnvme/08-2024/media/figure1.png b/blogs/deepnvme/08-2024/media/figure1.png new file mode 100755 index 000000000000..08db7d2f8afa Binary files /dev/null and b/blogs/deepnvme/08-2024/media/figure1.png differ diff --git a/blogs/deepnvme/08-2024/media/figure2.png b/blogs/deepnvme/08-2024/media/figure2.png new file mode 100755 index 000000000000..35be5d4c4015 Binary files /dev/null and b/blogs/deepnvme/08-2024/media/figure2.png differ diff --git a/blogs/deepnvme/08-2024/media/figure3.png b/blogs/deepnvme/08-2024/media/figure3.png new file mode 100755 index 000000000000..7175236f886b Binary files /dev/null and b/blogs/deepnvme/08-2024/media/figure3.png differ diff --git a/blogs/deepnvme/08-2024/media/table1.png b/blogs/deepnvme/08-2024/media/table1.png new file mode 100755 index 000000000000..bba571369932 Binary files /dev/null and b/blogs/deepnvme/08-2024/media/table1.png differ diff --git a/blogs/deepspeed-chat/README.md b/blogs/deepspeed-chat/README.md index 2e69da1dd733..edb705b3d575 100644 --- a/blogs/deepspeed-chat/README.md +++ b/blogs/deepspeed-chat/README.md @@ -10,19 +10,30 @@ +To cite DeepSpeed Chat, please cite our [arxiv report](https://arxiv.org/abs/2308.01320): + +``` +@article{yao2023dschat, + title={{DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales}}, + author={Zhewei Yao and Reza Yazdani Aminabadi and Olatunji Ruwase and Samyam Rajbhandari and Xiaoxia Wu and Ammar Ahmad Awan and Jeff Rasley and Minjia Zhang and Conglong Li and Connor Holmes and Zhongzhu Zhou and Michael Wyatt and Molly Smith and Lev Kurilenko and Heyang Qin and Masahiro Tanaka and Shuai Che and Shuaiwen Leon Song and Yuxiong He}, + journal={arXiv preprint arXiv:2308.01320}, + year={2023} +} +``` + # 1. Overview -ChatGPT like models have taken the AI world by a storm, and it would not be an overstatement to say that its impact on the digital world has been revolutionary. These models are incredibly versatile, capable of performing tasks like summarization, coding, and translation with results that are on-par or even exceeding the capabilities of human experts. Given the sheer power of these models, multiple efforts are underway in the AI open-source community to make ChatGPT-style models more accessible (e.g. ChatLLaMa, Alpaca, Vicuna, Databricks-Dolly, etc.). +ChatGPT-like models have taken the AI world by storm, and it would not be an overstatement to say that its impact on the digital world has been revolutionary. These models are incredibly versatile, capable of performing tasks like summarization, coding, and translation with results that are on-par or even exceeding the capabilities of human experts. Given the sheer power of these models, multiple efforts are underway in the AI open-source community to make ChatGPT-style models more accessible (e.g. ChatLLaMa, Alpaca, Vicuna, Databricks-Dolly, etc.). -Despite these incredible efforts, there is still a lack of an end-to-end RLHF pipeline capable of training powerful ChatGPT like model that is easily accessible to the AI community. For instance, training a modest 6.7B ChatGPT model with existing systems typically requires expensive multi-GPU setup that is beyond the reach of many data scientists. Even with access to such computing resources, training efficiency is often less than 5% of what these machines are capable of (as illustrated [shortly](#effective-throughput-and-scalability-analysis)). And finally, existing solutions simply cannot support easy, fast and affordable training state-of-the-art ChatGPT models with hundreds of billions of parameters, even given access to multi-GPU clusters. +Despite these incredible efforts, there is still a lack of an end-to-end RLHF pipeline capable of training a powerful ChatGPT-like model that is easily accessible to the AI community. For instance, training a modest 6.7B ChatGPT model with existing systems typically requires an expensive multi-GPU setup that is beyond the reach of many data scientists. Even with access to such computing resources, training efficiency is often less than 5% of what these machines are capable of (as illustrated [shortly](#effective-throughput-and-scalability-analysis)). Finally, existing solutions simply cannot support easy, fast and affordable training state-of-the-art ChatGPT models with hundreds of billions of parameters, even given access to multi-GPU clusters. -These limitations stem from a lack of a robust system design that is capable of effectively supporting the complex InstructGPT’s RLHF training pipeline that is quite different from the standard pre-training and fine-tuning pipelines that existing DL systems are designed for. Therefore, in the spirit of democratizing ChatGPT-like models, and making RLHF training truly accessible to the AI community, today we are releasing DeepSpeed-Chat with the following three capabilities: +These limitations stem from a lack of a robust system design that is capable of effectively supporting the complex InstructGPT’s RLHF training pipeline which is quite different from the standard pre-training and fine-tuning pipelines that existing DL systems are designed for. Therefore, in the spirit of democratizing ChatGPT-like models, and making RLHF training truly accessible to the AI community, today we are releasing DeepSpeed-Chat with the following three capabilities: (i) ***Easy-to-use Training and Inference Experience for ChatGPT Like Models***: A single script capable of taking a pre-trained Huggingface model, running it through all three steps of InstructGPT training using DeepSpeed-RLHF system and producing your very own ChatGPT like model. In addition, we provide an inference API for testing conversation-style interactions after the model is trained. (ii) ***DeepSpeed-RLHF Pipeline***: DeepSpeed-RLHF pipeline primarily replicates the training pipeline from the InstructGPT paper with careful attention to ensure completeness and one-to-one correspondence with the three-steps that includes a) Supervised Fine-tuning (SFT), b) Reward Model Fine-tuning and c) Reinforcement Learning with Human Feedback (RLHF). Additionally, we offer data abstraction and blending capabilities to enable training with multiple data sources. -(iii) ***DeepSpeed-RLHF System***: A robust and sophisticated RLHF system that combines the training and inference prowess of DeepSpeed into single unified Hybrid Engine (DeepSpeed-HE) for RLHF. The Hybrid-Engine is capable of seamlessly transitioning between inference and training modes within RLHF, allowing it to leverage various optimizations from DeepSpeed-Inference such as tensor-parallelism and high-performance transformer kernels for generation, while also benefiting from the multitude of ZeRO- and LoRA-based memory optimization strategies for RL training. DeepSpeed-HE is also aware of the full RLHF pipeline, allowing it to make optimal decisions in terms of memory management and data movement across different phases of RLHF. +(iii) ***DeepSpeed-RLHF System***: A robust and sophisticated RLHF system that combines the training and inference prowess of DeepSpeed into a single unified Hybrid Engine (DeepSpeed-HE) for RLHF. The Hybrid-Engine is capable of seamlessly transitioning between inference and training modes within RLHF, allowing it to leverage various optimizations from DeepSpeed-Inference such as tensor-parallelism and high-performance transformer kernels for generation, while also benefiting from the multitude of ZeRO- and LoRA-based memory optimization strategies for RL training. DeepSpeed-HE is also aware of the full RLHF pipeline, allowing it to make optimal decisions in terms of memory management and data movement across different phases of RLHF. DeepSpeed-RLHF system is capable of unparalleled efficiency at scale, making complex RLHF training fast, affordable, and easily accessible to the AI community: @@ -54,7 +65,7 @@ DeepSpeed-RLHF system is capable of unparalleled efficiency at scale, making com *Table 2. Multi-Node 64x A100-80GB: Training Time and Corresponding Approximate Cost on Azure.* -> ***Very Important Details***: The numbers in both tables above are for Step 3 of the training and based on actual measured training throughput on DeepSpeed-RLHF curated dataset and training recipe which trains for one epoch on a total of 135M tokens. We have in total 67.5M query tokens (131.9k queries with sequence length 256) and 67.5M generated tokens (131.9k answers with sequence length 256), and a maximum global batch size per step of 0.5M tokens (1024 query-answer pairs). We urge readers to pay attention to these specifications before making any cost and e2e time comparisons with DeepSpeed-RLHF. See our [benchmark settings](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md) page for more details. +> ***Very Important Details***: The numbers in both tables above are for Step 3 of the training and are based on actual measured training throughput on DeepSpeed-RLHF curated dataset and training recipe which trains for one epoch on a total of 135M tokens. We have in total 67.5M query tokens (131.9k queries with sequence length 256) and 67.5M generated tokens (131.9k answers with sequence length 256), and a maximum global batch size per step of 0.5M tokens (1024 query-answer pairs). We urge readers to pay attention to these specifications before making any cost and e2e time comparisons with DeepSpeed-RLHF. See our [benchmark settings](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md) page for more details. ***Democratizing RLHF Training***: With just a single GPU, DeepSpeed-HE supports training models with over 13 billion parameters, enabling data scientists without access to multi-GPU systems to create not just toy RLHF models but large and powerful ones that can be used in real-world scenarios. @@ -82,15 +93,13 @@ We start with the easy-to-use experience by showing how you can train OPT-13B an We use an example of pretrained OPT-13B as the actor model and OPT-350M as the reward model in the following single script to generate a final 13B ChatGPT-style model: ```python -git clone https://github.com/microsoft/DeepSpeed.git -cd DeepSpeed -pip install . +pip install deepspeed>=0.9.0 -git clone https://github.com/microsoft/DeepSpeedExamples.git +git clone https://github.com/deepspeedai/DeepSpeedExamples.git cd DeepSpeedExamples/applications/DeepSpeed-Chat/ pip install -r requirements.txt -python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --num-gpus 8 +python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --deployment-type single_node ``` In about half a day, your 13B model would be fully trained with its checkpoints ready. The following table demonstrates a breakdown of the training time for each of the three steps: @@ -123,7 +132,7 @@ Assistant:  Sure, I can try.  Microsoft is a company that makes computers, a We understand users often like to play with different model sizes and configurations to meet their training time, resources, and quality requirements. With DeepSpeed-Chat, users can easily do that. For example, if you want to train a larger and higher-quality model on your GPU cluster for your research or business, you can simply use the same script with your desired model size e.g., 66B and GPU counts e.g., 64 GPUs: ```python -python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --num-gpus 64 +python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --deployment-type multi_node ``` Within 9 hours, you can have your 66 billion parameters ChatGPT model ready to be served in your favorite front-end GUI: @@ -142,7 +151,7 @@ Table 5. E2E time breakdown for training a 66 billion parameter ChatGPT model vi If you only have around 1-2 hours for coffee or lunch break, you can also try to train a small/toy model with DeepSpeed-Chat. For example, we prepared a training example for a 1.3B model with a single dataset to test our framework on your consumer-grade GPUs. The best part is that you will have your model checkpoint ready to play with when you are back from your lunch break! ```python -python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --num-gpus 1 +python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu ```
@@ -199,16 +208,16 @@ We provide two additional features in Step 3 to help improve model quality: The two training features, EMA and Mixed Training, are often omitted by other recent efforts since they can be optional. However, according to InstructGPT, EMA checkpoints generally provide better response quality than conventional final trained model and Mixture Training can help the model retain the pre-training benchmark solving ability. As such, we provide them for users to fully get the training experience as described in InstructGPT and strike for higher model quality. -In addition to being highly consistent with InstructGPT paper, we also provide convenient features to support researchers and practitioners to train their own RLHF model with multiple data resources: +In addition to being highly consistent with InstructGPT paper, we also provide convenient features to support researchers and practitioners in training their own RLHF model with multiple data resources: -* ***Data Abstraction and Blending Capabilities:*** DeepSpeed-Chat is able to train the model with multiple datasets for better model quality. It is equipped with (1) an abstract dataset layer to unify the format of different datasets; and (2) data splitting/blending capabilities so that the multiple datasets are properly blended then split across the 3 training stages. +* ***Data Abstraction and Blending Capabilities:*** DeepSpeed-Chat is able to train the model with multiple datasets for better model quality. It is equipped with (1) an abstract dataset layer to unify the format of different datasets; and (2) data splitting/blending capabilities so that the multiple datasets are properly blended and then split across the 3 training stages. To illustrate the effectiveness of our training pipeline, we demonstrate the model quality with multi-round conversation as shown in the experience section. # 4. DeepSpeed Hybrid Engine – Unified Infrastructure to Power and Optimize RLHF Training -Step 1 and Step 2 of the instruct-guided RLHF pipeline resemble regular fine-tuning of large models, and they are powered by ZeRO-based optimizations and flexible combination of parallelism strategies in DeepSpeed training to achieve scale and speed. Step 3 of the pipeline, on the other hand, is the most complex part to handle in terms of performance implications. Each iteration requires efficient processing of two phases a) inference phase for token/experience generation, producing inputs for the training and b) training phase to update the weights of actor and reward models, as well as the interaction and scheduling between them. It introduces two major costs: (1) the memory cost, as several copies of the SFT and RW models need to be served throughout stage 3; and (2) the predominant generation phase, which if not accelerated properly, will significantly slow down the entire stage 3. Additionally, the two important features we added in Stage 3, including Exponential Moving Average (EMA) collection and Mixture Training, will incur additional memory and training costs. +Step 1 and Step 2 of the instruct-guided RLHF pipeline resemble regular fine-tuning of large models, and they are powered by ZeRO-based optimizations and a flexible combination of parallelism strategies in DeepSpeed training to achieve scale and speed. Step 3 of the pipeline, on the other hand, is the most complex part to handle in terms of performance implications. Each iteration requires efficient processing of two phases a) inference phase for token/experience generation, producing inputs for the training and b) training phase to update the weights of actor and reward models, as well as the interaction and scheduling between them. It introduces two major costs: (1) the memory cost, as several copies of the SFT and RW models need to be served throughout stage 3; and (2) the predominant generation phase, which if not accelerated properly, will significantly slow down the entire stage 3. Additionally, the two important features we added in Stage 3, including Exponential Moving Average (EMA) collection and Mixture Training, will incur additional memory and training costs. To tackle these challenges, we composed the full system capability of DeepSpeed Training and Inference into a unified infrastructure that we call **Hybrid Engine**. It leverages the original DeepSpeed engines for fast training mode while effortlessly applying DeepSpeed inference engine for generation/evaluation mode, providing a significantly faster training system for RLHF training at Stage 3. As Figure 2 shows, the transition between DeepSpeed training and inference engine is seamless: by having the typical eval and train modes enabled for the actor model, when running for inference and training pipeline, DeepSpeed selects its different optimizations to run the model faster and improve the overall system throughput. @@ -221,11 +230,11 @@ To tackle these challenges, we composed the full system capability of DeepSpeed
-During its inference execution for experience generation phase of RLHF training, DeepSpeed Hybrid Engine uses a light-weight memory management system to handle the KV-cache and intermediate results, together with highly optimized inference-adapted kernels and tensor parallelism implementation, to achieve significant boost in throughput (tokens-per-second) compared to the existing solutions. +During its inference execution for the experience generation phase of RLHF training, DeepSpeed Hybrid Engine uses a light-weight memory management system to handle the KV-cache and intermediate results, together with highly optimized inference-adapted kernels and tensor parallelism implementation, to achieve a significant boost in throughput (tokens-per-second) compared to the existing solutions. During the training execution, Hybrid Engine enables memory optimization techniques such as DeepSpeed’s ZeRO family of technologies and Low Rank Adaption (LoRA). We designed and implemented these system optimizations in a way that they are compatible with each other and can be composed together to deliver the highest training efficiency under the unified Hybrid Engine. -Hybrid Engine can seamlessly change model partitioning across training and inference to support tensor-parallelism based inferencing and ZeRO-based sharding mechanism for training. It can also reconfigure the memory system to maximize memory availability during each of these modes. This allows for improved performance by avoiding memory allocation bottlenecks and supporting large batch sizes. Packed with a spectrum of system technologies from DeepSpeed training and inference, Hybrid Engine pushes the boundary of modern RLHF training and delivers unparalleled scale and system efficiency for RLHF workloads. +Hybrid Engine can seamlessly change model partitioning across training and inference to support tensor-parallelism based inferencing and ZeRO-based sharding mechanisms for training. It can also reconfigure the memory system to maximize memory availability during each of these modes. This allows for improved performance by avoiding memory allocation bottlenecks and supporting large batch sizes. Packed with a spectrum of system technologies from DeepSpeed training and inference, Hybrid Engine pushes the boundary of modern RLHF training and delivers unparalleled scale and system efficiency for RLHF workloads. # 5. DeepSpeed RLHF: Unparalleled Scale and Efficiency via Hybrid Engine @@ -264,7 +273,7 @@ No icons represent OOM scenarios.* -This improvement in efficiency stems from DeepSpeed-HE’s ability to accelerate RLHF generation phase of the RLHF processing leveraging DeepSpeed inference optimizations. Figure 5 shows the time breakdown for a 1.3B parameter model at an RLHF training iteration: majority of the time goes to the generation phase. By leveraging high performance inference kernels from DeepSpeed, DeepSpeed-HE can achieve up to 9x throughput improvement during this phase over HuggingFace and 15x over Colossal-AI allowing it to achieve unparallel end-to-end efficiency. +This improvement in efficiency stems from DeepSpeed-HE’s ability to accelerate RLHF generation phase of the RLHF processing by leveraging DeepSpeed inference optimizations. Figure 5 shows the time breakdown for a 1.3B parameter model at an RLHF training iteration: majority of the time goes to the generation phase. By leveraging high performance inference kernels from DeepSpeed, DeepSpeed-HE can achieve up to 9x throughput improvement during this phase over HuggingFace and 15x over Colossal-AI allowing it to achieve unparalleled end-to-end efficiency.
@@ -276,7 +285,7 @@ This improvement in efficiency stems from DeepSpeed-HE’s ability to accelerate ## Effective Throughput and Scalability Analysis -***(I) Effective Throughput Analysis.*** The effective throughput of DeepSpeed-HE during Stage 3 of the RLHF training depends on the throughput that it achieves during the generation and RL training phases. In our RLHF pipeline, the generation phase comprises approximately 20% of the total computation while the RL training phase comprises of remaining 80% (see [benchmark settings](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md) page for details). However, despite having a small proportion, the former can take a large portion of the e2e time as it requires running the actor model once for each of the 256 generated tokens with initial prompt of 256 tokens, making it memory bandwidth bound and difficult to achieve high throughput. In contrast, the RL training phase is compute bound running the reference actor model with just a couple of forward and backward passes with full 512 tokens from both prompt and generation per sample and can achieve good throughput. +***(I) Effective Throughput Analysis.*** The effective throughput of DeepSpeed-HE during Stage 3 of the RLHF training depends on the throughput that it achieves during the generation and RL training phases. In our RLHF pipeline, the generation phase comprises approximately 20% of the total computation while the RL training phase comprises of remaining 80% (see [benchmark settings](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md) page for details). However, despite having a small proportion, the former can take a large portion of the e2e time as it requires running the actor model once for each of the 256 generated tokens with an initial prompt of 256 tokens, making it memory bandwidth bound and difficult to achieve high throughput. In contrast, the RL training phase is compute bound running the reference actor model with just a couple of forward and backward passes with full 512 tokens from both prompt and generation per sample and can achieve good throughput.
@@ -286,7 +295,7 @@ This improvement in efficiency stems from DeepSpeed-HE’s ability to accelerate
-To maximize the effective throughput, DeepSpeed-HE optimizes both phases. First, it uses the largest batch size possible to get higher efficiency on both phases. Second, during the generation phase, it leverages high-performance transformer kernels to maximize GPU memory bandwidth utilization when the model fits in single GPU memory, and leverage tensor-parallelism (TP) when it does not. Using TP in the generation phase instead of ZeRO to fit the model reduces the inter-GPU communication and maintains high GPU memory bandwidth utilization. +To maximize the effective throughput, DeepSpeed-HE optimizes both phases. First, it uses the largest batch size possible to get higher efficiency in both phases. Second, during the generation phase, it leverages high-performance transformer kernels to maximize GPU memory bandwidth utilization when the model fits in single GPU memory, and leverages tensor-parallelism (TP) when it does not. Using TP in the generation phase instead of ZeRO to fit the model reduces the inter-GPU communication and maintains high GPU memory bandwidth utilization. Figure 6 shows the best achievable effective throughput for DeepSpeed-HE in terms of TFlops/GPU for model sizes ranging from 1.3B to 175B. It also shows the throughput achieved by each of the generation and training phases. DeepSpeed-HE is the most efficient for models in the range 6.7B-66B. Going beyond this range to 175B, the throughput drops due to the limited memory to support larger batch sizes, while still achieving 1.2x better efficiency than the small 1.3B model. The per-GPU throughput of these gigantic models could improve further when we scale them to more GPUs with more memory available for larger batch sizes. @@ -302,7 +311,7 @@ Furthermore, we would like to point out that our effective performance is 19x hi ***(II) Scalability Analysis.*** The best effective throughput for different model sizes is achieved at different GPU count. This is in part because some of the larger model sizes require more memory to run. However, a large part of this behavior stems from DeepSpeed-HE’s scalability properties that we discuss next. -Figure 7 shows that DeepSeed-RLHF has achieved good scaling overall on up to 64 GPUs. However, if we look more closely, it shows that DeepSpeed-RLHF training achieves super-linear scaling at small scale, followed by near linear or sub-linear scaling at larger scales. This is due to interaction between memory availability and max global batch size. +Figure 7 shows that DeepSpeed-RLHF has achieved good scaling overall on up to 64 GPUs. However, if we look more closely, it shows that DeepSpeed-RLHF training achieves super-linear scaling at small scale, followed by near linear or sub-linear scaling at larger scales. This is due to the interaction between memory availability and max global batch size. As DeepSpeed-HE is powered by ZeRO-based technology for training, it allows model states to be partitioned across the available GPUs. As a result, the memory consumption per GPU reduces with the increase in the number of GPUs, allowing DeepSpeed-HE to support a larger batch per GPU resulting in super-linear scaling. However, at large scale, while the available memory continues to increase, the maximum global batch size (1024, in our case, with a sequence length of 512) limits the batch size per GPU, resulting in near-linear or sub-linear scaling. As a result, for a given max global batch size, DeepSpeed-HE achieves the best throughput and cost efficiency at the boundary of super-linear and sub-linear scalability, and the exact point is mostly determined by the largest batch size that can be run per GPU as the function of available memory and global batch size. @@ -311,13 +320,13 @@ As a result, for a given max global batch size, DeepSpeed-HE achieves the best t We are very excited to share that DeepSpeed-Chat is now open-sourced and available to the AI community. -* To get started, please visit our github page for DeepSpeed-Chat: [GitHub Landing Page](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) +* To get started, please visit our github page for DeepSpeed-Chat: [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) -* We will continue to improve DeepSpeed-Chat with your feedback and support. Our [roadmap](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-) shows currently supported features as well as ones that are planned for future. +* We will continue to improve DeepSpeed-Chat with your feedback and support. Our [roadmap](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-) shows currently supported features as well as ones that are planned for the future. -DeepSpeed-Chat is part of the bigger DeepSpeed ecosystem comprising of a multitude of Deep Learning systems and modeling technologies. To learn more, +DeepSpeed-Chat is part of the bigger DeepSpeed ecosystem comprising a multitude of Deep Learning systems and modeling technologies. To learn more, * Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. -* You can also follow us on our [English Twitter](https://twitter.com/MSFTDeepSpeed) and [Japanese Twitter](https://twitter.com/MSFTDeepSpeedJP) for latest news on DeepSpeed. +* You can also follow us on our [English Twitter](https://twitter.com/DeepSpeedAI), [Japanese Twitter](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. -DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) page. Please see our [contributing guide](https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to deepspeed-info@microsoft.com. +DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. diff --git a/blogs/deepspeed-chat/assets/images/ds_chat_main.png b/blogs/deepspeed-chat/assets/images/ds_chat_main.png new file mode 100644 index 000000000000..3266a425b102 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/ds_chat_main.png differ diff --git a/blogs/deepspeed-chat/assets/images/ds_chat_stability_sweep.png b/blogs/deepspeed-chat/assets/images/ds_chat_stability_sweep.png new file mode 100644 index 000000000000..d98cd765bc6a Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/ds_chat_stability_sweep.png differ diff --git a/blogs/deepspeed-chat/assets/images/ds_chat_zero_offload_gpu.png b/blogs/deepspeed-chat/assets/images/ds_chat_zero_offload_gpu.png new file mode 100644 index 000000000000..935cadc0cf13 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/ds_chat_zero_offload_gpu.png differ diff --git a/blogs/deepspeed-chat/assets/images/dschat-llama-13b-HE-perf.png b/blogs/deepspeed-chat/assets/images/dschat-llama-13b-HE-perf.png new file mode 100644 index 000000000000..56cf6280d8a5 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/dschat-llama-13b-HE-perf.png differ diff --git a/blogs/deepspeed-chat/assets/images/dschat-llama-7b-HE-perf.png b/blogs/deepspeed-chat/assets/images/dschat-llama-7b-HE-perf.png new file mode 100644 index 000000000000..93342fffbc60 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/dschat-llama-7b-HE-perf.png differ diff --git a/blogs/deepspeed-chat/assets/images/dschat-mpzero-llama.png b/blogs/deepspeed-chat/assets/images/dschat-mpzero-llama.png new file mode 100644 index 000000000000..ae7f5f62f52e Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/dschat-mpzero-llama.png differ diff --git a/blogs/deepspeed-chat/assets/images/llama2-with-he.png b/blogs/deepspeed-chat/assets/images/llama2-with-he.png new file mode 100644 index 000000000000..de5015d277bf Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/llama2-with-he.png differ diff --git a/blogs/deepspeed-chat/assets/images/sweep_after_lora_fix.png b/blogs/deepspeed-chat/assets/images/sweep_after_lora_fix.png new file mode 100644 index 000000000000..d12dca7ac49c Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/sweep_after_lora_fix.png differ diff --git a/blogs/deepspeed-chat/assets/images/sweep_before_lora_fix.png b/blogs/deepspeed-chat/assets/images/sweep_before_lora_fix.png new file mode 100644 index 000000000000..ab01a46cddf7 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/sweep_before_lora_fix.png differ diff --git a/blogs/deepspeed-chat/assets/images/zero_offload_after_stability.png b/blogs/deepspeed-chat/assets/images/zero_offload_after_stability.png new file mode 100644 index 000000000000..8392f8230ff7 Binary files /dev/null and b/blogs/deepspeed-chat/assets/images/zero_offload_after_stability.png differ diff --git a/blogs/deepspeed-chat/chinese/README.md b/blogs/deepspeed-chat/chinese/README.md index adcaa39c3df8..4ad48dbc8c39 100644 --- a/blogs/deepspeed-chat/chinese/README.md +++ b/blogs/deepspeed-chat/chinese/README.md @@ -10,6 +10,17 @@
+如需引用 DeepSpeed Chat,请引用我们的[arxiv report](https://arxiv.org/abs/2308.01320): + +``` +@article{yao2023dschat, + title={{DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales}}, + author={Zhewei Yao and Reza Yazdani Aminabadi and Olatunji Ruwase and Samyam Rajbhandari and Xiaoxia Wu and Ammar Ahmad Awan and Jeff Rasley and Minjia Zhang and Conglong Li and Connor Holmes and Zhongzhu Zhou and Michael Wyatt and Molly Smith and Lev Kurilenko and Heyang Qin and Masahiro Tanaka and Shuai Che and Shuaiwen Leon Song and Yuxiong He}, + journal={arXiv preprint arXiv:2308.01320}, + year={2023} +} +``` + # 1. 概述 近日来,ChatGPT及类似模型引发了人工智能(AI)领域的一场风潮。 这场风潮对数字世界产生了革命性影响。ChatGPT类模型具有惊人的泛用性,能够执行归纳、编程、翻译等任务,其结果与人类专家相当甚至更优。为了使ChatGPT等模型的训练和部署更轻松,AI 开源社区进行了各种尝试(例如 ChatLLaMa、Alpaca、Vicuna、Databricks-Dolly等)。 @@ -52,7 +63,7 @@ DeepSpeed-RLHF 系统在大规模训练中具有无与伦比的效率,使复 *表 2. 多节点 64x A100-80GB:训练时长及预估的 Azure 费用。* -> ***非常重要的细节***: 上述两个表格(即表一和表二)中的数据均针对 RLHF 训练的第 3 步,基于实际数据集和 DeepSpeed-RLHF 训练吞吐量的测试。该训练在总共 1.35 亿(135M)个字符(token)上进行一个时期(epoch)的训练。我们总共有 6750 万个查询(query)字符(131.9k 个 query,每个序列长度为 256)和 6750 万个生成/回答字符(131.9k 个答案,每个序列长度为 256),每步的最大全局字符批量大小约为 500 万个字符(1024 个查询-答案对)。在与 DeepSpeed-RLHF 进行任何成本和端到端时间比较之前,我们建议读者注意这些设定。想要了解更多详细信息,请参阅我们的页面 [benchmark setting](https://github.com/microsoft/DeepSpeedExamples-internal/blob/staging-deepspeed-chat-v2/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)。 +> ***非常重要的细节***: 上述两个表格(即表一和表二)中的数据均针对 RLHF 训练的第 3 步,基于实际数据集和 DeepSpeed-RLHF 训练吞吐量的测试。该训练在总共 1.35 亿(135M)个字符(token)上进行一个时期(epoch)的训练。我们总共有 6750 万个查询(query)字符(131.9k 个 query,每个序列长度为 256)和 6750 万个生成/回答字符(131.9k 个答案,每个序列长度为 256),每步的最大全局字符批量大小约为 500 万个字符(1024 个查询-答案对)。在与 DeepSpeed-RLHF 进行任何成本和端到端时间比较之前,我们建议读者注意这些设定。想要了解更多详细信息,请参阅我们的页面 [benchmark setting](https://github.com/deepspeedai/DeepSpeedExamples-internal/blob/staging-deepspeed-chat-v2/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)。 ***实现 RLHF 训练的普及化***:仅凭单个 GPU,DeepSpeed-HE 就能支持训练超过 130 亿参数的模型。这使得那些无法使用多 GPU 系统的数据科学家和研究者不仅能够轻松创建轻量级的 RLHF 模型,还能创建大型且功能强大的模型,以应对不同的使用场景。 @@ -78,15 +89,13 @@ DeepSpeed-RLHF 系统在大规模训练中具有无与伦比的效率,使复 以下示例展示了如何使用一个脚本,以预训练的 OPT-13B 作为 actor 模型,OPT-350M 作为 reward 模型,生成一个最终的 130 亿参数的 ChatGPT 类型的模型: ``` -git clone https://github.com/microsoft/DeepSpeed.git -cd DeepSpeed -pip install . +pip install deepspeed>=0.9.0 -git clone https://github.com/microsoft/DeepSpeedExamples.git +git clone https://github.com/deepspeedai/DeepSpeedExamples.git cd DeepSpeedExamples/applications/DeepSpeed-Chat/ pip install -r requirements.txt -python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --num-gpus 8 +python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --deployment-type single_node ``` @@ -109,7 +118,7 @@ python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m ``` Human: Do you know Microsoft? Assistant: Yes, I’m a big fan of Microsoft. What about them? -Human: Can you explian it to a 6-year old child? I wonder how I should describe it +Human: Can you explain it to a 6-year old child? I wonder how I should describe it Assistant: Sure, I can try. Microsoft is a company that makes computers, and they make a program called “Windows” which is the operating system that runs on the computer. It’s like the “brain” of the computer. It’s where all the programs and files are stored. It’s also where you’ll find the “start” button, which is where you’ll start to @@ -120,7 +129,7 @@ Assistant: Sure, I can try. Microsoft is a company that makes computers, and 我们了解到用户通常喜欢尝试不同的模型大小和配置,以满足他们不同的训练时间、资源和质量的需求。借助 DeepSpeed-Chat,你可以轻松实现这些目标。例如,如果你想在 GPU 集群上训练一个更大、更高质量的模型,用于你的研究或业务,你可以使用相同的脚本,只需输入你期望的模型大小(例如 660 亿参数)和 GPU 数量(例如 64 个 GPU): ```python -python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --num-gpus 64 +python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --deployment-type multi_node ``` 在 接下来的9 小时内,你将拥有一个 660 亿参数的 ChatGPT 模型,并可以在你喜欢的前端 GUI 中使用: @@ -138,7 +147,7 @@ python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m 如果你只有大约 1-2 小时的咖啡或午餐休息时间,你也可以尝试使用 DeepSpeed-Chat 训练一个小型模型。例如,我们为单个数据集准备了一个 13 亿参数模型的训练示例,以便在你的消费级 GPU 上测试我们的框架。当你从午餐休息回来时,你的模型已经准备好供你使用! ```python -python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --num-gpus 1 +python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu ```
@@ -265,7 +274,7 @@ DeepSpeed-HE可以在训练和推理之间无缝更改模型分区,以支持 ## 有效吞吐量和可扩展性分析 -***(I) 有效吞吐量分析。*** 在 RLHF 训练的第 3 阶段,DeepSpeed-HE 的有效吞吐量取决于它在生成和 RL 训练阶段所实现的吞吐量。在我们的 RLHF (详见 [benchmarking setting](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md))中,生成阶段占总计算的约 20%,而 RL 训练阶段占剩余的 80%。然而,尽管比例较小,前者可能会占用大部分的端到端时间,因为它需要为每个生成的字符运行一次 actor 模型,使其受到内存带宽限制,难以实现高吞吐量。相比之下,RL 训练阶段是计算密集型的,仅需运行参考 actor 模型进行几次前向和后向传递,每个样本都有来自提示和生成的全部 512 个字符,可以实现良好的吞吐量。 +***(I) 有效吞吐量分析。*** 在 RLHF 训练的第 3 阶段,DeepSpeed-HE 的有效吞吐量取决于它在生成和 RL 训练阶段所实现的吞吐量。在我们的 RLHF (详见 [benchmarking setting](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md))中,生成阶段占总计算的约 20%,而 RL 训练阶段占剩余的 80%。然而,尽管比例较小,前者可能会占用大部分的端到端时间,因为它需要为每个生成的字符运行一次 actor 模型,使其受到内存带宽限制,难以实现高吞吐量。相比之下,RL 训练阶段是计算密集型的,仅需运行参考 actor 模型进行几次前向和后向传递,每个样本都有来自提示和生成的全部 512 个字符,可以实现良好的吞吐量。
@@ -291,7 +300,7 @@ DeepSpeed-HE可以在训练和推理之间无缝更改模型分区,以支持 ***(II) 可扩展性分析。*** 不同模型大小的最佳有效吞吐量取决于不同的 GPU 数量。部分原因是因为一些较大的模型大小需要更多的内存来运行。基于此,我们接下来讨论 DeepSpeed-HE 的可扩展性特性。 -图 7 显示 DeepSeed-RLHF 在多达 64 个 GPU的集群 上实现了良好的整体扩展。然而,如果我们仔细观察,可以发现 DeepSpeed-RLHF 训练在小规模时实现了超线性扩展,随后在较大规模时实现了接近线性或次线性扩展。这是由于内存可用性和最大全局批量大小之间的相互作用。 +图 7 显示 DeepSpeed-RLHF 在多达 64 个 GPU的集群 上实现了良好的整体扩展。然而,如果我们仔细观察,可以发现 DeepSpeed-RLHF 训练在小规模时实现了超线性扩展,随后在较大规模时实现了接近线性或次线性扩展。这是由于内存可用性和最大全局批量大小之间的相互作用。 DeepSpeed-HE 的核心技术基于 ZeRO,用于训练过程中将模型状态分割到每个GPU上。这意味着随着 GPU 数量的增加,每个 GPU 的内存消耗会减少,使得 DeepSpeed-HE 能够在每个 GPU 上支持更大的批量,从而实现超线性扩展。然而,在大规模情况下,尽管可用内存持续增加,但最大全局批量大小仍然限制了每个 GPU 的批量大小,导致接近线性或次线性扩展。因此,在给定的最大全局批量大小(例如,我们设置为 1024 个句子,每个句子长度为 512)下,DeepSpeed-HE 在超线性和次线性可扩展性之间实现了最佳的吞吐量和成本效益。具体的平衡点主要取决于每个 GPU 上可运行的最大批量大小,而这又受到可用内存和全局批量大小的函数所决定。 @@ -299,18 +308,18 @@ DeepSpeed-HE 的核心技术基于 ZeRO,用于训练过程中将模型状态 我们非常高兴地宣布,DeepSpeed-Chat现已开源并向 AI 社区开放。 -* 如果你发现我们的成果对你有用或者喜欢我们的开源成果,请在 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)上点⭐。 +* 如果你发现我们的成果对你有用或者喜欢我们的开源成果,请在 [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) 和 [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples)上点⭐。 -* 请访问我们的DeepSpeed-Chat GitHub页面以开始使用:[GitHub 登陆页面](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) +* 请访问我们的DeepSpeed-Chat GitHub页面以开始使用:[GitHub 登陆页面](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) -* 我们将继续根据你的反馈和支持改进 DeepSpeed-Chat。我们的[计划图](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-)显示了当前支持的功能以及计划在未来支持的功能。 +* 我们将继续根据你的反馈和支持改进 DeepSpeed-Chat。我们的[计划图](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-)显示了当前支持的功能以及计划在未来支持的功能。 DeepSpeed-Chat 是更大的DeepSpeed生态系统的一部分,包括众多深度学习系统和建模技术。要了解更多信息, * 请访问我们的[网站](https://www.deepspeed.ai/),了解详细的博客文章、教程和有用的文档。 -* 你还可以关注我们的[英文 Twitter](https://twitter.com/MSFTDeepSpeed) 和[日文 Twitter](https://twitter.com/MSFTDeepSpeedJP),了解 DeepSpeed 的最新动态。我们还将授权开源社KAIYUANSHE微信公众号第一时间发布我们的中文博客。 +* 我们会在[知乎](https://www.zhihu.com/people/deepspeed)上发布最新中文博客及动态。你还可以关注我们的[英文 Twitter](https://twitter.com/DeepSpeedAI) 和[日文 Twitter](https://twitter.com/DeepSpeedAI_JP)。 -DeepSpeed 欢迎你的贡献!我们鼓励你在 [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) 页面报告问题、贡献 PR 并参与讨论。请参阅我们的[贡献指南](https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md)了解更多详情。我们愿意与大学、研究实验室、公司等进行合作,共同开展深度学习研究,将 DeepSpeed 应用于赋能现实世界的 AI 模型和应用等。对于此类需求(以及其他不适合在 GitHub 上提出的需求),请直接发送电子邮件至 deepspeed-info@microsoft.com。 +DeepSpeed 欢迎你的贡献!我们鼓励你在 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 页面报告问题、贡献 PR 并参与讨论。请参阅我们的[贡献指南](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)了解更多详情。我们愿意与大学、研究实验室、公司等进行合作,共同开展深度学习研究,将 DeepSpeed 应用于赋能现实世界的 AI 模型和应用等。对于此类需求(以及其他不适合在 GitHub 上提出的需求),请直接发送电子邮件至 info@deepspeed.ai。 diff --git a/blogs/deepspeed-chat/ds-chat-release-8-31/README.md b/blogs/deepspeed-chat/ds-chat-release-8-31/README.md new file mode 100644 index 000000000000..f511104dc078 --- /dev/null +++ b/blogs/deepspeed-chat/ds-chat-release-8-31/README.md @@ -0,0 +1,352 @@ +
+ +# DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements + +
+ +
+ +
+ +# Table of Contents +1. [Introduction](#introduction) +2. [System Support for Llama and Llama-2 models](#system-support-llama) +3. [Improved Efficiency and Accessibility](#new-features) + - [3.3x Higher Throughput with MixZ++ for LoRA](#mixz) + - [ZeRO-Offload Support for Larger Models with 16x fewer GPUs](#zero-offload) +4. [Stability Bug Fixes](#stability-bug-fixes) +5. [Software Improvements](#software-improvements) + - [Characterization Scripts](#characterization-scripts) + - [Instrumentation](#instrumentation) + - [Testing](#testing) +6. [Try Out DeepSpeed-Chat](#try-out-deepspeed-chat) + + +# 1. Introduction + +DeepSpeed-Chat is a general system framework for RLHF training that enables easy, fast, affordable, and scalable training of ChatGPT-style models that we [publicly released on GitHub](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/README.md). The detailed performance and capabilities of DeepSpeed-Chat have been published in our [blog post](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat) and [arXiv](https://arxiv.org/abs/2308.01320) paper. + +We are happy to share that today we are improving DeepSpeed-Chat along three areas: i) system support for Llama/Llama-2 family of models, ii) system features for improved efficiency and accessibility, and iii) stability and software enhancements. + +- **System support for training Llama and Llama-2 models** + + We ***introduce system support for training Llama and Llama-2 models*** in DeepSpeed-Chat enabling and leveraging various optimizations and features including the Hybrid Engine, ZeRO family of optimizations, Low-Rank Adaptation (LoRA) support, as well as full integration into the three-stage DeepSpeed-Chat RLHF pipeline. By leveraging the Hybrid-Engine, we speed up the experience generation phase for Llama-2-7B and Llama-2-13B models by **up to 7.1X**. + +- **New System Features for Improved Efficiency and Accessibility** + - ***Mixed Precision ZeRO++ ([MixZ++](https://github.com/deepspeedai/DeepSpeed/pull/3954))***. It is an extended set of optimization strategies built upon [ZeRO++](https://www.deepspeed.ai/tutorials/zeropp/) tailored to reduce memory usage and improve training/inference efficiency for RLHF training with LoRA. MixZ++ partitions model parameters across GPUs to reduce footprint and gathers them with quantized communication only when needed similar to its ZeRO and ZeRO++ siblings. Our evaluation indicates MixZ++ increases the training throughput by **up to 3.3x** for the Llama-2-70B model running on 128 V100 GPUs. + + - ***[ZeRO-Offload](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)***. It is an optimization that offloads optimizer memory and computation from the GPU to the host CPU, enabling larger models to be trained with fewer GPU resources. After training stability fixes and testing, we have enabled this feature across all three stages of the DeepSpeed-Chat RLHF training pipeline. ZeRO-Offload reduces the minimum number of GPUs required to train large models by **up to 16x**. + +- **Stability and Software Enhancements** + + - DeepSpeed-Chat contains a rich set of features for training across many different platforms and scenarios. Composing these features in a systematic way and ensuring both system stability and decent training convergence is critical for the usability of the framework. Thus, in addition to new features in DeepSpeed-Chat, many system stability and training convergence issues have been fixed both in DeepSpeed-Chat (client code) and DeepSpeed (runtime). These improvements have been thoroughly tested using the OPT model family for end-to-end training. Furthermore, end-to-end testing, characterization scripts, and several instrumentation features like TensorBoard support are now also available. *To try out these latest features and software improvements, please use DeepSpeed release [v0.10.2](https://github.com/deepspeedai/DeepSpeed/tree/v0.10.2) and the latest DeepSpeed-Chat in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples)*. + + - Finally, to ensure the long-term health of the DeepSpeed-Chat training framework, [PyTests](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/tests/test_training.py) were added for testing Step 3 of the RLHF training pipeline and are run on a nightly basis through a newly developed [GitHub Actions workflow](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml). + +We now dive into the details of our new features, training stability, and software improvements. + +# 2. System Support for Llama and Llama-2 models + +The DeepSpeed-Chat training framework now provides system support for the Llama and Llama-2 models across all three stages of training. To support this, we encountered a spectrum of issues, spanning from minor runtime errors to intricate performance-related challenges. In particular, the Llama model architecture which deviates from the standard Transformers block, was incompatible with DeepSpeed's inference kernels and the DeepSpeed container policy used by the Hybrid Engine. Addressing these hurdles necessitated extensive modifications across our DeepSpeed-Chat pipeline and the DeepSpeed runtime including code to support the ZeRO family of optimizations and their interaction with optimized inference kernels in the Hybrid Engine. We have resolved these challenges to ensure that DeepSpeed-Chat can support Llama and Llama-2 and provide our users with the best possible experience. The details can be seen from several PRs that have been merged in our codebases. + +## Key Supported Optimizations + +The following key optimizations in DeepSpeed are now fully integrated for Llama and Llama-2 models: + +- **DeepSpeed-Chat Integration**: Fully integrated into the complete, end-to-end three-stage DeepSpeed-Chat RLHF training framework, based on the OpenAI InstructGPT training strategy. +- **Hybrid Engine**: DeepSpeed Hybrid Engine allows for superior generation phase [acceleration](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/README.md#throughput-and-model-size-scalability-comparisons-with-existing-rlhf-systems), now supported for all Llama-1 model variants, Llama-2-7B, and Llama-2-13B models. +- **ZeRO and ZeRO-Offload**: Fully supported by the [ZeRO](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/README.md#throughput-and-model-size-scalability-comparisons-with-existing-rlhf-systems) family of optimizations including offload support leveraging full memory capacity of a system thus enabling training of even larger models. +- **Mixed Precision ZeRO++ (MixZ++)**: Enhanced support for larger models like Llama-2-70B through the new MixZ++ feature, improving efficiency and reducing memory usage when there are frozen or non-trainable parameters. +- **LoRA**: Fully supported by the [LoRA](https://github.com/deepspeedai/LoRA) feature, which vastly reduces the storage requirements for large language models by freezing original weights and learning pairs of rank-decomposition matrices. + +## Getting Started + +Users looking to try the new Llama and Llama-2 model support can get started by using the newly added Llama scripts. +| Step Number | Scripts | +| --- | --- | +| 1 | [Llama-2 Step 1 Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2) | +| 2 | [Llama-2 Step 2 Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/llama2) | +| 3 | [Llama-2 Step 3 Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2) | + +*Note*: While all the system aspects of Llama and Llama-2 support have been extensively tested, there are no guarantees about training convergence and may require hyper-parameter tuning to achieve convergence. + +## Performance Evaluation + +We highlight the performance benefits of the Hybrid Engine for Llama-2 models on NVIDIA A100 and V100 GPUs in this section. Improved performance for larger models like Llama-2-70B and reduced resource requirements via ZeRO-Offload are discussed in the [next section](#new-features). + +#### A100 Performance Evaluation +Using A100 GPUs, we achieve 7.1x faster generation for Llama-2-7B and 5.4x faster generation for Llama-2-13B with DeepSpeed-Chat Hybrid Engine compared to DeepSpeed-Chat without Hybrid Engine (baseline) as shown in *Figure 1*. + + +
+ Up to 7.1x faster Llama-2 generation with DS-Chat Hybrid Engine + + *Figure 1: Up to 7.1x faster Llama-2 generation with DS-Chat Hybrid Engine* + +
+ +#### V100 Performance Evaluation +Using V100 GPUs, we achieve 4x faster generation for Llama-2-7B and 2.1x faster generation for Llama-2-13B with DeepSpeed-Chat Hybrid Engine compared to DeepSpeed-Chat without Hybrid Engine (baseline) as shown in *Figure 2*. + +
+ 4x faster Llama-2-7B generation with DS-Chat Hybrid Engine + 2.1x faster Llama-2-13B generation with DS-Chat Hybrid Engine + + *Figure 2: [Left] 4x faster Llama-2-7B generation with DS-Chat Hybrid Engine (16 V100 GPUs) [Right] 2.1x faster Llama-2-13B generation with DS-Chat Hybrid Engine on 32 V100 GPUS vs. DS-Chat without Hybrid Engine on 16 V100 GPUs.* + +
+ + +# 3. Improved Efficiency and Accessibility + +We now dive into the details of two new features we are introducing today: 1) Mixed Precision ZeRO++ (MixZ++) and 2) ZeRO-Offload. Both these features offer unique benefits for DeepSpeed-Chat users. MixZ++ provides up to 3.3x better throughput for LoRA-enabled training and ZeRO-Offload reduces the minimum number of GPUs required to train by up to 16x. + +## 3.3x Higher Throughput with MixZ++ for LoRA + +Mixed Precision ZeRO++ ([MixZ++](https://github.com/deepspeedai/DeepSpeed/pull/3954)) is an extended set of optimization strategies built upon [ZeRO](https://www.deepspeed.ai/tutorials/zero/) and [ZeRO++](https://www.deepspeed.ai/tutorials/zeropp/) tailored to reduce memory usage and improve training/inference efficiency for RLHF training with LoRA. + +Similar to [ZeRO](https://www.deepspeed.ai/tutorials/zero/), MixZ++ partitions model parameters across GPUs to reduce footprint and gathers them only when needed. In addition, similar to ZeRO++, MixZ++ allows for hierarchical partitioning and quantized communication. The hierarchical partitioning allows all the parameters to be stored within a node when possible so that the communication happens within a node, where communication bandwidth is significantly higher than communicating across nodes. The communication overhead is further reduced by quantizing the weights before gathering them. + +Finally, unlike ZeRO++ where parameters are always stored in fp16/bf16, and quantized/dequantized before and after communication, MixZ++ can persistently store the frozen weights in [Low-Rank Adaptation (LoRA)](https://github.com/deepspeedai/LoRA) in lower-precision, significantly reducing the communication overhead, eliminating quantization overhead, and supporting larger batch sizes that enable better efficiency. + +A comprehensive exploration of technical details can be accessed through our [ZeRO++ blog](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/), [MixZ++ tutorial](https://www.deepspeed.ai/tutorials/mixed_precision_zeropp/), and [paper](https://arxiv.org/pdf/2306.10209.pdf). + +#### Highlights + +State-of-the-art approaches like [QLoRA](https://arxiv.org/abs/2305.14314) focus on combining multiple techniques like quantization of LoRA weights, relying on new datatypes such as NF4, and memory-management/offload techniques like paged optimizers to enable finetuning of large models on a single GPU. MixZ++ is our approach to enable large model training powered by quantization but is designed to scale to a large number of GPUs with simplicity and compatibility with existing technologies like ZeRO-Offload and DeepSpeed Hybrid Engine. + +MixZ++ has the following highlights: +- Simplicity: A general solution requiring no assumptions about the model and/or optimizer. Integrating it into your training script is as simple as adding a single line of code. +- Performance: Powered by a set of highly optimized CUDA kernels that enables efficient quantization/dequantization. The evaluation shows up to 3.3x higher throughput for Llama-2-70B training on 128 GPUs compared to the ZeRO-3 baseline (*Figure 3*). +- Compatibility: Compatible with DeepSpeed/ZeRO features like DeepSpeed Hybrid Engine, ZeRO-Offload, etc. +- Scalability: Designed to scale to a large number of GPUs. It is tested on up to 384 GPUs on Azure. + + +#### Performance Evaluation +To assess the effectiveness of MixZ++ for LoRA-enabled training, we carried out a series of RLHF training experiments (Step 3) using the Llama-2-70B model. These experiments were conducted on hardware configurations featuring 64 and 128 V100 GPUs. A visual representation of the experiment results is shown in the following figure: + +
+ Mixed Precision ZeRO++ Evaluation + + *Figure 3: We achieve 3.3x increased throughput for RLHF training of Llama-2-70B on 128 V100 GPUs using Mixed Precision ZeRO++ vs. ZeRO-3. We obsvered 2x improved throughout for the same experiment on 64 V100 GPUs.* + +
+ +Specifically, our results showcase a 2x increase in training throughput when utilizing 64 GPUs with MixZ++, compared to the ZeRO-3 baseline. Furthermore, when scaling up to 128 GPUs, the speedup effect becomes even more pronounced, with a substantial 3.3x improvement in training throughput. These outcomes underscore the potential of MixZ++ as a powerful tool for improving training efficiency in large-scale GPU settings. + +To try this feature, please refer to [MixZ++ tutorial](https://www.deepspeed.ai/tutorials/mixed_precision_zeropp/). + +## ZeRO-Offload Support for Larger Models with 16x fewer GPUs + +[ZeRO-Offload](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) powers unprecedented model sizes by leveraging the full memory capacity of a system, concurrently exploiting all heterogeneous memory. Modern GPU clusters have 2-3x more CPU memory than GPU memory. ZeRO-Offload capitalizes on this disparity and offloads optimizer memory and computation from the GPU to the host CPU, enabling larger models to be trained with fewer GPU resources without being bottlenecked by the CPU's lower bandwidth. ZeRO-Offload allows training of large models on up to 16x fewer GPUs as we can see in *Figure 4*. + +
+ + + *Figure 4: ZeRO-Offload enables us to train Llama-2-7B with 16x fewer GPUs. 16 V100 GPUs are required for training Llama-2-7B with DS-Chat ZeRO-3. Enabling LoRA allows for the number of GPUs to be reduced to 4 while enabling ZeRO-Offload reduces the number of needed GPUs to 1. The HuggingFace Baseline does not run due to memory limitations.* + +
+ +ZeRO-Offload was [disabled](https://github.com/deepspeedai/DeepSpeedExamples/pull/553) + with the initial release of DeepSpeed-Chat due to training instability that was observed when it was used with Hybrid Engine and LoRA. After improvements to Hybrid Engine and LoRA as well as extensive testing of all feature configurations for ZeRO Stage2 and ZeRO Stage 3, this feature can now be enabled across all three steps of the DeepSpeed-Chat training framework. Please note that configuring ZeRO-Offload with ZeRO Stage 2 and Hybrid Engine with LoRA disabled is currently unsupported due to observed training instability. + +
+ + + *Figure 5: Reward scores for all supported DeepSpeed-Chat configurations with ZeRO-Offload enabled. Run with 16 V100 GPUs, [AdamG012/chat-opt-1.3b-sft-deepspeed](https://huggingface.co/AdamG012/chat-opt-1.3b-sft-deepspeed) actor model, [AdamG012/chat-opt-350m-reward-deepspeed](https://huggingface.co/AdamG012/chat-opt-350m-reward-deepspeed) critic model, DS commit: [f036f00c](https://github.com/deepspeedai/DeepSpeed/tree/f036f00c3763694e539a9070a98130e2667e49bd), DSE commit: [81a8521f](https://github.com/deepspeedai/DeepSpeedExamples/tree/81a8521f05e2761eed34fcf65f19873df9f74403).* + +
+ +# 4. Stability Bug Fixes + +A wide range of issues have been addressed in the DeepSpeed runtime and the DeepSpeed-Chat pipeline. These fixes enable advanced features such as Hybrid Engine, LoRA, and ZeRO-Offload to run across all training steps of the DeepSpeed-Chat pipeline and improve training stability and convergence. + +
+ + + *Figure 6: Step 3 Reward Scores for all supported DeepSpeed-Chat configurations. Run with 16 V100 GPUs, [AdamG012/chat-opt-1.3b-sft-deepspeed](https://huggingface.co/AdamG012/chat-opt-1.3b-sft-deepspeed) actor model, [AdamG012/chat-opt-350m-reward-deepspeed](https://huggingface.co/AdamG012/chat-opt-350m-reward-deepspeed) critic model, DS commit: [f036f00c](https://github.com/deepspeedai/DeepSpeed/tree/f036f00c3763694e539a9070a98130e2667e49bd), DSE commit: [81a8521f](https://github.com/deepspeedai/DeepSpeedExamples/tree/81a8521f05e2761eed34fcf65f19873df9f74403).* + +
+ +*Figure 6* above shows the training convergence across all supported DeepSpeed-Chat configurations. This data was collected using 16 V100 NVIDIA GPUs, the [AdamG012/chat-opt-1.3b-sft-deepspeed](https://huggingface.co/AdamG012/chat-opt-1.3b-sft-deepspeed) OPT model as the actor, the [AdamG012/chat-opt-350m-reward-deepspeed](https://huggingface.co/AdamG012/chat-opt-350m-reward-deepspeed) OPT model as the critic, and the following DeepSpeed and DeepSpeedExamples repository commits: DS commit: [f036f00c](https://github.com/deepspeedai/DeepSpeed/tree/f036f00c3763694e539a9070a98130e2667e49bd), DSE commit: [81a8521f](https://github.com/deepspeedai/DeepSpeedExamples/tree/81a8521f05e2761eed34fcf65f19873df9f74403). + +We now dive into the details of all the fixes across different areas. + +## DeepSpeed-Chat Pipeline Fixes + +In this section we discuss the functionality and training stability fixes in the DeepSpeed-Chat pipeline. + +- **Training Stability:** + + - [PR #620 - Make training more stable](https://github.com/deepspeedai/DeepSpeedExamples/pull/620) + + - To improve the training stability in Step 3, several different areas of training were tuned and changed. To start, the Kullback-Liebler (KL) divergence used in the Proximal Policy Optimization (PPO) trainer was slightly tuned to reduce divergence between the new and reference policies and improve the reward score. Next, the sequence generation function in the PPO trainer (`_generate_sequence()`) removed the specification of a `min_length` in the Actor model's `generate()` call, which means generated sequences won't be artificially enlarged, allowing for the possibility of sequence generation to collapse i.e. when training convergence is extremely poor. A minor off-by-one error was also fixed in the PPO trainer's reward computation function (`compute_rewards()`). Finally, the PPO trainer's RLHF training function was updated to zero out the reward and value after the end of a conversation to prevent incorrect `advantages` and `returns`. + + - [PR #633 - DS Chat Step 3 - Add separate Lora Adam optimizer group](https://github.com/deepspeedai/DeepSpeedExamples/pull/633) + + - The [LoRA](https://github.com/deepspeedai/LoRA) feature is supported across all three training steps of the DeepSpeed-Chat framework. Prior to this stability effort, there was no distinction between the overall learning rate and the LoRA learning rate i.e. the LoRA learning rate was set to whatever the overall learning rate was. This led to instability in training convergence and can be seen in *Figure 7* below showing the reward score across training steps for various Step 3 configurations: + +
+ + + *Figure 7: Before the fix, the sweep across all ZeRO-2 cases without a separate LoRA learning rate shows training instability when LoRA is used.* + +
+ + To address this training convergence issue, when creating the optimizer grouped parameters, the LoRA `lora_right_weight` and `lora_left_weight` parameters were explicitly separated out and given their own LoRA-specific learning rate. After this change, a dramatic improvement in stability was observed, as shown in the figure below: + +
+ + + *Figure 8: After creating a separate LoRA learning rate, the sweep across all ZeRO-2 cases shows proper convergence.* + +
+ + The next fix details the addition of separate LoRA learning rate arguments. + + - [PR ##685 Add LoRA LR for DS Chat steps 1-3](https://github.com/deepspeedai/DeepSpeedExamples/pull/685) + + - A *separate* LoRA learning rate argument can now be provided in each of the three training steps, with Step 3 having individual LoRA learning rates for the Actor and Critic models. + +- **Bug Fixes:** + + - [PR #636 - DS Chat Step 3 - Fix Zero Stage 3](https://github.com/deepspeedai/DeepSpeedExamples/pull/636) + + - During DeepSpeed-Chat Step 3 training, we observed hangs when ZeRO Stage 3 was enabled for the actor model and when the `world_size > 1`. When observing the state of each rank, one rank would still be in the sequence generation phase `self._generate_sequence()`, while the other rank had already progressed to the `self.actor_model()` call. This ZeRO Stage 3 desynchronization, due to misaligned token generation between the GPUs, can normally be automatically detected and accounted for in the HuggingFace Transformers library via `synced_gpus`. However, due to the nature of the DeepSpeed-Chat pipeline and the lifetime of the corresponding model configuration objects, this automatic detection code was not triggered. To resolve this, when invoking the `generate()` function, the `synced_gpus` argument is explicitly passed and set to `True` when ZeRO Stage 3 is being used. + + - [PR #658 - Fix only optimize lora and ack-ckpting compatible](https://github.com/deepspeedai/DeepSpeedExamples/pull/658) + + - This fix allows Step 3 training to run with the combination of gradient checkpointing and *LoRA-only* parameter optimization, a previously unsupported training case. With the addition of the [enable_input_require_grads](https://github.com/huggingface/transformers/blob/f26099e7b5cf579f99a42bab6ddd371bf2c8d548/src/transformers/modeling_utils.py#L1225) model utility function in the HuggingFace Transformers library, which enables the gradients for the input embeddings, gradient checkpointing and optimization of *only* the LoRA parameters is made possible. + + - [PR #576 - Fix argparse](https://github.com/deepspeedai/DeepSpeedExamples/pull/576) + + - An external contributor helped in resolving an argument parsing issue. + + - [PR #584 - Fix unused parameter bug](https://github.com/deepspeedai/DeepSpeedExamples/pull/584) + + - An external contributor fixed the passing of an uninitialized parameter that was hardcoded earlier. + + +## Hybrid Engine Fixes +In this section we discuss several fixes in the Hybrid Engine. + +- [PR #3563 - Fix LoRA Fuse/Unfuse in Hybrid Engine](https://github.com/deepspeedai/DeepSpeed/pull/3563) + + - During Step 3 training for OPT with LoRA and Hybrid Engine enabled, an issue arose regarding a tensor size mismatch of the LoRA weights. Specifically, the LoRA QKV weights were not fused in the OPT container policy, yet they were expected to be fused by the Hybrid Engine. This challenge was effectively resolved by introducing both fused and unfused LoRA methods in the Hybrid Engine. We thank @sxjscience for providing this fix. + +- [PR #3883 - Extend HE-Lora test with Z3 support + Fix/add guard in HE for Z3](https://github.com/deepspeedai/DeepSpeed/pull/3883) + + - The Hybrid Engine was updated to properly check whether ZeRO Stage 3 was enabled when resetting the inference container parameters, along with expanding the corresponding unit tests. + + +## ZeRO Stage 3 Fixes +In this section we discuss several fixes in support of the ZeRO Stage 3 feature. + +- [PR #3819 - Fix racing condition in GatheredParameters](https://github.com/deepspeedai/DeepSpeed/pull/3819) + + - A race condition in the the ZeRO `GatheredParameters` context, which resulted in various `'status': 'INFLIGHT'` issues, was fixed by removing duplicate input parameters that were being passed from the Hybrid Engine. + +- [PR #3884 - Separate ZeRO3 InflightParamRegistry for train and eval](https://github.com/deepspeedai/DeepSpeed/pull/3884) + + - The ZeRO Stage 3 `InflightParamRegistry` was updated to use a separate `InflightParamRegistry` for training and evaluation, fixing an issue where leftover parameters in flight were causing inflight parameter errors. These fixes, along with related fixes in the Hybrid Engine, enabled the use of the ZeRO-Offload feature in the DeepSpeed-Chat training pipeline. + +- [PR #3928 - Remove the param.ds_tensor from print](https://github.com/deepspeedai/DeepSpeed/pull/3928) + + - A minor change that was necessary to address the DeepSpeed-Chat Step 3 hang issue ([PR #636](https://github.com/deepspeedai/DeepSpeedExamples/pull/636)) as it allowed us to progress further into execution and observe the desynchronization point. + + +# 5. Software Improvements + +To improve the characterization, ease of debug, and maintainability of the DeepSpeed-Chat framework, several areas of software improvements have been completed. Characterization scripts were added to enable systematic composition of features, instrumentation was added to improve insight into the behavior of training, and a testing CI workflow was added to improve the maintainability of the DeepSpeed-Chat training framework. + +## Characterization Scripts + +The DeepSpeed-Chat training framework provides a rich set of features (Hybrid Engine, ZeRO, LoRA, etc.) that can be composed in many different combinations, depending on the scenario. The interactions between the features are often complex and composing them in a systematic way for characterization is useful for understanding their behavior. To support such use cases, characterization scripts have been added to run sweeps of Steps 1, 2, and 3 training for various combinations of features. The scripts default to OPT but can be modified to run with Llama. Please see the READMEs in the following folders for more details: + +- [Step 1 Sweep Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/opt/single_node/sweep) +- [Step 2 Sweep Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/sweep) +- [Step 3 Sweep Scripts](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep) + +For example, the Step 3 characterization script sweeps across various training features: +| Feature | Values | +| --- | --- | +| ZeRO Stage | 2, 3 | +| Hybrid Engine | True, False | +| ZeRO-Offload | True, False | +| LoRA | True, False | + +And can be ran as follows: + +
+DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning$ bash training_scripts/opt/single_node/sweep/run_step3_sweep.sh
+
+ +The training log for each combination of features will be stored in a folder with the name formatted as `z${z}_he_${he}_offload_${offload}_lora_${lora}` + + +Related PRs: + +- [DS Chat Characterization Scripts (Step 1 and 3)](https://github.com/deepspeedai/DeepSpeedExamples/pull/638) +- [Add step 2 sweep script, clean up scripts](https://github.com/deepspeedai/DeepSpeedExamples/pull/664) +- [Update script location and docs for all 3 steps](https://github.com/deepspeedai/DeepSpeedExamples/pull/681) + +## Instrumentation + +To gain better insight into DeepSpeed-Chat training, new [instrumentation features](https://github.com/deepspeedai/DeepSpeedExamples/pull/624) were added across all three steps of DeepSpeed-Chat and can be enabled via arguments to each step's `main.py`. + +| Argument | Description | Step(s) | +| --- | --- | --- | +| --print_loss | Print loss during each step | 1 | +| --enable_tensorboard | Enable TensorBoard logging at the model Runtime Engine level | 1,2,3 | +| | Enable TensorBoard logging at the Training Pipeline level | 3 | +| --tensorboard_path | Path to write TensorBoard log | 1,2,3 | +| --print_answers | Print actor model prompt and answers during training across all ranks | 3 | + + +### TensorBoard +TensorBoard logging can be enabled in each of the three training steps, with some slight nuances in Step 3. To start, for each training step, the `enable_tensorboard` argument can be used to enable a TensorBoard monitor at the Runtime Engine level ([see documentation](https://www.deepspeed.ai/docs/config-json/#monitoring-module-tensorboard-wandb-csv)) and is reflected in the corresponding model training configuration: +```python +"tensorboard": { + "enabled": enable_tensorboard, + "output_path": f"{tb_path}/ds_tensorboard_logs/", + "job_name": f"{tb_name}_tensorboard" +} +``` + +- **Step 3**: + Due to Step 3 initializing both an Actor and a Critic model, _each_ of the models will have their own corresponding TensorBoard monitor at the Runtime Engine level. Beyond that, Step 3 training also contains a Pipeline-level TensorBoard monitor a level above the model runtime engines, which captures the `reward`, `actor_loss`, `actor_loss_sum`, `critic_loss`, and `critic_loss_sum`. + +## Testing + +As part of the DeepSpeed team's commitment to maintaining the DeepSpeed-Chat training framework, continuous integration [PyTest](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/tests/test_training.py) testing has been added for Step 3 RLHF training in a new [GitHub Actions workflow](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml). + +| Description | Status | +| ----------- | ------ | +| Integrations | [![nv-ds-chat](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml) | + + The workflow is run on a **nightly** basis across a **16-case** test matrix (see table below), and uses the **facebook/opt-125m** model for both the actor and critic. + +| Parameter | Values | +| --- | --- | +| ZeRO Stage | 2, 3 | +| Hybrid Engine | True, False | +| ZeRO-Offload | True, False | +| LoRA | True, False | + +Each configuration (16 total) runs through a limited number of Step 3 non-overflow training steps (i.e. steps where neither actor nor critic overflow) and saves the actor/critic models. Assertions are used to check if the training pipeline executed correctly and if the actor and critic models were saved properly. + +# 6. Try Out DeepSpeed-Chat +We are very excited to share this DeepSpeed-Chat feature and stability release. + +* To get started, please visit our GitHub page for DeepSpeed-Chat: [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) + +* We will continue to improve DeepSpeed-Chat with your feedback and support. Our [roadmap](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-) shows currently supported features as well as ones that are planned for future. + +DeepSpeed-Chat is part of the bigger DeepSpeed ecosystem comprising of a multitude of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* You can also follow us on our [English Twitter](https://twitter.com/DeepSpeedAI), [Japanese Twitter](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. + +* "Star" our [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) and [DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/) repositories if you like our work! diff --git a/blogs/deepspeed-chat/japanese/README.md b/blogs/deepspeed-chat/japanese/README.md index 66ff9d9c90cf..b6cfb2c1496b 100644 --- a/blogs/deepspeed-chat/japanese/README.md +++ b/blogs/deepspeed-chat/japanese/README.md @@ -10,6 +10,17 @@
+DeepSpeed Chat を引用するには、こちらの[arxiv report](https://arxiv.org/abs/2308.01320)を引用してください: + +``` +@article{yao2023dschat, + title={{DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales}}, + author={Zhewei Yao and Reza Yazdani Aminabadi and Olatunji Ruwase and Samyam Rajbhandari and Xiaoxia Wu and Ammar Ahmad Awan and Jeff Rasley and Minjia Zhang and Conglong Li and Connor Holmes and Zhongzhu Zhou and Michael Wyatt and Molly Smith and Lev Kurilenko and Heyang Qin and Masahiro Tanaka and Shuai Che and Shuaiwen Leon Song and Yuxiong He}, + journal={arXiv preprint arXiv:2308.01320}, + year={2023} +} +``` + # 1. 概要 ChatGPT(チャットGPT)やその類似モデルは、AIの世界に旋風を巻き起こし、デジタル業界に革命的な影響を与えています。これらのモデルは非常に汎用性が高く、要約、コーディング、翻訳などの多様なタスクを、人間の専門家と同等か、それ以上の結果で実施できます。その圧倒的な性能を受けて、AI関連のオープンソースコミュニティでは、ChatGPTスタイルのモデルをより利用しやすくするための複数の取り組みが始まっています(ChatLLaMa、Alpaca、Vicuna、Databricks-Dollyなど)。 @@ -18,7 +29,7 @@ ChatGPT(チャットGPT)やその類似モデルは、AIの世界に旋風 ChatGPTの訓練に用いられるInstructGPTにおいて提案されたRLHFでは、これまでの標準的な事前学習やファインチューニングと全く異なり、はるかに複雑なパイプラインが必要となります。従来のソフトウェアでは、そうしたパイプラインが効果的にサポートする仕組みがありませんでした。そこで、RLHFの訓練を広くAIコミュニティで利用可能とし、ChatGPTのようなモデルを誰もが作成できるにするため、以下の機能を備えたDeepSpeed-Chatをリリースすることになりました。 -(i) ***容易に実施可能なChatGPTライクなモデルの訓練と推論***: Huggingfaceレポジトリで提供されている学習済みモデルから開始して、InstructGPT学習の全3ステップを実行し、独自のChatGPTライクなモデルを生成できるスクリプトを提供します。また、学習後の会話形式のインタラクションをテストするための推論APIを提供します。 +(i) ***容易に実施可能なChatGPTライクなモデルの訓練と推論***: Hugging Faceレポジトリで提供されている学習済みモデルから開始して、InstructGPT学習の全3ステップを実行し、独自のChatGPTライクなモデルを生成できるスクリプトを提供します。また、学習後の会話形式のインタラクションをテストするための推論APIを提供します。 (ii) ***DeepSpeed-RLHF パイプライン***: DeepSpeed-RLHFパイプラインは、InstructGPTの学習パイプラインの3つのステップ a) 教師付きファインチューニング (Supervised fine-tuning, SFT), b) 報酬モデルのファインチューニング, c) RLHF (Reinforcement Learning with Human Feedback) を、包括的に、かつ1対1の対応を保って再現するものです。また、複数のデータソースからの同時学習を可能にするために、学習データの抽象化・ブレンド機能を提供します。 @@ -51,7 +62,7 @@ DeepSpeed-RLHFシステムは、大規模モデルの学習において類を見 *表2. 複数ノード(64x A100-80GB)を用いた場合の訓練時間とAzureでの概算実行コスト*
-> ***注意事項***: 上記の2つの表の数値は、訓練のステージ3のものです。DeepSpeed-RLHFが用いるデータセットと訓練の設定において、合計1.35億トークンを1エポックで訓練した際のスループットの実測値に基づいています。合計6750万のクエリートークン(配列長256の13万件のクエリー)と6750万の生成トークン(配列長256の13万件の回答)があり、ステップごとの最大グローバルバッチサイズは 50万 トークン(クエリーと回答それぞれ1024件)です。DeepSpeedRLHFを用いた場合のコストおよび実行時間の比較にあたっては、これらの詳細をよくご確認ください。さらに詳細な情報は[ベンチマーク設定](https://github.com/microsoft/DeepSpeedExamples/blob/staging-deepspeed-chat-v2/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)を参照ください。 +> ***注意事項***: 上記の2つの表の数値は、訓練のステージ3のものです。DeepSpeed-RLHFが用いるデータセットと訓練の設定において、合計1.35億トークンを1エポックで訓練した際のスループットの実測値に基づいています。合計6750万のクエリートークン(配列長256の13万件のクエリー)と6750万の生成トークン(配列長256の13万件の回答)があり、ステップごとの最大グローバルバッチサイズは 50万 トークン(クエリーと回答それぞれ1024件)です。DeepSpeedRLHFを用いた場合のコストおよび実行時間の比較にあたっては、これらの詳細をよくご確認ください。さらに詳細な情報は[ベンチマーク設定](https://github.com/deepspeedai/DeepSpeedExamples/blob/staging-deepspeed-chat-v2/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)を参照ください。 ***RLHFを誰もが利用できるように***: DeepSpeed-HEは、1台のGPUのみで130億以上のパラメーターを持つモデルの訓練を実行できます。複数のGPUを備えた高価な計算設備を持たないデータサイエンティストも、小規模なトイモデルではなく、実際のシナリオで使用できる大規模で強力なRLHFモデルを作成できます。 @@ -79,15 +90,13 @@ DeepSpeed-RLHFシステムは、大規模モデルの学習において類を見 以下のスクリプトを実行すると、最終的に130億パラメータのChatGPTライクなモデルが生成されます。 ```python -git clone https://github.com/microsoft/DeepSpeed.git -cd DeepSpeed -pip install . +pip install deepspeed>=0.9.0 -git clone https://github.com/microsoft/DeepSpeedExamples.git +git clone https://github.com/deepspeedai/DeepSpeedExamples.git cd DeepSpeedExamples/applications/DeepSpeed-Chat/ pip install -r requirements.txt -python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --num-gpus 8 +python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --deployment-type single_node ``` 8台のA100を備えたDGXサーバを使う場合、130億パラメータのモデルを半日で訓練できます。以下は各ステップに要する時間の内訳です。 @@ -121,7 +130,7 @@ Assistant:  Sure, I can try.  Microsoft is a company that makes computers, a 訓練の時間、計算機資源、および品質の要件を満たすために、さまざまなモデルのサイズや構成を試す必要があることがあります。DeepSpeed-Chatを使用すれば、簡単にそれが可能です。例えば、研究やビジネスのために、GPUクラスタでより大規模で高品質なモデルを訓練したい場合、希望するモデルサイズとGPU数を指定するだけです。以下は、アクターモデルのパラメータ数を66Bに、GPU数を64に指定する例です。 ```python -python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --num-gpus 64 +python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --deployment-type multi_node ``` 64台のA100 (80GBメモリ) GPUを使用する場合、9時間で660億パラメータのChatGPTモデルを訓練できます。 @@ -139,7 +148,7 @@ python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m 1~2時間のコーヒータイムや昼休みに、DeepSpeed-Chatで小規模なトイモデルをトレーニングしてみるのも良いでしょう。例えば、コンシューマグレードのGPUでの訓練を動かしてみるため、1つのデータセットで1.3Bのモデルを訓練する例を用意しました。これなら、昼休みから戻ったときに、できあがったモデルのチェックポイントを試してみることができます。 ```python -python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --num-gpus 1 +python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu ```
@@ -235,12 +244,12 @@ DeepSpeed-HEは、訓練と推論の両方で、モデルの分割をシーム ## 既存のRLHFシステムとのスループットとモデルサイズのスケーラビリティ比較 -DeepSpeed-RLHFは、Colossal-AIや、ネイティブのPyTorchを用いたHuggingFaceなどの他のRLHFを訓練可能なシステムと比較して、実行速度とスケーラビリティの両方で優れています。 +DeepSpeed-RLHFは、Colossal-AIや、ネイティブのPyTorchを用いたHugging Faceなどの他のRLHFを訓練可能なシステムと比較して、実行速度とスケーラビリティの両方で優れています。 -* スループットに関しては、DeepSpeedは単一GPUでのRLHFトレーニングで10倍以上の向上を実現しています(図3)。複数GPU環境では、Colossal-AIと比較して6~19倍、HuggingFace DDPと比較して1.4~10.5倍のスピードアップを実現しています(図4)。 +* スループットに関しては、DeepSpeedは単一GPUでのRLHFトレーニングで10倍以上の向上を実現しています(図3)。複数GPU環境では、Colossal-AIと比較して6~19倍、Hugging Face DDPと比較して1.4~10.5倍のスピードアップを実現しています(図4)。 * モデルのスケーラビリティに関しては、Colossal-AIが最大で1.3Bのモデルを単一GPUで、6.7BのモデルをA100-40Gを備えた単一のノードで訓練できますが、DeepSpeed-HEは同じハードウェアでそれぞれ6.5Bと50Bのサイズのモデルを訓練できます。これは、最大で7.5倍のモデルサイズを扱えることになります。 -したがって、DeepSpeed-HEは、Colossal-AIやHuggingFace DDPなどの既存のRLHFシステムと比較して、1桁以上高いスループットを実現しており、同じ実行時間ではるかに大きなアクターモデルを訓練したり、10倍以上低いコストで同様のサイズのモデルを訓練することができます。 +したがって、DeepSpeed-HEは、Colossal-AIやHugging Face DDPなどの既存のRLHFシステムと比較して、1桁以上高いスループットを実現しており、同じ実行時間ではるかに大きなアクターモデルを訓練したり、10倍以上低いコストで同様のサイズのモデルを訓練することができます。
@@ -258,7 +267,7 @@ DeepSpeed-RLHFは、Colossal-AIや、ネイティブのPyTorchを用いたHuggin
-この効率化は、DeepSpeed-HEが、DeepSpeedの高度に最適化された推論機能を活用して、RLHF処理の生成フェーズを高速化したことに起因しています。図5は、1.3BパラメータモデルのRLHF訓練の時間内訳を示したもので、時間の大半は生成フェーズに費やされていることが分かります。DeepSpeedの高性能な推論カーネルを活用することで、DeepSpeed-HEはこのフェーズでHuggingFaceの9倍、Colossal-AIの15倍のスループット向上を達成し、end-to-endの類を見ない効率化を実現しています。 +この効率化は、DeepSpeed-HEが、DeepSpeedの高度に最適化された推論機能を活用して、RLHF処理の生成フェーズを高速化したことに起因しています。図5は、1.3BパラメータモデルのRLHF訓練の時間内訳を示したもので、時間の大半は生成フェーズに費やされていることが分かります。DeepSpeedの高性能な推論カーネルを活用することで、DeepSpeed-HEはこのフェーズでHugging Faceの9倍、Colossal-AIの15倍のスループット向上を達成し、end-to-endの類を見ない効率化を実現しています。
@@ -270,7 +279,7 @@ DeepSpeed-RLHFは、Colossal-AIや、ネイティブのPyTorchを用いたHuggin ## 実効スループットとスケーラビリティ -***(I) 実効スループット分析*** RLHFのステージ3におけるDeepSpeed-HEの実効スループットは、生成フェーズと強化学習の訓練フェーズの両方のスループットで決まります。我々の作成したRLHFのパイプラインでは、生成フェーズが全計算量の約20%を占め、強化学習の訓練フェーズが残りの80%を占めています(詳細は[ベンチマークのページ](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)を参照)。しかし、計算量で見た割合が少ないとはいえ、前者は生成された256個のトークンのそれぞれに対して、初期プロンプトの256個のトークンに対してアクターモデルによる推論をそれぞれ1回実行する必要があるため、end-to-endの時間で見ると、その大部分を占めることになり、メモリ帯域が制限されて高いスループットを得ることが難しくなります。一方、強化学習の訓練フェーズでは、1サンプルあたりプロンプトと生成の両方から512個のトークンをフルに使用して、参照アクターモデルについて、数回のフォワードパスとバックワードパスで実行できるため、高いスループットを達成できます。 +***(I) 実効スループット分析*** RLHFのステージ3におけるDeepSpeed-HEの実効スループットは、生成フェーズと強化学習の訓練フェーズの両方のスループットで決まります。我々の作成したRLHFのパイプラインでは、生成フェーズが全計算量の約20%を占め、強化学習の訓練フェーズが残りの80%を占めています(詳細は[ベンチマークのページ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/BenckmarkSetting.md)を参照)。しかし、計算量で見た割合が少ないとはいえ、前者は生成された256個のトークンのそれぞれに対して、初期プロンプトの256個のトークンに対してアクターモデルによる推論をそれぞれ1回実行する必要があるため、end-to-endの時間で見ると、その大部分を占めることになり、メモリ帯域が制限されて高いスループットを得ることが難しくなります。一方、強化学習の訓練フェーズでは、1サンプルあたりプロンプトと生成の両方から512個のトークンをフルに使用して、参照アクターモデルについて、数回のフォワードパスとバックワードパスで実行できるため、高いスループットを達成できます。
@@ -296,7 +305,7 @@ DeepSpeed-RLHFは、Colossal-AIや、ネイティブのPyTorchを用いたHuggin ***(II) スケーラビリティ分析*** モデルサイズごとに、最良のスループットを得られるGPU数は異なります。これは、モデルサイズが大きくなると、実行に多くのメモリを必要とすることに加え、以下に説明する DeepSpeed-HE のスケーラビリティ特性にも起因しています。 -図7は、DeepSeed-RLHF が最大 64 GPU で全体的に良好なスケーラビリティを達成したことを示しています。しかし、より詳細に見ると、DeepSpeed-RLHFの訓練では、小規模な環境では超線形(super linear)なスケーリングを達成し、大規模では線形(linear)またはそれ以下のスケーラビリティになっていることが分かります。これは、メモリの可用性と最大グローバルバッチサイズとの間の相互作用によるものです。 +図7は、DeepSpeed-RLHF が最大 64 GPU で全体的に良好なスケーラビリティを達成したことを示しています。しかし、より詳細に見ると、DeepSpeed-RLHFの訓練では、小規模な環境では超線形(super linear)なスケーリングを達成し、大規模では線形(linear)またはそれ以下のスケーラビリティになっていることが分かります。これは、メモリの可用性と最大グローバルバッチサイズとの間の相互作用によるものです。 DeepSpeed-HEはトレーニングにZeROの技術を採用しているため、利用可能なGPU間でモデルを分割することが可能です。その結果、GPUあたりのメモリ消費量はGPU数の増加とともに減少し、DeepSpeed-HEはGPUあたりでより大きなバッチサイズをサポートできるようになり、超線形のスケーリングが実現できます。しかし、より大規模になると、利用可能なメモリが増加し続ける一方で、最大グローバルバッチサイズが制限されているため、GPUあたりのバッチサイズを小さくすることになり、線形またはそれ以下のスケーリングになります。その結果、与えられた最大グローバルバッチサイズに対して、DeepSpeed-HEは、スーパーリニアとサブリニアのスケーラビリティの境界で最高のスループットとコスト効率を達成し、正確なポイントは、利用可能なメモリとグローバルバッチサイズの関数としてGPUごとに実行できる最大バッチサイズによってほぼ決定されます。 @@ -305,8 +314,8 @@ DeepSpeed-HEはトレーニングにZeROの技術を採用しているため、 DeepSpeed-ChatをオープンソースソフトウェアとしてAIコミュニティに公開できることを嬉しく思います。 -* DeepSpeed-Chatの[GitHubページ](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat)を見て、早速使い始めましょう。 -* ユーザのみなさまからのフィードバックと協力で、これからも継続的に DeepSpeed-Chat を改善していく予定です。現在サポートされている機能や、将来的にサポートされている機能については、[ロードマップ](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-)をご覧ください。 +* DeepSpeed-Chatの[GitHubページ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat)を見て、早速使い始めましょう。 +* ユーザのみなさまからのフィードバックと協力で、これからも継続的に DeepSpeed-Chat を改善していく予定です。現在サポートされている機能や、将来的にサポートされている機能については、[ロードマップ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/README.md#-deepspeed-chats-roadmap-)をご覧ください。 # 7. DeepSpeedについて @@ -323,14 +332,14 @@ DeepSpeedは、以下のような機能を提供します。 DeepSpeedは、Microsoftの[AI at Scale initiative](https://www.microsoft.com/en-us/research/project/ai-at-scale/)の一部で、次世代AIの機能の大規模な実現を進めています。詳細は[こちら](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)をご覧ください。DeepSpeedは、[Megatron-Turing NLG (530B)](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/), [Jurassic-1 (178B)](https://uploads-ssl.webflow.com/60fd4503684b466578c0d307/61138924626a6981ee09caf6_jurassic_tech_paper.pdf), [BLOOM (176B)](https://huggingface.co/blog/bloom-megatron-deepspeed), [GLM (130B)](https://github.com/THUDM/GLM-130B), [YaLM (100B)](https://github.com/yandex/YaLM-100B) を含め、様々な大規模モデルを学習するのに使用されてきました。 -またDeepSpeedは、 [Hugging Face Transformers](https://huggingface.co/docs/transformers/main/main_classes/deepspeed), [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/usage_guides/deepspeed), [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html), [MosaicML Composer](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration), [Determined AI](https://docs.determined.ai/latest/training/apis-howto/deepspeed/overview.html) など、多くの著名なオープンソースの深層学習フレームワークのバックエンドとして利用されています。 +またDeepSpeedは、 [Hugging Face Transformers](https://huggingface.co/docs/transformers/deepspeed), [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/usage_guides/deepspeed), [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html), [MosaicML Composer](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration), [Determined AI](https://docs.determined.ai/latest/training/apis-howto/deepspeed/overview.html) など、多くの著名なオープンソースの深層学習フレームワークのバックエンドとして利用されています。 DeepSpeedについてのより詳しい情報は、以下をご覧ください。 * [DeepSpeedのWebサイト](https://www.deepspeed.ai/) には、DeepSpeedの技術に関する詳細なブログ記事、チュートリアル、ドキュメントなどが掲載されています。 -* [DeepSpeedのTwitterアカウント (英語)](https://twitter.com/MSFTDeepSpeed) では、DeepSpeedの最新情報を発信していますので、ぜひフォローください。[日本語版のTwitterアカウント](https://twitter.com/MSFTDeepSpeedJP)もあり、最新の情報を日本語で発信しています。 +* [DeepSpeedのTwitterアカウント (英語)](https://twitter.com/DeepSpeedAI) では、DeepSpeedの最新情報を発信していますので、ぜひフォローください。[日本語版のTwitterアカウント](https://twitter.com/DeepSpeedAI_JP)もあり、最新の情報を日本語で発信しています。 DeepSpeedチームは、ユーザの方々からのフィードバックやご連絡を受け付けています。 -* ユーザのみなさまからのバグ報告、Pull request、さまざまな議論への参加は、[GitHub](https://github.com/microsoft/DeepSpeed/)で受け付けています。詳細については、[ガイドライン](https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md)を確認してください。 -* DeepSpeedチームでは、DeepSpeedを用いた深層学習の研究や実世界へのAIモデルやアプリケーションに関して、大学、研究所、企業との方々とのコラボレーションを行っています(日本語でコミュニケーション可能な研究員も在籍しています)。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については、deepspeed-info@microsoft.com まで直接メールをお送りください。 +* ユーザのみなさまからのバグ報告、Pull request、さまざまな議論への参加は、[GitHub](https://github.com/deepspeedai/DeepSpeed/)で受け付けています。詳細については、[ガイドライン](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)を確認してください。 +* DeepSpeedチームでは、DeepSpeedを用いた深層学習の研究や実世界へのAIモデルやアプリケーションに関して、大学、研究所、企業との方々とのコラボレーションを行っています(日本語でコミュニケーション可能な研究員も在籍しています)。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については、info@deepspeed.ai まで直接メールをお送りください。 diff --git a/blogs/deepspeed-domino/README.md b/blogs/deepspeed-domino/README.md new file mode 100644 index 000000000000..ce190ed1e459 --- /dev/null +++ b/blogs/deepspeed-domino/README.md @@ -0,0 +1,199 @@ +

+ domino logo +

+ +
+ +# Domino: Communication-Free LLM Training Engine + +
+ +
+ + +
+ +*Figure 1: Project Domino is Microsoft DeepSpeed's Tensor Parallel (TP) Training Engine, which provides a uniform solution for both single-node and **multi-node** cases. Domino scales up traditional single-node-only TP solution to multi-node environments via **near-complete communication hiding** behind computation.* + +
+

+ +# Table of Content +1. [Introduction](#introduction) +2. [Domino Highlights](#domino-highlights) +3. [Design Motivation](#design-motivation) +4. [Domino Design](#domino-design) +5. [Implementation and Optimization](#implementation-and-optimization) +6. [Getting Started: Try out DeepSpeed-Domino](#getting-started-try-out-deepspeed-domino) +7. [Citation](#citation) +8. [Acknowledgements](#acknowledgements) + + +# Introduction + +Generative AI (GenAI) has enabled transformative applications in a wide variety of domains, including chatbot, text summarization, and high-quality image and video generation. These capabilities are built on top of large foundation models, particularly Large Language Models (LLMs). LLMs are typically based on the [Transformer](https://arxiv.org/abs/1706.03762) network architecture, and include popular model families such as GPT and Llama. LLMs have grown beyond the memory capacity of a single accelerator (e.g., GPU), and so inferencing or training them requires distributed processing using multiple GPUs or even multiple nodes. + +Tensor parallelism (TP) is a popular distributed technique for training LLMs. TP leverages the aggregate memory of multiple GPUs to fit LLMs by partitioning each model layer across the GPUs. However, TP incurs two communication collective operations for each partitioned layer, separately for the forward and backward passes. TP is appealing due to its excellent system efficiency in single-node cases, where GPUs are directly connected via high bandwidth links like NVLink and NVSwitch. However, TP falls short in multi-node cases due to the lower bandwidth of cross-node interconnects. [Prior work](https://arxiv.org/abs/2406.06858) reports that communication can take up to 75\% of end-to-end training time. Figure 2 shows that even on the latest DGX-H100 nodes interconnected with high-end Infiniband of 400GB/s bandwidth, communication overheads remains as high as 43\% of end-to-end training iteration time. Recent advances in GeMM+NCCL kernel fusion are unable to fully hide communication overheads due to their limited scope of computation-communication overlapping. The trend of faster compute in newer GPUs (e.g., DGX-B200) indicates that the communication overheads of TP will be more pronounced in both single node and multiple node scenarios. + +
+
+ + *Figure 2: TP communication overhead in GPT-3-13B training using 1,2,4 DGX-H100 nodes (i.e., 8, 16, 32 H100 GPUs).* + +
+ +# Domino Highlights + + +* Domino is TP optimization technique that achieves **Near-Complete** communication hiding behind computation by decomposing a single batch training iteration into smaller and independent pieces, allowing efficient pipelining. + +Domino is the first work that provides a **uniform** Tensor Parallelism (TP) solution for both single-node and **multi-node** cases. Traditional TP solutions (e.g., Megatron-LM) fall short in multi-node cases due to limited cross-node communication bandwidth. + +### Performance + +We tested Domino on 1 to 4 DGX-H100 boxes (8xH100 per box). Each node has intra-node NVLink bandwidth of 900GB/s and inter-node IB bandwidth of 400GB/s. We oberved the following performance results: +1. For both GPT and Llama model series, Domino outperforms Megatron-LM by up to **1.3x** and **1.2x** respectively in end-to-end training iteration throughput for different model sizes, sequence lengths and batch sizes. These results are summarized in Figure 1. +2. For several cases, Domino achieves **near-optimal** training throughput, where optimal throughput refers to the throughput achieved assuming the communication collectives of TP are disabled. + +For more detailed performance results, please refer to our [arxiv paper](https://arxiv.org/abs/2409.15241). + +# Design Motivation + +In this section, we briefly discuss three topics. First, we motivate why the time is right is for a uniform TP solution for both single node and multi-node cases. Next, we analyze the communication overhead on latest Nvidia DGX-H100 boxes with high cross-node communication interconnects. Finally, we describe TP's sequential data dependency which causing communication stands out. + +### It is time for a uniform TP for single and multi-node scenarios + +Nvidia is pushing hard on breaking communication bandwidth gap between intra-node (i.e., GPUs within a node connected with NVLink) and inter-node (i.e., cross-node connected with Infini-Band(IB)). For example, each DGX-H100 is equipped with eight ConnectX-7 network cards and gets aggregated cross-node bandwidth of 400GB/s, which is at same level of intra-node NVLink (900GB/s). Therefore, it is time for proposing a uniform solution for both single node and multi-node TP training. + +### Communication Overhead in TP + +As described in [Megatron-LM paper](https://arxiv.org/pdf/1909.08053), for TP, every transformer block (i.e.,1 Self-Attention layer + 1 MLP layer) incurs 4 AllReduce calls, two in forward pass and two in the backward pass (shown in Figure 3). Given a LLM consisting of $N$ stacked transformer blocks, the number of AllReduce calls required for TP training is $4 * N$. Even for small models like GPT-3 2.7B or 6.7B which consists of 32 layers, the total number of AllReduce calls is 128 for every training iteration. For larger models, the number of AllReduce calls grows linearly with number of layers. + +
+
+ + *Figure 3: TP communication = 4 x AllReduce x num\_transformer\_block* + +
+ +One big issue for TP is that the *communication resides on critical path of every input batch training execution* due to sequential data dependency we described in the following [TP data dependency analysis](#tp-data-dependency-analysis) section. Therefore, the communication overhead stands out and is difficult to hide behind computation. In Figure 4, we provide our communication overhead measurement using Megatron-LM training GPT-3 and Llama-2 model series with different model sizes and batch sizes across 1 to 4 DGX-H100 nodes (i.e., 8 to 32 H100 GPUs). The communication overhead is up to **47\%** despite using latest Nvidia hardware DGX-H100 with 400GB/s cross-node bandwidth. + +
+
+ + *Figure 4: TP communication and computation ratio per training iteration time over different models and batch sizes using 1 to 4 DGX-H100 nodes.* + +
+ +As Llama-3 405B model training takes 54 days on 16,000 H100 GPUs, the projected communication time can be up to around **25 days on 16,000 H100s**. This finding shows that, despite using latest high-bandwidth interconnects like NVLink/Infini-Band(IB), the communication overheads of TP remains a huge portion of end-to-end training time. + +### TP data dependency analysis + +In traditional TP, shown in Figure 5, a transformer layer (either Attn or MLP layer) computation can be abstracted into $X\*A\*B=Y$, where $X$ is input. For attention layer, $A$ is attention computation (e.g., multihead-attention) and $B$ is linear layer. For MLP layer, both $A$ and $B$ are linear layers. An AllReduce is conducted on $Y$ after computation. Due to **sequential data dependency on $Y$ between computation (i.e., $X\*A\*B=Y$) and communication (i.e., AllReduce($Y$)), AllReduce($Y$) completely stands out**, thus making TP not efficient in limited communication bandwidth scenarios. + +
+
+
+ + *Figure 5: TP Forward pass of single Self-Attention/MLP layer. (X is input, A is attention computation for Self-Attention layer and linear for MLP layer, B is linear for both Self-Attention and MLP layer. Y is X\*A\*B output)* + +
+
+ + +# Domino Design + +Compared to Figure 5, Domino breaks data dependency of $X\*A\*B$ via [*Row-wise Split on Inputs X*](#row-wise-split-on-inputs-x), [*Column-wise Split on Weights B*](#column-wise-split-on-weights-b), as well as a [hybrid solution combining these two](#2d-split-on-both-x-and-b). After breaking computation into pieces, Domino pipelines computation and communication working on different independent pieces, thus achieving near-complete communication hiding behind computation. Domino's unique benefits are listed as follows: + +1. Comparing with GeMM+NCCL kernel fusion techniques, Domino breaks data dependency thus has a much wider range of computation kernel sequences to overlap with NCCL call. For example, Domino can overlap AllReduce not only to a single GeMM, but also extend overlapping scope to multiple GeMMs, LayerNorm, DropOut and more. +2. Domino achieves near-complete communication hiding behind computation, thus also achieves near-optimal system throughput in certain cases. (Optimal throughput refers to end-to-end throughput that disables all communication in TP training.) +3. Domino works at kernel scheduler level, any kernel optimizations or new kernels can be seamlessly integrated into Domino framework. +4. Domino tensor partition scheme is simple and generic. It is easy for user side end-to-end correctness debugging when facing issues like overflow or weights/gradients errors. + +For the ease of illustration, we describe forward propagation only (since backward pass is just in reverse order), and we describe only splitting tensor into two chunks. + +## Row-wise split on Inputs X: + +Domino breaks Input X in row dimension (i.e. batch dimension). + +
+
+ + *Figure 6: Domino row-wise (batch-dim) split on inputs X.* + +
+ +**Data Dependency**: Split inputs' batch dimension has no data dependency for both intra-layer and inter-layer cases. Therefore, we achieve both *intra-layer* (AllReduce($Y1$) and $X2\*A\*B$) and *inter-layer* (AllReduce($Y2$) and next-layer's $X1\*A\*B$) computation-communication overlapping. With this batch split on inputs, Domino can hide up to **100\%** communication behind computation. + +## Column-wise split on Weights B: + +Domino breaks weight matrix B in column dimension. + + +
+
+ + *Figure 7: Domino column-wise (last-dim) split on weights B.* + +
+ +**Data Dependency**: Split Weights B column-wise have no data dependency in intra-layer case but have data dependency in inter-layer case. Therefore, we only achieve *intra-layer* + (AllReduce($Y1$) and $X2\*A\*B$) computation-communication overlapping. This column-split on weights scheme remains essential, since row-wise input split only would lead to narrow shape tensors that hinder kernel computational efficiency. In practice, Domino achieves 50\% to 70\% communication hiding behind computation with weights B column-wise split. + +## 2D Split on both X and B: + +For extremely large LLMs, Domino splits both inputs X and weights B in row and column dimension, separately. This method is beneficial for model training requiring both low memory footprints and minimizing communication overheads. + +
+
+ + *Figure 8: Domino 2D split on both inputs X and weights B.* + +
+ +**Data Dependency**: This 2D split policy inherits synchronization at the end of each transformer layer due to column-wise split on weights B. Therefore, the 2D approach only achieves *intra-layer* computation-communication overlapping. + +# Implementation and Optimization + +For brevity, we summarize key implementation of row-wise input split. For more implementation details, please refer to our [arxiv paper](https://arxiv.org/abs/2409.15241). + +**Forward:** Figure 9 shows how we position and trigger NCCL calls in order to overlap with computation kernel sequences in forward propagation. We split batch into two chunks as $\mu$-batch0 and $\mu$-batch1. $\mu$-batch0 attention output as attn0 and MLP output as MLP0. $\mu$-batch1's attention output as attn1 and MLP output as MLP1. AllReduce(attn0) is overlapped with self-attention computation on $\mu$-batch1. For AllReduce(attn1), we group multiple $\mu$-batches' Dropout, Residual, LayerNorm computation-communication overlapping. This small kernel grouping not only enable complete hiding of AllReduce(attn1), but also provides proper overlapping space for AllReduce(MLP0) in the backward pass shown in Figure 10. For AllReduce(MLP0), we hide it behind $\mu$-batch1's MLP computation kernel sequence of GeMM + GeLU + GeMM. For AllReduce(MLP1), we hide it behind next layer's attention computation. + +
+
+ + *Figure 9: Transformer block (i.e., 1 self-attn + 1 MLP) forward pass. Upper figure is vanila TP implementation, bottom is Domino implementation.* + +
+ +**Backward:** Figure 10 shows a simple example of batch split in to two $\mu$-batches as $\mu$-batch0 and $\mu$-batch1. Besides similar overlapping strategy in the forward pass, we extend the scope of overlap communication with weights' gradient computation inside same $\mu$-batch (e.g., AllReduce(MLP1) partially overlaps with its own $\mu$-batch1 computation as the 3rd orange block from left). Each *grad matmul* includes two separate GeMM computation for inputs gradient and weights gradient. Therefore, we can extend overlapping scope by overlapping AllReduce(MLP1) with $\mu$-batch1's weights gradient computation. + +Backward is a bit more challenging because backward computation graph is automatically generated by torch.autograd(). To precisely control NCCL call triggering time, we implement a *no\_operation* module, which obtains communication handle during forward pass and retains it for use during backward pass. Our *no\_operation* module works seamlessly with torch.autograd(), and enable us precisely control NCCL start/end time without rewriting customized backward computation graph. + +
+
+ + *Figure 10: Transformer block (i.e., 1 self-attn + 1 MLP) backward pass. Upper figure is vanila TP implementation, bottom is Domino implementation.* + +
+ +**General kernel optimizations:** We adopt general kernel-level optimization techniques. For example, we use cudaGraph to squeeze idle/bubble time between adjacent compute kernels to reduce end-to-end latency. We use CUDA multi-stream to increase parallel execution. We also leverage torch.compile() to further improve our system efficiency. + +# Getting Started: Try out DeepSpeed-Domino + +To try out DeepSpeed-Domino, please refer to [Domino tutorial](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/training/DeepSpeed-Domino/README.md) in our DeepSpeedExample repo. + +## Citation + +``` +@article{wang2024-deepspeed-domino, + title={{Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping}}, + author={Guanhua Wang and Chengming Zhang and Zheyu Shen and Ang Li and Olatunji Ruwase}, + journal={arXiv preprint arXiv:2409.15241}, + year={2024} +} +``` + +## Acknowledgements + +This work is the result of a deep collaboration between Microsoft DeepSpeed and our academia partners from University of Maryland, University of Houston. The contributors include [Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), [Hongwei Chen](https://github.com/hwchen2017) and [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/) from Microsoft DeepSpeed Team, [Chengming Zhang](https://chengmingzh8.github.io/) from University of Houston, [Zheyu Shen](https://www.linkedin.com/in/zheyushen/) and [Ang Li](https://www.ang-li.com/) from University of Maryland. diff --git a/blogs/deepspeed-domino/images/design-base.png b/blogs/deepspeed-domino/images/design-base.png new file mode 100644 index 000000000000..d347e9c2ba8b Binary files /dev/null and b/blogs/deepspeed-domino/images/design-base.png differ diff --git a/blogs/deepspeed-domino/images/design-column.png b/blogs/deepspeed-domino/images/design-column.png new file mode 100644 index 000000000000..a99ad3c6b461 Binary files /dev/null and b/blogs/deepspeed-domino/images/design-column.png differ diff --git a/blogs/deepspeed-domino/images/design-hybrid.png b/blogs/deepspeed-domino/images/design-hybrid.png new file mode 100644 index 000000000000..302e3f95e8fc Binary files /dev/null and b/blogs/deepspeed-domino/images/design-hybrid.png differ diff --git a/blogs/deepspeed-domino/images/design-row.png b/blogs/deepspeed-domino/images/design-row.png new file mode 100644 index 000000000000..551a54f4e651 Binary files /dev/null and b/blogs/deepspeed-domino/images/design-row.png differ diff --git a/blogs/deepspeed-domino/images/domino-hero.png b/blogs/deepspeed-domino/images/domino-hero.png new file mode 100644 index 000000000000..078b6472b42a Binary files /dev/null and b/blogs/deepspeed-domino/images/domino-hero.png differ diff --git a/blogs/deepspeed-domino/images/domino-logo.png b/blogs/deepspeed-domino/images/domino-logo.png new file mode 100644 index 000000000000..58be0990b944 Binary files /dev/null and b/blogs/deepspeed-domino/images/domino-logo.png differ diff --git a/blogs/deepspeed-domino/images/gpt3-scale.png b/blogs/deepspeed-domino/images/gpt3-scale.png new file mode 100644 index 000000000000..611b2221a73c Binary files /dev/null and b/blogs/deepspeed-domino/images/gpt3-scale.png differ diff --git a/blogs/deepspeed-domino/images/implement-bwd.png b/blogs/deepspeed-domino/images/implement-bwd.png new file mode 100644 index 000000000000..4b115222f387 Binary files /dev/null and b/blogs/deepspeed-domino/images/implement-bwd.png differ diff --git a/blogs/deepspeed-domino/images/implement-fwd.png b/blogs/deepspeed-domino/images/implement-fwd.png new file mode 100644 index 000000000000..51d3a73bae58 Binary files /dev/null and b/blogs/deepspeed-domino/images/implement-fwd.png differ diff --git a/blogs/deepspeed-domino/images/tp-ar.png b/blogs/deepspeed-domino/images/tp-ar.png new file mode 100644 index 000000000000..6dd01ccceed8 Binary files /dev/null and b/blogs/deepspeed-domino/images/tp-ar.png differ diff --git a/blogs/deepspeed-domino/images/tp-comm-overhead.png b/blogs/deepspeed-domino/images/tp-comm-overhead.png new file mode 100644 index 000000000000..947473ff5261 Binary files /dev/null and b/blogs/deepspeed-domino/images/tp-comm-overhead.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/README.md b/blogs/deepspeed-fastgen/2024-01-19/README.md new file mode 100644 index 000000000000..6494d8f9a303 --- /dev/null +++ b/blogs/deepspeed-fastgen/2024-01-19/README.md @@ -0,0 +1,187 @@ +
+ +# DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements. + +
+ +
+ + +
+ +# Table of Contents +1. [Introduction](#introduction) +2. [New Model Families](#new-model-families) +3. [Performance Optimizations](#performance-optimizations) +4. [Feature Enhancements](#stability-and-software-enhancements) +5. [Community Engagement](#community-engagement) +6. [Try Out DeepSpeed-FastGen](#try-out-deepspeed-fastgen) + + +# 1. Introduction + +[DeepSpeed-FastGen](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen) is an inference system framework that enables easy, fast, and affordable inference for large language models (LLMs). From general chat models to document summarization, and from autonomous driving to copilots at every layer of the software stack, the demand to deploy and serve these models at scale has skyrocketed. DeepSpeed-FastGen utilizes the Dynamic SplitFuse technique to tackle the unique challenges of serving these applications and offer higher effective throughput than other state-of-the-art systems like vLLM. + +Today, we are happy to share that we are improving DeepSpeed-FastGen along three areas: i) three new model families, ii) performance optimizations, and iii) feature enhancements: +- **New Model Families** + + We introduce support for Mixtral (MoE), Falcon, and Phi-2 model families in DeepSpeed-FastGen. Our inference optimizations for these models provide up to 2.5X improvement in effective throughput over other state-of-the-art frameworks like vLLM. + +- **Performance Optimizations** + + We drastically reduced the scheduling overhead of Dynamic SplitFuse and increased the efficiency of token sampling. As a result, we see higher throughput and lower latency, particularly when handling concurrent requests from many clients. We demonstrate the performance optimizations with benchmarks and evaluation of DeepSpeed-FastGen against vLLM for the newly added model families. The benchmark results can be seen in [Performance Evaluation](#performance-optimizations) and the benchmark code is available at [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/inference/mii). + +- **Feature Enhancements** + + DeepSpeed-FastGen contains a rich set of features for running inference with many different model families and over 20,000 HuggingFace hosted models. We extend this feature set for all models to include a RESTful API, more generation options, and support for models using the safetensor checkpoint format. Additionally, we improve on overall stability and address bugs in our original DeepSpeed-FastGen release. + +We now dive into the details of the new model families, performance optimizations, and software improvements. If you would like to get started right away please see [Try Out DeepSpeed-FastGen](#try-out-deepspeed-fastgen). This new release is available in [DeepSpeed versions >= 0.13.0](https://github.com/deepspeedai/DeepSpeed/tree/v0.13.0) and [DeepSpeed-MII versions >= 0.2.0](https://github.com/deepspeedai/DeepSpeed-MII/tree/v0.2.0). + +# 2. New Model Families + +Today we introduce support for three new model families: i) [Mixtral (MoE)](https://arxiv.org/pdf/2401.04088.pdf), ii) [Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/), and iii) [Falcon](https://arxiv.org/pdf/2311.16867v1.pdf) + +## Mixtral + +Mixtral model, a language model based on sparse mixture of experts (MoE), has demonstrated promising performance across multiple benchmarks. The Mixtral model operates by applying a router network at each layer for every token, selecting two distinct experts for processing the current state and combine their outputs. This process is dynamic, with the possibility of different experts being chosen at each timestep. This architecture ensures that while each token is exposed to a broad spectrum of parameters, it actively utilizes only a subset during inference. + +In this release, we are pleased to announce the support for Mixtral models. We've enhanced our FastGen codebase by the integration of the Mixtral model implementation, refinements to our high-performance kernels for efficient top-k gating, and updates to Rotary Positional Encoding (RoPE) implementation. These advancements ensure that users can fully exploit the capabilities of DeepSpeed-FastGen for executing Mixtral model inference, thereby achieving heightened performance and efficiency. + +## Phi-2 + +Microsoft Research has introduced a suite of small language models (SLMs) named "Phi," notable for their exceptional performance across a spectrum of benchmarks. The latest addition to this suite, [Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/), is a language model boasting 2.7 billion parameters. It stands out as a testament to outstanding reasoning and language understanding capabilities, exemplifying state-of-the-art performance within the realm of base language models featuring fewer than 13 billion parameters. Notably, Phi-2 achieves parity with or surpasses models up to 25 times its size on complex benchmarks, a feat attributed to pioneering innovations in model scaling and meticulous training data curation. + +Owing to its compact size, Phi-2 emerges as an ideal model for both researchers and deployment scenarios, promising a reduction in inference costs. To efficiently support the Phi-2 model family, we introduce partial RoPE support in our DeepSpeed-FastGen kernels. + +## Falcon + +Falcon is a family of large language models (LLMs) developed by the Technology Innovation Institute (TII). The Falcon models include Falcon 7B, Falcon-40B and its larger counterpart, Falcon-180B, the largest openly available language model to date. + +A closer examination of the architectural nuances within the Falcon series reveals notable distinctions. Specifically, the Falcon 7B model diverges slightly from Falcon-40B; notably, Falcon-40B incorporates an additional layer norm preceding the parallel MLP layer, a feature absent in the Falcon 7B model. In contrast, Falcon-180B adheres to the same architecture as Falcon-40B but stands out as a scaled-up version. + +# 3. Performance Optimizations and Evaluation + +SplitFuse effectively enhances utilization by simultaneously computing prompts and decoding (generating tokens). However, we observed a significant overhead for scheduling ragged batching, especially when generating a large number of tokens from numerous concurrent requests. In this release, we've minimized this scheduling overhead for querying KV cache states. As a result, there's a notable improvement in the performance for scenarios with a large number of generation steps. + +In general for long prompts and a smaller number of generated tokens, we can fully utilize the benefits of SplitFuse, which combines prompt processing and decoding (token generation) in a single forward pass. This provides a significant advantage over vLLM in these scenarios as shown in our [previous blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen). For short prompts and a larger number of generated tokens, where most forward passes run purely for decoding, our highly optimized engine and the efficient scheduler for ragged batching demonstrate impressive performance. + +We follow the benchmarking methodology we presented in our [previous blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen#a-benchmarking-methodology). + +*NOTE: All the benchmarks in this blog use the recommended DeepSpeed-FastGen persistent deployment mode.* + +### Mixtral + +We developed a new MoE module, which contains kernels optimized for our inference engine. The enhancements in the decoding phase, included in this release, significantly improve throughput and efficiency in generating a large number of tokens as shown in *Figure 1*. + +
+
+ + *Figure 1: Throughput-latency curve of Mixtral using A100. A normal distribution was applied to prompt and generation lengths with averages of (1200, 2600) and (60, 128), respectively, and a 30% variance*
+
+ +We show the throughput-latency of Mixtral-8x7B-v0.1 running on A100 with tensor parallelism degree of 4. First, we show the scenarios where the prompt lengths are longer than the number of generation steps (i.e., tokens), which is typical of popular use cases like chatbots. From *Figure 1*, DeepSpeed-FastGen provides 2.4X higher throughput for a prompt length of 1200 and 60 generation steps. In addition to the performance for the long prompt scenarios, we present new results for shorter prompts and larger number of generation steps in *Figure 2*. Our performance advantage still holds. + +
+
+ + *Figure 2: Throughput-latency curve of Mixtral using A100. A normal distribution was applied to prompt and generation lengths with averages of 500 and (150, 500, 1024), respectively, and a 30% variance*
+
+ +As we can see in *Figure 2*, DeepSpeed-FastGen is showing higher throughput and lower latency thanks to the scheduling performance improvements presented in this blog. + +### Phi-2 + +
+
+ + *Figure 3: Throughput-latency curve of Phi-2 using A100. A normal distribution was applied to prompt and generation lengths with averages of (1200, 1900) and (60, 128), respectively, and a 30% variance*
+
+ +From *Figure 3*, DeepSpeed-FastGen provides 1.5X higher throughput for a prompt length of 1900 and 60 generation steps. For other scenarios our throughput-latency evaluation of the Phi-2 model show a similar pattern, with DeepSpeed-FastGen providing equivalent latency with greater throughput or lower latency for the same throughput. + +### Falcon + +Given the substantial size of the Falcon-40B and Falcon-180B models, the majority of computations are dedicated to forward passes, while the overhead of scheduling and token sampling is relatively minor. + +
+
+ + *Figure 4: Throughput-latency curve of Falcon 40B using A100. A normal distribution was applied to prompt and generation lengths with averages of (1200, 1900) and (60, 128), respectively, and a 30% variance*
+
+ +
+
+ + *Figure 5: Throughput-latency curve of Falcon 180B using A100. A normal distribution was applied to prompt and generation lengths with averages of (1200, 1900) and (60, 128), respectively, and a 30% variance*
+
+ +As seen in *Figure 4* and *Figure 5*, DeepSpeed-FastGen is able to provide higher throughput and lower latency compared to vLLM for Falcon-40B and Falcon-180B. + +# 4. Feature Enhancements + +In this section we introduce several feature enhancements that have been released since we first introduced DeepSpeed-FastGen. + +## Performance improvements +We achieve a notable improvement in performance by minimizing the scheduling overhead for querying KV cache states as discussed in [Performance Optimizations](#performance-optimizations). + +See [PR-4965](https://github.com/deepspeedai/DeepSpeed/pull/4965), [PR-377](https://github.com/deepspeedai/DeepSpeed-MII/pull/377) for more details. + +## Support for safetensor checkpoints +Some HuggingFace-hosted model checkpoint weights are provided only in the safetensor format. We extend our HuggingFace checkpoint engine to work with the safetensor format to support even more models! + +See [PR-4659](https://github.com/deepspeedai/DeepSpeed/pull/4659), [PR-296](https://github.com/deepspeedai/DeepSpeed-MII/pull/296) for more details. + +## Added RESTful API + +We add the option to automatically stand up a RESTful API when creating DeepSpeed-FastGen persistent deployments in DeepSpeed-MII. This API provides a way for users to send prompts to their deployments and receive responses using HTTP POST methods and tools like `curl` or python's `request` package. The RESTful API provides the same high throughput and low latency performance as our python APIs. For more information, please see [MII RESTful API](https://github.com/deepspeedai/DeepSpeed-MII#restful-api). + +See [PR-348](https://github.com/deepspeedai/DeepSpeed-MII/pull/348), [PR-328](https://github.com/deepspeedai/DeepSpeed-MII/pull/328), [PR-294](https://github.com/deepspeedai/DeepSpeed-MII/pull/294) for more details. + +## Added deployment and generate options + +We extend the customizability of DeepSpeed-FastGen deployments and text-generation. Users can now specify a `device_map` when creating non-persistent pipelines and persistent deployments that controls which GPUs to use for hosting a model. Additionally, the interfaces between pipelines and deployments now match and include options for setting top-p, top-k, and temperature values. For additional information about the user-exposed options, please see [MII Pipeline](https://github.com/deepspeedai/DeepSpeed-MII#non-persistent-pipeline) and [MII Deployment](https://github.com/deepspeedai/DeepSpeed-MII#persistent-deployment). + +See [PR-331](https://github.com/deepspeedai/DeepSpeed-MII/pull/331), [PR-280](https://github.com/deepspeedai/DeepSpeed-MII/pull/280), [PR-275](https://github.com/deepspeedai/DeepSpeed-MII/pull/275), [PR-268](https://github.com/deepspeedai/DeepSpeed-MII/pull/268), [PR-295](https://github.com/deepspeedai/DeepSpeed-MII/pull/295), for more details. + +## Mitigate risk of deadlock + +In use cases where many prompts are sent to a deployment in a small time window, deadlock can occur in the DeepSpeed-FastGen inference engine, resulting in no text-generation progress is made on any prompts. To mitigate this, we ensure that there is a sufficient margin in the KV cache when scheduling requests. While not completely resolved, we continue to investigate a fix for these situations that arrive when the deployment is under heavy load. + +See [PR-274](https://github.com/deepspeedai/DeepSpeed-MII/pull/274) for more details. + +## Inference Checkpoints + +We add the capability to create inference engine snapshots to DeepSpeed-FastGen. This reduces the loading time for large models in future deployments. + +See [PR-4664](https://github.com/deepspeedai/DeepSpeed/pull/4664) for more details. + +## General stability and bug fixes + +We include many bug fixes and stability improvements to DeepSpeed-FastGen. This includes fixing issues with some OPT model size variants, bugs with MII configuration options, and improved error messages. + +See [PR-4938](https://github.com/deepspeedai/DeepSpeed/pull/4938), [PR-4920](https://github.com/deepspeedai/DeepSpeed/pull/4920), [PR-4739](https://github.com/deepspeedai/DeepSpeed/pull/4739), [PR-4694](https://github.com/deepspeedai/DeepSpeed/pull/4694), [PR-4634](https://github.com/deepspeedai/DeepSpeed/pull/4634), [PR-367](https://github.com/deepspeedai/DeepSpeed-MII/pull/367), [PR-350](https://github.com/deepspeedai/DeepSpeed-MII/pull/350), for more details. + +# 5. Community Engagement + +DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, and companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. + +*We would like to recognize the contribution from our user community in adding support for the [Qwen](https://arxiv.org/abs/2309.16609) model family to DeepSpeed-FastGen in [PR-4913](https://github.com/deepspeedai/DeepSpeed/pull/4913).* + +# 6. Try Out DeepSpeed-FastGen + +We are very excited to share this DeepSpeed-FastGen release. + +* To get started, please visit our GitHub page for DeepSpeed-MII: [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeed-MII) + +DeepSpeed-FastGen is part of the bigger DeepSpeed ecosystem comprising a multitude of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* You can also follow us on our [English Twitter](https://twitter.com/DeepSpeedAI), [Japanese Twitter](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +The following items are on our roadmap and we plan to engage with our community on these through our GitHub issues and PRs: + +* Performance improvements +* Quantization support +* New hardware backends through collaboration with partners + +**"Star" our [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) and [DeepSpeed-MII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) repositories if you like our work!** diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-dark.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-dark.png new file mode 100644 index 000000000000..1121fa9dafd6 Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-dark.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-light.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-light.png new file mode 100644 index 000000000000..35f60788331c Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/fastgen-hero-light.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-180B_tp8.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-180B_tp8.png new file mode 100644 index 000000000000..6ccfcb0fe17f Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-180B_tp8.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-40b_tp2.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-40b_tp2.png new file mode 100644 index 000000000000..b08401cca7cb Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_falcon-40b_tp2.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_1.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_1.png new file mode 100644 index 000000000000..519781246289 Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_1.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_2.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_2.png new file mode 100644 index 000000000000..f2bf11cda74b Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_mistralai-Mixtral-8x7B-v0.1_tp4_2.png differ diff --git a/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_phi-2_tp1.png b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_phi-2_tp1.png new file mode 100644 index 000000000000..7e92417a64fe Binary files /dev/null and b/blogs/deepspeed-fastgen/2024-01-19/assets/images/th_lat_curve_phi-2_tp1.png differ diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md new file mode 100644 index 000000000000..af2fa085bf21 --- /dev/null +++ b/blogs/deepspeed-fastgen/README.md @@ -0,0 +1,309 @@ +
+ +# DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference + +
+ +
+ + +
+ +## Table of Contents +1. [Introduction](#introduction) +2. [Key LLM Serving Techniques](#background) +3. [Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy](#technical-approach) +4. [Performance Evaluation](#performance-evaluation) +5. [DeepSpeed-FastGen: Implementation and Usage](#using-deepspeed-fastgen) +6. [Try out DeepSpeed-FastGen](#try) +7. [Acknowledgements](#acknowledgements) + + +## 1. Introduction + +Large language models (LLMs) like GPT-4 and LLaMA have emerged as a dominant workload in serving a wide range of applications infused with AI at every level. From general chat models to document summarization, and from autonomous driving to copilots at every layer of the software stack, the demand to deploy and serve these models at scale has skyrocketed. While frameworks like DeepSpeed, PyTorch, and several others can regularly achieve good hardware utilization during LLM training, the interactive nature of these applications and the poor arithmetic intensity of tasks like open-ended text generation have become the bottleneck for inference throughput in existing systems. + +To this end, frameworks like [vLLM](https://arxiv.org/pdf/2309.06180.pdf) powered by PagedAttention and research systems like [Orca](https://www.usenix.org/system/files/osdi22-yu.pdf) have significantly improved the performance of inference for LLMs. However, these systems still struggle to provide consistent quality of service, particularly for workloads with longer prompts. These long prompt workloads are becoming increasingly important as more and more models, like [MPT-StoryWriter](https://www.mosaicml.com/blog/mpt-7b), and systems, such as [DeepSpeed Ulysses](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses), support context windows stretching to tens of thousands of tokens. To better understand the problem space, we provide detailed examples of how text generation works for LLMs in two distinct phases called prompt processing and generation. When systems treat them as distinct phases, generation will be preempted by prompt processing that risks breaking the service level agreements (SLAs). + +Today, we are glad to present DeepSpeed-FastGen, a system that overcomes these limitations by leveraging the proposed Dynamic SplitFuse technique and offers up to 2.3x higher effective throughput compared to state-of-the-art systems like vLLM. DeepSpeed-FastGen leverages the combination of DeepSpeed-MII and DeepSpeed-Inference to provide an easy-to-use serving system. + +**Quick Start:** Trying DeepSpeed-FastGen is as simple as installing the latest [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII) release: + +```bash +pip install deepspeed-mii +``` + +To generate text using a simple non-persistent pipeline deployment, run the following code. For more details, please see [Section 5](#using-deepspeed-fastgen). + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +## 2. Existing LLM Serving Techniques in Literature + +A text generation workload for a single sequence consists of two phases: 1) prompt processing, in which the user-provided text is efficiently processed as a batch of tokens to build a key-value (KV) cache for attention, and 2) token generation, which will add a single token to that cache and generate a new token. Over the course of generating a sequence of text, the model will make many forward calls to the model to generate the full sequence of text. Two major techniques have been proposed in the literature and deployed in systems that address various limitations and bottlenecks that may arise during these phases. + +_ Blocked KV Caching: _ + +vLLM identified that memory fragmentation due to large monolithic KV-caches significantly reduced the concurrency of LLM serving systems and proposed [Paged Attention](https://arxiv.org/pdf/2309.06180.pdf) to enable non-contiguous caches and increase total system throughput. Rather than assign individual variable-sized contiguous chunks of memory, the underlying storage in the KV cache is fixed-sized blocks (also known as pages). The blocked KV-cache increases system throughput by increasing the amount of potential sequence concurrency by eliminating KV-cache induced memory fragmentation. Non-contiguous KV cache implementations are also included in [HuggingFace TGI](https://github.com/huggingface/text-generation-inference) and [NVIDIA TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). + +_ Continuous Batching: _ + +In the past, dynamic batching, in which a server would wait for multiple requests to process in phase with each other, was used to improve GPU utilization. However, this approach has drawbacks, as it typically requires padding inputs to identical lengths or stalling the system to wait to construct a larger batch. + +Recent advancement in large language model (LLM) inference and serving has been focusing on fine granularity scheduling and optimizing memory efficiency. For instance, Orca proposes _iteration-level scheduling_ (also known as continuous batching) which makes distinct scheduling decisions at each forward pass of the model. This allows requests to join/leave the batch as needed, eliminating the need for padding requests thus improving the overall throughput. In addition to Orca, continuous batching has been implemented in NVIDIA TRT-LLM, HuggingFace TGI, and vLLM. + +In current systems, there are two primary approaches to implement continuous batching. In TGI and vLLM, the generation phase is preempted to perform prompt processing (called infill in TGI) before continuing with generation. In Orca, these phases are not distinguished; instead, Orca will add a prompt into the running batch so long as the total number of sequences doesn't reach a fixed bound. Both of these approaches to varying degrees need to stall generation to process long prompts (see [Section 3B](#splitfuse)). + +To address these shortcomings, we propose a novel prompt and generation composition strategy, Dynamic SplitFuse. + +## 3. Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy + +DeepSpeed-FastGen is built to leverage continuous batching and non-contiguous KV caches to enable increased occupancy and higher responsivity for serving LLMs in the data center, similar to existing frameworks such as TRT-LLM, TGI, and vLLM. In order to achieve a new level of performance, DeepSpeed-FastGen introduces SplitFuse which leverages dynamic prompt and generation decomposition and unification to further improve continuous batching and system throughput. + +### A. Three Performance Insights +Before describing Dynamic SplitFuse, we answer three key performance questions that together motivate its design. + +*__1. What factors impact the forward pass of a single LLM?__* In order to effectively schedule, it is necessary to understand what are the relevant independent variables the scheduling loop should control. We observe below that the composition of sequences in a forward pass (the batch size in sequences) has a negligible impact on performance compared to the raw number of tokens in the forward pass. This means an effective scheduler can be built around a single signal, the number of tokens in the forward pass. + +
+
+
+ +*__2. How does a model's throughput respond to changing the number of tokens in the forward pass?__* An LLM has two key operating regions with a relatively steep transition. With a small number of tokens, the GPU bottleneck is reading the model from memory and so throughput scales with the number of tokens, whereas with many tokens the model is throughput bound by compute and sees near-constant throughput. The model should run highly efficiently if all forward passes are in the throughput-saturating region. + +
+
+
+ +*__3. How should a pool of tokens be scheduled across multiple forward passes?__* We observe above that for well-aligned inputs the token-throughput curve is concave, which means the second derivative is bound to be less than or equal to 0. As an example, let $f(x)$ be a concave function of latency to throughput for a given model. For a concave function $f(x)$, the following holds: + + $$0 \geq \lim_{h \to 0} \frac{f(x + h) - 2f(x) + f(x - h)}{h^2}$$ + + $$0 \geq f(x + h) - 2f(x) + f(x - h)$$ + + $$2f(x) \geq f(x + h) + f(x - h)$$ + +This states that for a given pool of `2x` tokens to process, the manner that maximizes throughput is that which evenly splits them between two batches. More generally, in a system that must consume and process P tokens over F forward passes, the ideal partitioning scheme will divide them equally. + +### B. Dynamic SplitFuse + +Dynamic SplitFuse is a novel token composition strategy for prompt processing and token generation. DeepSpeed-FastGen utilizes Dynamic SplitFuse to run at a consistent forward size by leveraging the capability to take partial tokens from prompts and compose this with generation. In particular, Dynamic SplitFuse performs two key behaviors: + +1. Long prompts are decomposed into much smaller chunks and scheduled across multiple forward passes (iterations) with only the final pass performing any generation. +2. Short prompts will be composed to exactly fill a target token budget. Even short prompts may be decomposed to ensure the budget is precisely met and the forward sizes are well-aligned. + +Together, these two techniques provide concrete benefits on all user metrics: + +1. *__Better Responsiveness__:* Since long prompts no longer require extremely long forward passes to process, the model will provide lower client latency. More forward passes are performed within the same window of time. +2. *__Higher Efficiency:__* Fusion of short prompts to larger token budgets enables the model to consistently operate in the high throughput regime. +3. *__Lower variance and better consistency:__* Since forward passes are of consistent size and forward pass size is the primary determinant of performance, the latency of each forward pass is much more consistent than competing systems as is the perceived generation frequency. There are no pre-emption or long-running prompts to increase the latency as in other prior work. + +Consequently, DeepSpeed-FastGen will consume tokens from incoming prompts at a rate that permits fast ongoing generation while adding tokens to the system that increase system utilization, providing lower latency and higher throughput streaming generation to all clients as compared to other state-of-the-art serving systems. + +
+ +
+ + *Figure 1: Illustration of continuous batching strategies. Each block shows the execution of a forward pass. An arrow indicates that the forward pass has sequences with one or more tokens generated. vLLM performs either token generations or prompt processing in a forward pass; token generation preempts prompt processing. Orca runs prompts at their complete length alongside generation. Dynamic SplitFuse performs dynamic composition of fixed-sized batches composed of both generation and prompt tokens.* + +
+ +## 4. Performance Evaluation + +DeepSpeed-FastGen provides state-of-the-art LLM serving performance leveraging its blocked KV cache and Dynamic SplitFuse continuous batching. We evaluate DeepSpeed-FastGen against vLLM on a range of models and hardware configurations following the benchmarking methodology discussed below. + +### A. Benchmarking Methodology + +We use two primary quantitative schemes for measuring performance. + +**Throughput-Latency Curves:** Two key metrics for production readiness are throughput (measured in requests per second) and latency (the responsiveness of each request). To measure this, we instantiate multiple clients (ranging from 1 to 32) concurrently and send requests (512 in total) to the server. The resulting latency of each request is measured at the endpoint and throughput is measured by the end-to-end time to complete the experiment. + +**Effective Throughput:** Interactive applications, such as chat applications, can have more stringent and complex requirements than can be captured by top-level metrics like end-to-end latency. In particular, we focus on the increasingly popular chat user scenario: + + 1. A user initiates a task by sending a prompt. + 2. The system processes the prompt and returns the first token. + 3. Subsequent tokens are streamed to the user as they are produced. + +At each point in this process there is an opportunity for a system to provide an adverse user experience; for example, if the first token arrives too slowly or the generation appears to stop for some time. We propose an SLA framework that considers both of these dimensions. + +As the lengths of prompts and generated texts vary significantly, affecting computational costs, it is impractical to set rigid SLA values for throughput and latency. Therefore, we define the SLA for prompt latency as |tokens in prompt| / 512 seconds (= 512 tokens/s). Additionally, considering humans' reading speed, we set the SLA for generation latency on the Exponential Moving Average (EMA) to 2, 4, or 6 tokens/sec. Requests that adhere to these SLAs are deemed successful, and the throughput of these successful requests is referred to as **effective throughput**. + +We evaluate vLLM and DeepSpeed-FastGen on both Llama-2 7B, Llama-2 13B, and Llama-2 70B on NVIDIA A100, H100, and A6000. + +### B. Throughput-Latency Analysis + +In this experiment, DeepSpeed-FastGen outperforms vLLM in both throughput and latency, providing equivalent latency with greater throughput or more responsive latency and the same throughput. On Llama-2 70B with 4 A100x80GB, DeepSpeed-FastGen demonstrates up to 2x higher throughput (1.36 rps vs. 0.67 rps) at identical latency (9 seconds) or up to 50% latency reduction (7 seconds vs. 14 seconds) while achieving the same throughput (1.2 rps), as shown in Figure 2. These trends hold when evaluating Llama-2 13B as shown in Figure 3. + +
+
+ + *Figure 2: Throughput and latency of text generation using Llama 2 70B (Tensor parallelism across 4 A100-80GB GPUs). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 128/60, respectively, and a 30% variance* +

+ +
+
+ + *Figure 3: Throughput and latency of text generation using Llama 2 13B (A100-80GB GPU, no tensor parallelism). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 60/128, respectively, and a 30% variance* +
+ +### C. Effective Throughput Analysis + +Under the effective throughput analysis that considers both first token latency and the rate at which generation occurs, DeepSpeed-FastGen provides up to 2.3x higher throughput than vLLM. Figure 4 presents a comparative analysis of the effective throughputs of DeepSpeed-FastGen and vLLM. Each plotted point denotes the effective throughput derived from a specific number of clients. As we scaled the number of clients, we initially observed an increase in effective throughput. However, the latency also significantly increases as the number of clients approaches the system's capacity, causing many requests to fail in meeting the SLA. Consequently, the effective throughput will either saturate or decrease at some point. From a usability perspective, it's not particularly relevant how many clients are required to achieve the max effective throughput; the maximum point of the line is the optimal serving point. + +
+ + + *Figure 4: Effective throughput of DeepSpeed-FastGen and vLLM (Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance)* +

+ +When vLLM preempts the ongoing generation of previous requests, the generation latency experiences a notable increase. This leads to vLLM's effective throughput appearing lower than its directly measured throughput. At vLLM's peak, the effective throughput was 0.63 queries/sec and around 28% of requests failed to meet the 4 tokens/s SLA. At the same SLA, DeepSpeed-FastGen achieved 1.42 queries/sec (less than 1% of requests failed to meet the SLA), which is 2.3x higher than vLLM. + +### D. Token Level Timing Analysis + +Figure 5 displays the P50, P90, and P95 latencies of the generation processes. Both vLLM and DeepSpeed-FastGen exhibit similar P50 latencies, but vLLM demonstrates significantly higher latencies for P90 and P95. +Regarding the P95 latencies, DeepSpeed-FastGen achieved a reduction of 3.7 times. + +This discrepancy is due to a noticeable spike in vLLM's generation latency when it preempts the ongoing generation to process new prompts. +In contrast, DeepSpeed-FastGen typically processes the prompt and generation for previous requests concurrently, leading to much more consistent generation latency. + + +
+
+ + *Figure 5: Per-Token generation Latency of Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs, 16 clients. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 128, respectively, and a 30% variance.* +

+ + +### E. Scalability using Load Balancing + +DeepSpeed-FastGen offers replica-level load balancing that evenly distributes requests across multiple servers, allowing you to effortlessly scale up your application. + +Figure 6 illustrates the scalability of DeepSpeed-FastGen when employing the load balancer and up to 16 replicas. Note that we utilized 4 A100 GPUs to compute the Llama 2 70B model. In total, we employed 8 nodes to run the 16 replicas. The results demonstrate nearly perfect scalability with DeepSpeed-FastGen. +Given that the throughput of a single replica is 1.46 queries/sec, the throughput with 16 replicas reaches 23.7 queries/sec, marking a linear 16x increase compared to a single replica. + +
+
+ + *Figure 6: Scalability using the load balancing feature. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +### F. Other Hardware Platforms + +In addition to the deep analysis on A100, we provide additional benchmarking results for H100 and A6000. The same performance trends were observed on both A6000 and H100 as A100. + +
+
+ + *Figure 7: Throughput-latency curve and effective throughput of Llama 2 70b using 8 H100 GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +
+
+ + *Figure 8: Throughput-latency curve and effective throughput of Llama 2 7b using A6000. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +## 5. DeepSpeed-FastGen: Implementation and Usage + +DeepSpeed-FastGen is the synergistic composition of [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII) and [DeepSpeed-Inference](https://github.com/deepspeedai/DeepSpeed) as illustrated in the figure below. Together, both of these software packages provide various components of the system including the frontend APIs, the host and device infrastructure to schedule batches using Dynamic SplitFuse, optimized kernel implementations, and the tools to construct new model implementations. + +
+ + +
+ + +The fastest way to get started with our alpha release of DeepSpeed-FastGen is: `pip install deepspeed-mii`. + +Please follow our [Getting Started](https://github.com/deepspeedai/deepspeed-mii#getting-started-with-mii) guide for more details. For usage and reporting issues, please use the [DeepSpeed-MII Github repository](https://github.com/deepspeedai/DeepSpeed-MII). + +### A. Supported Models + +We currently support the following model architectures in this alpha release of DeepSpeed-FastGen: + +* [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2) +* [Mistral](https://huggingface.co/models?other=mistral) +* [OPT](https://huggingface.co/models?other=opt) +* [Falcon](https://huggingface.co/models?other=falcon) +* [Mixtral](https://huggingface.co/models?other=mixtral) +* [Phi-2](https://huggingface.co/models?other=phi-msft) +* [Phi-3](https://huggingface.co/models?other=phi3) +* [Qwen](https://huggingface.co/models?other=qwen) +* [Qwen2](https://huggingface.co/models?other=qwen2) +* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe) + +All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer. + +We plan to add additional models in the coming weeks and months after the initial release. If there are specific model architectures you would like supported, please [file an issue](https://github.com/deepspeedai/DeepSpeed-MII/issues) and let us know. + +### B. Deployment options +All of the examples below are runnable in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/inference/mii). Once installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment: + +#### Non-persistent pipeline + +The non-persistent pipeline deployment is a great and fast way to get started and can be done with only a few lines of code. Non-persistent models are only around for the duration of the python script you are running but are useful for temporary interactive sessions. + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +#### Persistent deployment + +A persistent deployment is ideal for use with long-running and production applications. The persistent deployment uses a lightweight GRPC server that can be created using the following 2 lines: + + +```python +import mii +mii.serve("mistralai/Mistral-7B-v0.1") +``` + +The above server can be queried by multiple clients at once thanks to the built-in load balancer from DeepSpeed-MII. Creating a client also just takes 2 lines of code: + +```python +client = mii.client("mistralai/Mistral-7B-v0.1") +output = client.generate("Deepspeed is", max_new_tokens=128) +print(output) +``` + +A persistent deployment can be terminated when it is no longer needed: + +```python +client.terminate_server() +``` + +### C. Advanced Installation Information + +For ease of use and a significant reduction in lengthy compile times that many projects require in this space, we distribute a pre-compiled Python wheel covering the majority of our custom kernels through a new library called [DeepSpeed-Kernels](https://github.com/deepspeedai/DeepSpeed-Kernels). We have found this library to be very portable across environments with NVIDIA GPUs with compute capabilities 8.0+ (Ampere+), CUDA 11.6+, and Ubuntu 20+. In most cases, you shouldn't even need to know this library exists as it is a dependency of DeepSpeed-MII and will be installed with it. However, if for whatever reason you need to compile our kernels manually please see our [advanced installation docs](https://github.com/deepspeedai/DeepSpeed-Kernels#source). + + +# 6. Try Out DeepSpeed-FastGen +We are very excited to share this DeepSpeed-FastGen alpha release. + +* To get started, please visit our GitHub page for DeepSpeed-MII: [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeed-MII) + +DeepSpeed-FastGen is part of the bigger DeepSpeed ecosystem comprising a multitude of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* You can also follow us on our [English Twitter](https://twitter.com/DeepSpeedAI), [Japanese Twitter](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, and companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. + +The following items are on our roadmap and we plan to engage with our community on these through our GitHub issues and PRs: + +- Performance improvements +- Broader model support +- New hardware backends through collaboration with partners +- Release performance benchmarks (used to generate plots in this blog) + +**"Star" our [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) and [DeepSpeedMII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) repositories if you like our work!** + +# 7. Acknowledgements + +We would like to thank various open-source community projects including HuggingFace, vLLM, and HuggingFace TGI. We have leveraged HF APIs to support models and tokenizers in our alpha release and will continue to add more models. We especially acknowledge and thank the developers of [Flash Attention](https://github.com/Dao-AILab/flash-attention) for their great work. We have extensively leveraged FlashAttention kernels in our system with modifications that have been acknowledged in our code repositories at appropriate file headers. Finally, we want to thank the developers of [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) kernels that we have used in our MoE kernels (released as part of DeepSpeed-Kernels repository). diff --git a/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png new file mode 100644 index 000000000000..9d4ab55f5f7a Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png new file mode 100644 index 000000000000..89fb9ca3e1ce Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png new file mode 100644 index 000000000000..11c7f82bc54f Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png new file mode 100644 index 000000000000..1b9a38306e8e Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg new file mode 100644 index 000000000000..2affbf8a4cc3 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png new file mode 100644 index 000000000000..9b90357a3f1b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png new file mode 100644 index 000000000000..9e754abde85d Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png new file mode 100755 index 000000000000..6ac1a775805b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png new file mode 100755 index 000000000000..af8f1defe653 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png new file mode 100755 index 000000000000..dde598a985d8 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png new file mode 100755 index 000000000000..bdb5f8df483e Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png new file mode 100644 index 000000000000..6d45880588d9 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png new file mode 100644 index 000000000000..7c14e2bf6e53 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png new file mode 100644 index 000000000000..834c06dfb07a Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png new file mode 100644 index 000000000000..df16b5bebc53 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png new file mode 100644 index 000000000000..8b69a8a1718b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png new file mode 100644 index 000000000000..aaceebde7038 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png new file mode 100644 index 000000000000..cc7b8ec1ec05 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/token_latency.png b/blogs/deepspeed-fastgen/assets/images/token_latency.png new file mode 100644 index 000000000000..405a3c0d06ed Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/token_latency.png differ diff --git a/blogs/deepspeed-fastgen/chinese/README.md b/blogs/deepspeed-fastgen/chinese/README.md new file mode 100644 index 000000000000..9a65be5c4e49 --- /dev/null +++ b/blogs/deepspeed-fastgen/chinese/README.md @@ -0,0 +1,303 @@ +
+ +# DeepSpeed-FastGen:通过 MII 和 DeepSpeed-Inference 实现 LLM 高吞吐量文本生成 + +
+ +
+ + +
+ +## 目录 +1. [引言](#introduction) +2. [关键的 LLM 服务技术](#background) +3. [动态 SplitFuse:一种新颖的提示和生成组合策略](#technical-approach) +4. [性能评估](#performance-evaluation) +5. [DeepSpeed-FastGen:实现与使用](#using-deepspeed-fastgen) +6. [尝试 DeepSpeed-FastGen](#try) +7. [致谢](#acknowledgements) + + +## 1. 引言 + +GPT-4 和 LLaMA 这样的大型语言模型(LLMs)已在各个层次上成为了集成 AI 的主流服务应用。从常规聊天模型到文档摘要,从自动驾驶到各个软件中的Copilot功能,这些模型的部署和服务需求正在迅速增加。像 DeepSpeed、PyTorch 和其他几个框架可以在 LLM 训练期间实现良好的硬件利用率。但它们在与用户互动及处理开放式文本生成等任务时,受限于这些操作的计算密集度相对较低,现有系统往往在推理吞吐量上遇到瓶颈。 + +为了解决这一问题, [vLLM](https://arxiv.org/pdf/2309.06180.pdf) 这样由 PagedAttention 驱动的框架和 [Orca](https://www.usenix.org/system/files/osdi22-yu.pdf) 这样的系统显著提高了 LLM 推理的性能。然而,这些系统在面对长提示的工作负载时,依旧难以提供良好的服务质量。随着越来越多的模型(例如 [MPT-StoryWriter](https://www.mosaicml.com/blog/mpt-7b))和系统(例如[DeepSpeed Ulysses](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses))支持延伸到数万个令牌的上下文窗口,这些长提示工作负载变得越来越重要。为了更好地理解问题,我们在下文中提供了详细的示例来说明 LLM 的文本生成是如何在“提示处理”和“生成”的这两个阶段中工作的。当系统将它们视为不同的阶段时,生成阶段将被提示处理所抢占,这可能会破坏服务级别协议(SLAs)。 + +今天,我们很高兴地介绍 DeepSpeed-FastGen 框架,它通过采用我们提出的动态 SplitFuse 技术,能够提供比vLLM 等先进系统高出多达 2.3 倍的有效吞吐量。DeepSpeed-FastGen 是 DeepSpeed-MII 和 DeepSpeed-Inference 的结合,提供了一个易于使用的服务系统。 + +**快速开始:** 要使用 DeepSpeed-FastGen 只需安装最新的 [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII) 发行版: + +```bash +pip install deepspeed-mii +``` + +要使用简单的非持久性管道部署并生成文本,请运行以下代码。更多详情,请参见[第 5 节](#using-deepspeed-fastgen)。 + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +## 2. 现有 LLM 服务技术 + +单个序列的文本生成工作负载包含两个阶段:1)提示处理,此阶段系统处理用户输入的文本,将其转换成一系列令牌并构建用于注意力机制的键值(KV)缓存;2)生成令牌,即向缓存中添加单个令牌并产生新的令牌。在生成文本序列的过程中,系统将对模型进行多次前向调用以生成完整的文本序列。现有文献和系统中已经提出了两种主要技术,它们解决了这些阶段中可能出现的各种限制和瓶颈。 + +_ 分块 KV 缓存:_ + +vLLM识别出大型单体KV缓存导致的内存碎片化显著降低了大型语言模型服务系统的并发性,并提出了“分页注意力”[Paged Attention](https://arxiv.org/pdf/2309.06180.pdf) 机制来实现非连续KV缓存,并增加整个系统的总吞吐量。此技术采用分页缓存机制,从而提升了系统的整体吞吐量。不同于之前分配各个不同大小的连续内存块的做法,分块 KV 缓存中的底层存储是固定大小的块(也称为页面)。分块 KV 缓存通过消除 KV 缓存引起的内存碎片化,增加了潜在的序列并发量,从而增加了系统吞吐量。非连续 KV 缓存也被 [HuggingFace TGI](https://github.com/huggingface/text-generation-inference) 和 [NVIDIA TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) 等框架所实现。 + +_ 连续批处理:_ + +过去,动态批处理(服务器等待多个请求以同步处理)被用来提高 GPU 利用率。然而,这种方法有缺点,因为它通常需要将输入填充到相同长度或使系统等待以构建更大的批次(batch)。 + +近期大型语言模型(LLM)推理和服务的优化一直专注于细粒度调度和优化内存效率。例如,Orca 提出了 _迭代级调度_(也称为连续批处理),它在模型的每次前向传递时作出独特的调度决策。这允许请求根据需要加入/离开批次,从而消除了填充请求的需要,提高了总体吞吐量。除了 Orca,NVIDIA TRT-LLM、HuggingFace TGI 和 vLLM 也实现了连续批处理。 + +在当前系统中,有两种主要方法来实现连续批处理。在 TGI 和 vLLM 中,生成阶段被抢占以执行提示处理(在 TGI 中称为填充)然后继续生成。在 Orca 中,这些阶段不被区分;相反,只要总序列数没有达到固定限制,Orca 就会将提示加入正在运行的批次中。这两种方法都在不同程度上需要暂停生成以处理长提示(参见[第 3B 节](#splitfuse))。 + +为了解决这些缺点,我们提出了一种新颖的提示和生成组合策略,动态 SplitFuse。 + +## 3. 动态 SplitFuse:一种新颖的提示和生成组合策略 + +类似于现有的框架如 TRT-LLM、TGI 和 vLLM,DeepSpeed-FastGen 的目标是利用连续批处理和非连续 KV 缓存技术,以提升数据中心服务大型语言模型(LLM)的硬件利用率和响应速度。为了实现更高的性能,DeepSpeed-FastGen 提出了 SplitFuse 技术,它利用动态提示和生成分解, 统一来进一步改善连续批处理和系统吞吐量。 + +### A. 三个性能见解 +在描述动态 SplitFuse 之前,我们回答三个关键的性能问题,这些问题解释了SplitFuse背后的逻辑。 + +*__1. 哪些因素影响单个 LLM 的前向传递?__* 为了有效地调度,我们必须首先了解调度过程中应考虑的独立变量有哪些。我们观察到,在前向传递中序列的组成(序列中的批次大小)对性能的影响可以忽略不计。这意味着我们可以围绕单一变量--即前向传递中的令牌数量--构建一个高效的调度器。 + +
+
+
+ +*__2. 模型的吞吐量与前向传递中令牌数量的关系如何?__* 一个 LLM 有两个关键的运行区间,并且过渡相对陡峭。当令牌数量较少时,GPU 的瓶颈是从内存中读取模型,因此吞吐量会随着令牌数量的增加而上升,而当令牌数量很多时,模型的吞吐量受GPU计算能力限制,吞吐量近乎恒定。因此如果我们能将所有前向传递都保持在吞吐量饱和区间,则模型运行效率最高。 + +
+
+
+ +*__3. 如何在多个前向传递中调度一组令牌?__* 我们在上图中观察到,对于对齐良好的输入,令牌吞吐量曲线是凹的,这意味着第二导数必定小于或等于 0。设 $f(x)$ 为给定模型的延迟至吞吐量的凹函数。则对于凹函数 $f(x)$,以下关系成立: + + $$0 \geq \lim_{h \to 0} \frac{f(x + h) - 2f(x) + f(x - h)}{h^2}$$ + + $$0 \geq f(x + h) - 2f(x) + f(x - h)$$ + + $$2f(x) \geq f(x + h) + f(x - h)$$ + +这表明,对于给定的 `2x` 个总令牌来说,最大化吞吐量的方式是将它们均匀分割到两个批次之间。更一般地说,在一个系统中,如果要在 F 个前向传递中处理 P 个令牌,最理想的分区方案是均匀分配它们。 + +### B. 动态分割融合(Dynamic SplitFuse) + +动态分割融合是一种用于提示处理和令牌生成的新型令牌组成策略。DeepSpeed-FastGen 利用动态分割融合策略,通过从提示中取出部分令牌并与生成过程相结合,使得模型可以保持一致的前向传递大小(forward size)。具体来说,动态分割融合执行两个关键行为: + +1. 将长提示分解成更小的块,并在多个前向传递(迭代)中进行调度,只有在最后一个传递中才执行生成。 +2. 短提示将被组合以精确填满目标令牌预算。即使是短提示也可能被分解,以确保预算被精确满足,前向大小(forward sizes)保持良好对齐。 + +动态分割融合(Dynamic SplitFuse)提升了以下性能指标: + +1. **更好的响应性:** 由于长提示不再需要极长的前向传递来处理,模型将提供更低的客户端延迟。在同一时间窗口内执行的前向传递更多。 +2. **更高的效率:** 短提示的融合到更大的令牌预算使模型能够持续运行在高吞吐量状态。 +3. **更低的波动和更好的一致性:** 由于前向传递的大小一致,且前向传递大小是性能的主要决定因素,每个前向传递的延迟比其他系统更加一致。生成频率也是如此,因为DeepSpeed-FastGen不需要像其他先前的系统那样抢占或长时间运行提示,因此延迟会更低。 + +因此,与现有最先进的服务系统相比,DeepSpeed-FastGen 将以允许快速、持续生成的速率消耗来自提示的令牌,同时向系统添加令牌,提高系统利用率,提供更低的延迟和更高的吞吐量流式生成给所有客户端。 + +
+ +
+ + *图 1: 连续批处理策略的示意图。每个块显示一个前向传递的执行。箭头表示前向传递有一个或多个生成的令牌序列。vLLM 在一个前向传递中要么生成令牌,要么处理提示;令牌生成抢占提示处理。Orca 在生成过程中以完整长度处理提示。DeepSpeed-FastGen动态分割融合则执行固定大小批次的动态组合,包括生成和提示令牌。* + +
+ +## 4. 性能评估 + +DeepSpeed-FastGen 利用分块 KV 缓存和动态分割融合连续批处理,提供了最先进的 LLM 服务性能。我们以下述的基准测试方法对 DeepSpeed-FastGen 和 vLLM 在一系列模型和硬件配置上进行评估。 + +### A. 基准测试方法论 + +我们采用两种主要的定量方法来衡量性能。 + +**吞吐量-延迟曲线:** 生产环境的两个关键指标是吞吐量(以每秒请求计)和延迟(每个请求的响应性)。为了衡量这一点,我们模拟了多个客户端(数量从 1 到 32 不等)同时向服务器发送请求(总计 512 个)的情况。每个请求的结果延迟在端点测量,吞吐量通过完成实验的端到端时间来测量。 + +**有效吞吐量:** 诸如聊天应用程序之类的交互式应用程序可能有比上述指标(如端到端延迟)更严格和复杂的要求。以越来越受欢迎的聊天应用为例: + + 1. 用户通过发送提示(输入)来开始对话。 + 2. 系统处理提示并返回第一个令牌。 + 3. 随着生成的进行,后续令牌被流式传输给用户。 + +在这个过程的每个阶段,系统都有可能提供不利的用户体验;例如,第一个令牌到达得太慢;或生成似乎停止了一段时间。我们提出了一个考虑这两个维度的 SLA 框架。 + +由于提示和生成文本的长度差异很大,影响计算成本,因此设定同一个 SLA 值对于吞吐量和延迟是不切实际的。因此,我们将提示延迟的 SLA 定义为 “|提示中的令牌|/512” 秒(= 512 令牌/秒)。此外,考虑到人类的阅读速度,我们将生成延迟的 SLA 设置在指数移动平均(EMA)上为 2、4 或 6 令牌/秒。能够达到这些 SLA 的请求被认为是成功的,这些成功请求的吞吐量被称为**有效吞吐量**。 + +我们通过在 NVIDIA A100、H100 和 A6000 上运行 Llama-2 7B、Llama-2 13B 和 Llama-2 70B 对 vLLM 和 DeepSpeed-FastGen进行了评估。 + +### B. 吞吐量-延迟分析 + +在这个实验中,DeepSpeed-FastGen 在吞吐量和延迟方面都优于 vLLM,在相同的延迟下DeepSpeed-FastGen的吞吐量更大;在相同的吞吐量下DeepSpeed-FastGen的响应延迟更小。如图 2 所示,在 Llama-2 70B 运行于 4 个 A100x80GB 的情况下,DeepSpeed-FastGen 展示了高达 2 倍的吞吐量(1.36 rps 对比 0.67 rps)在相同的延迟(9 秒)下;或高达 50% 的延迟减少(7 秒对比 14 秒)同时实现相同的吞吐量(1.2 rps)。评估 Llama-2 13B 时DeepSpeed-FastGen也呈现了这些趋势,如图 3 所示。 + +
+
+ + *图 2: 使用 Llama 2 70B 进行文本生成的吞吐量和延迟(使用 4 个 A100-80GB GPU 的张量并行)。提示和生成长度遵循正态分布,平均值分别为 1200/2600 和 128/60,方差为 30%* +

+ +
+
+ + *图 3: 使用 Llama 2 13B 进行文本生成的吞吐量和延迟(A100-80GB GPU,无张量并行)。提示和生成长度遵循正态分布,平均值分别为 1200/2600 和 60/128,并且有 30% 的方差* +
+ +### C. 有效吞吐量分析 + +在考虑了首个令牌的延迟和生成速率的有效吞吐量分析下,DeepSpeed-FastGen 提供的吞吐量比 vLLM 高出多达 2.3 倍。图 4 展示了 DeepSpeed-FastGen 和 vLLM 的有效吞吐量的比较分析。每个绘制的点表示从特定数量的客户端得出的有效吞吐量。当我们扩大客户端数量时,我们最初观察到有效吞吐量的增加。然而,当客户端数量接近系统容量时,延迟也显著增加,导致许多请求未能满足 SLA。因此,有效吞吐量将在某个点上饱和或减少。从可用性角度来看,达到最大有效吞吐量所需的客户端数量并不特别重要;线条的最高点是最优的服务点。 + +
+ + + *图 4: DeepSpeed-FastGen 和 vLLM 的有效吞吐量(Llama 2 70B/A100-80GB 使用张量并行在 4 个 A100-80GB GPU 上。提示和生成长度遵循正态分布,平均值分别为 2600 和 60,并且有 30% 的方差)* +

+ +当 vLLM 抢占正在进行的先前请求的生成时,生成延迟会明显增加。这导致 vLLM 的有效吞吐量看起来低于其直接测量的吞吐量。在 vLLM 的峰值时,有效吞吐量为 0.63 查询/秒,大约 28% 的请求未能满足 4 令牌/秒的 SLA。在相同的 SLA 下,DeepSpeed-FastGen 达到了 1.42 查询/秒(不到 1% 的请求未能满足 SLA),这是 vLLM 的 2.3 倍。 + +### D. 令牌级时间分析 + +图 5 显示了生成过程的 P50、P90 和 P95 延迟。vLLM 和 DeepSpeed-FastGen 展示了类似的 P50 延迟,但 vLLM 的 P90 和 P95 延迟显著更高。 + +这种差异是由于 vLLM 在抢占正在进行的生成以处理新提示时,生成延迟出现显著增加所导致的。 +相比之下,DeepSpeed-FastGen 通常会同时处理之前请求的提示和生成,导致生成延迟更加一致。 + +
+
+ + *图 5: 使用张量并行在 4 个 A100-80GB GPU 上的 Llama 2 70B/A100-80GB 的每令牌生成延迟,16 客户端。提示和生成长度遵循正态分布,平均值分别为 2600 和 128,并且有 30% 的方差。 +

+ + +### E. 使用负载均衡的可扩展性 + +DeepSpeed-FastGen 提供了副本级负载均衡,可以将请求均匀分布在多个服务器上,让您轻松扩展应用程序。 + +图 6 展示了 DeepSpeed-FastGen 在使用负载均衡器和最多 16 个副本时的可扩展性。请注意,我们使用了 4 个 A100 GPU 来计算每个 Llama 2 70B 模型。总共,我们使用了 8 个节点来运行 16 个副本。结果展示了 DeepSpeed-FastGen 几乎完美的可扩展性。 +单个副本时DeepSpeed-FastGen的吞吐量为 1.46 查询/秒,而16 个副本的吞吐量达到了 23.7 查询/秒,与单个副本相比标志着线性的 16 倍增长。 + +
+
+ + *图 6: 使用负载均衡功能的可扩展性。提示和生成长度遵循正态分布,平均值分别为 2600 和 60,并且有 30% 的方差*
+
+ +### F. 其他硬件平台 + +除了对 A100 的深入分析,我们还提供了 H100 和 A6000 的基准测试结果。在 A6000 和 H100 上观察到的性能趋势与 A100 相同。 + +
+
+ + *图 7: 使用 8 个 H100 GPU 的 Llama 2 70b 的吞吐量-延迟曲线和有效吞吐量。提示和生成长度遵循正态分布,平均值分别为 2600 和 60,并且有 30% 的方差*
+
+ +
+
+ + *图 8: 使用 A6000 的 Llama 2 7b 的吞吐量-延迟曲线和有效吞吐量。提示和生成长度遵循正态分布,平均值分别为 2600 和 60,并且有 30% 的方差*
+
+ +## 5. DeepSpeed-FastGen:软件实现与使用指南 + +DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII) 和 [DeepSpeed-Inference](https://github.com/deepspeedai/DeepSpeed) 的协同组合,如下图所示。这两个软件包共同提供了系统的各个组成部分,包括前端 API、用于使用动态 SplitFuse 调度批次的主机和设备基础设施、优化的内核实现,以及构建新模型实现的工具。 + +
+ + +
+ + +使用我们的 alpha 版 DeepSpeed-FastGen 最快的入门方式是:`pip install deepspeed-mii`。 + +请按照我们的 [入门指南](https://github.com/deepspeedai/deepspeed-mii#getting-started-with-mii) 获取更多细节。如需使用和报告问题,请使用 [DeepSpeed-MII Github 仓库](https://github.com/deepspeedai/DeepSpeed-MII)。 + +### A. 支持的模型 + +在 DeepSpeed-FastGen 的当前 alpha 版本中,我们目前支持以下模型架构: + +* [LLaMA](https://huggingface.co/models?other=llama) 和 [LLaMA-2](https://huggingface.co/models?other=llama-2) +* [Mistral](https://huggingface.co/models?other=mistral) +* [OPT](https://huggingface.co/models?other=opt) +* [Falcon](https://huggingface.co/models?other=falcon) +* [Mixtral](https://huggingface.co/models?other=mixtral) +* [Phi-2](https://huggingface.co/models?other=phi-msft) +* [Qwen](https://huggingface.co/models?other=qwen) + +所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。 + +> 我们计划在最初发布后的几周和几个月内添加更多模型。如果您希望支持特定的模型架构,请[提交问题](https://github.com/deepspeedai/DeepSpeed-MII/issues)来让我们知道。 + +### B. 部署选项 +以下所有示例均可在 [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/inference/mii) 中运行。安装后,您有两种部署方式:交互式非持久管道或持久化服务部署: + +#### 非持久管道 + +非持久管道部署是快速入门的好方法,只需几行代码即可完成。非持久模型只在您运行的 python 脚本期间存在,适用于临时交互式会话。 + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +#### 持久部署 + +持久部署非常适合用于长时间运行和生产的应用。持久部署使用了轻量级的 GRPC 服务器,可以使用以下两行代码创建: + +```python +import mii +mii.serve("mistralai/Mistral-7B-v0.1") +``` + +上述服务器可以同时被多个客户端查询,这要归功于 DeepSpeed-MII 内置的负载平衡器。创建客户端也只需要两行代码: + +```python +client = mii.client("mistralai/Mistral-7B-v0.1") +output = client.generate("Deepspeed is", max_new_tokens=128) +print(output) +``` + +持久部署可以在不再需要时终止: + +```python +client.terminate_server() +``` + +### C. 高级安装方式 + +为了使用方便并显著减少许多其他框架所需的冗长编译时间,我们通过名为 [DeepSpeed-Kernels](https://github.com/deepspeedai/DeepSpeed-Kernels) 的新库分发了覆盖我们大部分自定义内核的预编译 Python wheel。我们发现这个库在环境中非常便携,只要这些环境具有 NVIDIA GPU 计算能力 8.0+(Ampere+)、CUDA 11.6+ 和 Ubuntu 20+。在大多数情况下,您甚至不需要知道这个库的存在,因为它是 DeepSpeed-MII 的依赖项,并将自动与之一起安装。然而,如果您因任何原因需要手动编译我们的内核,请参阅我们的[高级安装文档](https://github.com/deepspeedai/DeepSpeed-Kernels#source)。 + + +# 6. 尝试 DeepSpeed-FastGen +我们非常高兴分享 DeepSpeed-FastGen 的首个 alpha 版本。 + +* 要开始,请访问我们的 DeepSpeed-MII GitHub 页面: [GitHub 登陆页面](https://github.com/deepspeedai/DeepSpeed-MII) + +DeepSpeed-FastGen 是更大的 DeepSpeed 生态系统的一部分,该生态系统包含了多种深度学习系统和建模技术。要了解更多, + +* 请访问我们的[网站](https://www.deepspeed.ai/),详细查看博客文章、教程和有用的文档。 +* 您也可以通过我们的[英文 Twitter](https://twitter.com/DeepSpeedAI)、[日本 Twitter](https://twitter.com/DeepSpeedAI_JP) 和[中文知乎](https://www.zhihu.com/people/deepspeed) 关注我们,以获取 DeepSpeed 的最新消息。 + +DeepSpeed 欢迎您的贡献!我们鼓励您在 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 页面上报告问题、贡献 PR,并参与讨论。有关更多详细信息,请参见我们的[贡献指南](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)。我们愿意与大学、研究实验室和公司合作,比如那些在深度学习研究上共同工作,应用 DeepSpeed 来赋能真实世界的 AI 模型和应用等。对于那些不适合在 GitHub 上提出的请求(以及其他请求),请直接发送电子邮件至 info@deepspeed.ai。 + +以下项目在我们的路线图上,我们计划通过我们的 GitHub 问题和 PR 与我们的社区在这些项目上进行交流: + +- 性能改进 +- 更广泛的模型支持 +- 通过与合作伙伴的合作支持新硬件后端 +- 发布性能测试套件(例如此博客中生成的图表) + +如果您喜欢我们的工作,请为我们的 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 和 [DeepSpeedMII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) 仓库打上“星标”! + +# 7. 致谢 + +我们要对包括 HuggingFace、vLLM 和 HuggingFace TGI 在内的多个开源社区项目表示感谢。在 alpha 版本中, 我们利用 HF API 来调用模型和分词器,并计划未来添加更多模型。我们特别感谢 [Flash Attention](https://github.com/Dao-AILab/flash-attention) 开发者的出色工作。我们在系统中广泛利用了 FlashAttention 内核,并已经在我们的代码库的对应的文件头部进行了致谢。最后,我们要感谢我们在 MoE 内核(作为 DeepSpeed-Kernels 仓库的一部分发布)中使用的 [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) 内核的开发者。 diff --git a/blogs/deepspeed-fastgen/japanese/README.md b/blogs/deepspeed-fastgen/japanese/README.md new file mode 100644 index 000000000000..b90b3be92a45 --- /dev/null +++ b/blogs/deepspeed-fastgen/japanese/README.md @@ -0,0 +1,315 @@ +
+ +# DeepSpeed-FastGen: MIIとDeepSpeed-InferenceによるLLMのための高速なテキスト生成 + +
+ +
+ + +
+ +## Table of Contents +1. [概要](#introduction) +2. [LLMのためのテキスト生成の既存技術](#background) +3. [Dynamic SplitFuse: プロンプト処理と生成を組み合わせる新しいアプローチ](#technical-approach) +4. [パフォーマンス評価](#performance-evaluation) +5. [DeepSpeed-FastGen: 実装と使い方](#using-deepspeed-fastgen) +6. [DeepSpeed-FastGenを使ってみる](#try) +7. [謝辞](#acknowledgements) + + +## 1. 概要 + +AIを様々な目的に利用する幅広いアプリケーションで、GPT-4やLLaMAのような大規模言語モデル(LLM)が、主要なワークロードになってきています。一般的なチャットモデルから、文書の要約、自動運転、ソフトウェアスタックの各層におけるプログラミングの補助まで、これらのモデルを大規模に展開・提供する需要が急増しています。DeepSpeedやPyTorchをはじめとするフレームワークは、一般に、LLMの訓練では良好なハードウェアの利用効率を達成できるものの、オープンエンドのテキスト生成などの課題では、GPUなどのハードウェア上で一度に実行される計算量が少ないことが、既存システムにおいて推論スループットのボトルネックとなっています。 + +PagedAttentionを搭載した [vLLM](https://arxiv.org/pdf/2309.06180.pdf) や [Orca](https://www.usenix.org/system/files/osdi22-yu.pdf) のような既存システムは、こうした課題を解決するために設計され、LLMの推論性能を大幅に向上させました。しかしこれらのシステムは依然として、特に長いプロンプトを含むワークロードにおいて、一貫したサービス品質の提供という点で課題を残しています。 +数千トークンに及ぶコンテキストウィンドウをサポートするモデルやシステム、例えば [MPT-StoryWriter](https://www.mosaicml.com/blog/mpt-7b) や [DeepSpeed Ulysses](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses) などが増えるにつれて、これらの長いプロンプトのワークロードはますます重要になってきています。 +これらの問題をより深く理解するために、LLMによるテキスト生成がどのように機能するか説明します。LLMによるテキスト生成は、プロンプト処理と生成と呼ばれる2つの異なるフェーズから構成されます。システムがこれらを全く独立に扱うと、生成のフェーズは、プロンプト処理によって中断されることになります。その結果、システムのレイテンシなどを定めた SLA (Service Level Agreement) に違反する可能性が高くなります。 + +このブログで紹介するDeepSpeed-FastGenは、新たに提案するDynamic SplitFuse技術などを活用することでこうした課題を解決し、vLLMなどの最新の既存システムと比較して最大2.3倍の実効スループットを実現するシステムです。 +DeepSpeed-FastGenは、DeepSpeed-MIIとDeepSpeed-Inferenceの組み合わせにより、使いやすいテキスト生成機能を実現します。 + + +**クイックスタート:** 最新の[DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII)をインストールするだけで、 DeepSpeed-FastGenを試すことができます。 + + +```bash +pip install deepspeed-mii +``` + +より簡単に利用できる、非永続型(推論サーバを起動しない)のパイプラインを使用してテキストを生成するには、次のコードを実行します。詳細については、[セクション5](#using-deepspeed-fastgen) をご覧ください。 + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +## 2. LLMのためのテキスト生成の既存技術 + +テキスト系列を生成するためのワークロードは、次の2つのフェーズで構成されます。 1. プロンプト処理: ここでユーザーが与えたテキストは、アテンション機構におけるキーとバリューのキャッシュ(KVキャッシュ)を構築するために、トークンのバッチとして効率的に処理されます。 2. トークン生成: このフェーズで、KVキャッシュに単一のトークンが追加され、新たなトークンが生成されます。テキスト系列を生成する過程では、モデルは完全なテキストの系列を生成するために多くのフォワードパスの呼び出しを行います。これらのフェーズにおける様々な制限やボトルネックを解決するため、既存システムでは従来提案されてきた以下の2つの主要な技術が採用されています。 + +_ ブロックKVキャッシュ: _ + +vLLMは、KVキャッシュにモノリシックの巨大なメモリ領域を割り当てることが、LLMによるテキスト生成システムの同時実行性を大幅に低下させる原因であるとし、その解決として、非連続的に確保されたメモリ領域をKVキャッシュとして利用することで、システム全体のスループットを増加させる [Paged Attention](https://arxiv.org/pdf/2309.06180.pdf) を提案しました。リクエストごとに様々なサイズの連続メモリ領域を割り当てるのではなく、固定されたサイズのメモリブロック(ページとも呼ばれる)を割り当てるようにします。このブロックKVキャッシュは、KVキャッシュによるメモリ断片化を解決することで、潜在的に処理可能な系列の同時実行数を増やし、システムのスループットを増加させます。こうした非連続KVキャッシュの実装は、[HuggingFace TGI](https://github.com/huggingface/text-generation-inference) と [NVIDIA TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) にも含まれています。 + + +_ 連続バッチ(Continuous Batching): _ + +従来は、サーバーが複数のリクエストを一緒に処理するために待つという 動的バッチ(Dynamic Batching)が GPU利用率を改善するために使用されていました。しかし、このアプローチには欠点があります。通常は入力を同一の長さにパディングするか、より大きなバッチを構築するために、十分な数のリクエストが到着するまで処理を止めて待つ必要があります。 + +最近の大規模言語モデル(LLM)の推論と、それをサービスとして提供するための技術は、より細かい粒度でのスケジューリングとメモリ効率の最適化に焦点を当てています。例えば、Orcaは _イテレーションレベルのスケジューリング_ (連続バッチまたは Continuous Batching とも呼ばれる)を提案しており、これはモデルの各フォワードパスごとにスケジューリングの判断を行います。これにより、必要に応じてあるリクエストをバッチに含めたり除いたりすることができるため、パディングが不要になり、全体のスループットを向上させます。この連続バッチは、Orcaだけでなく、NVIDIAのTRT-LLM、HuggingFaceのTGI、およびvLLMにも実装されています。 + +現在のシステムでは、連続バッチ処理を実装するには二つの主要なアプローチがあります。TGIとvLLMでは、生成フェーズが中断されてプロンプト処理(TGIではインフィルと呼ばれる)が行われ、その後で生成を続けます。Orcaでは、これらのフェーズは区別されず、代わりにシーケンスの総数が一定の制限に達しない限り、実行中のバッチにプロンプトを追加します。これらのアプローチは、長いプロンプトを処理するために生成を一時停止する必要があるという点で、程度の差こそあれ似ています([セクション3B](#splitfuse)参照)。 + + +これらの課題に対処するために、私たちはDynamic SplitFuseと呼ばれる、プロンプト処理と生成を組み合わせる新しい手法を提案します。 + + +## 3. Dynamic SplitFuse: プロンプト処理と生成を組み合わせる新しいアプローチ + +DeepSpeed-FastGenは、データセンターでのLLMの提供において、TRT-LLM、TGI、vLLMなどの既存のフレームワークと同様に、連続バッチと非連続なKVキャッシュを活用して、より高い占有率と応答性を実現するために開発されました。より高いレベルのパフォーマンスを実現するために、DeepSpeed-FastGenはSplitFuseを導入し、動的にプロンプトの分解し、生成と組み合わせることで、連続バッチとシステムスループットをさらに改善します。 + + +### A. パフォーマンスに関する三つの知見 + +Dynamic SplitFuseについて説明する前に、その設計を動機付ける三つの重要なパフォーマンスに関する質問とその回答を示します。 + +*__1. 単一のLLMのフォワードパスに影響を与える要因は何ですか?__* 効果的にスケジューリングを行うためには、反復的に実行されるスケジューリングで制御すべき、関連する独立変数が何であるかを理解することが必要です。我々は以下に示すように、フォワードパス内のシーケンスの構成(シーケンスでのバッチサイズ)がフォワードパスのトークンの生数に比べてパフォーマンスにほとんど影響を与えないことを観察しました。これは、効果的なスケジューラを構築するには、主にフォワードパスのトークン数という単一の要素のみに注目すればよいことを意味しています。 + +
+
+
+ +*__2. フォワードパスのトークン数の変化に対して、モデルのスループットはどのように反応しますか?__* LLMには比較的急に振る舞いが変化する、二つの主要な動作領域があります。トークン数が少ない場合、GPUのボトルネックはメモリからのモデルの読み出しであるため、スループットはトークン数に応じてスケールしますが、トークンが多い場合はモデルのスループットは計算によって制限され、ほぼ一定のスループットを示します。効率的な実行のために、すべてのフォワードパスが、スループットが飽和するような領域で実行されるのが望ましいと言えます。 + +
+
+
+ +*__3. トークンのプールは複数のフォワードパスにどのようにスケジュールされるべきですか?__* 上記で述べたように、入力が適切に整列している場合、トークンのスループット曲線は凹形であり、これは二次導関数が0以下であることを意味します。例として、あるモデルの遅延からスループットへの凹関数を $f(x)$ としましょう。凹関数 $f(x)$ に対しては、以下が成り立ちます: + + $$0 \geq \lim_{h \to 0} \frac{f(x + h) - 2f(x) + f(x - h)}{h^2}$$ + + $$0 \geq f(x + h) - 2f(x) + f(x - h)$$ + + $$2f(x) \geq f(x + h) + f(x - h)$$ + +これは、処理する `2x` トークンのプールに対して、スループットを最大化する方法は、それらを二つのバッチに均等に分割することであると述べています。より一般的には、`P` トークンを `F` 回のフォワードパスで処理する必要があるシステムでは、理想的な分割スキームはそれらを均等に分割するものになります。 + +### B. Dynamic SplitFuse + +Dynamic SplitFuseは、プロンプト処理とトークン生成を組み合わせるための新しいアプローチです。DeepSpeed-FastGenは、プロンプトからの一部のトークンを取り出し、これを生成と組み合わせることで、一貫したフォワードサイズで実行するためにDynamic SplitFuseを利用します。Dynamic SplitFuseは以下の2つの主要な動作からなります: + +1. 長いプロンプトは、はるかに小さなチャンクに分解され、複数のフォワードパス(イテレーション)にわたってスケジュールされます。生成は、最後のフォワードパスでのみ実行されます。 +2. 短いプロンプトは、フォワードパスのための目標トークン数を正確に満たすようにスケジュールされます。短いプロンプトであっても、フォワードパスに与える目標のトークン数を正確に満たし、複数のフォワードパスでトークン数が均等になるように分解されることがあります。 + +これら2つの技術を組み合わせることで、以下のすべてのユーザー指標において、具体的な利点が得られます: + +1. *__より良い応答性__*: 長いプロンプトによりフォワードパスで非常に長い時間がかかることがなくなり、モデルはクライアントから見てより低いレイテンシが実現できます。これは、同じ時間枠内でより多くのフォワードパスが実行されていることになります。 +2. *__高い効率__*: 短いプロンプトを、その他のリクエストのトークンと一緒に実行することで、モデルは一貫して高スループットで動作します。 +3. *__レイテンシ変動の減少と一貫性の向上__*: 1回のフォワードパスに与えるトークン数の変動が少なくなります。フォワードパスに与えるトークン数がパフォーマンスの主要な決定要因であるため、各フォワードパスのレイテンシは競合するシステムよりもはるかに一貫したものとなります。他の先行研究のように、プリエンプションや長時間実行されるプロンプトによって遅延が増加することはありません。 + +結果として、DeepSpeed-FastGenは、システムの利用率を高めるためにトークンをフォワードパスに加えていくことで、到着するリクエストのプロンプト処理を、進行中の生成フェーズを高速に実行しながら行えます。これにより、 +他の最先端のテキスト生成システムと比較して、すべてのクライアントに対してより低レイテンシかつ高スループットのストリーミング生成を実現できます。 + + +
+ +
+ +*図1: 連続バッチ処理戦略のイラスト。各ブロックはフォワードパスの実行を示しています。矢印は、1つ以上のトークンが生成されたシーケンスを持つフォワードパスを示しています。vLLMはフォワードパスでトークン生成またはプロンプト処理のいずれかを実行し、トークン生成はプロンプト処理をプリエンプトします。Orcaは生成と同時に完全な長さのプロンプトを実行します。Dynamic SplitFuseは、生成トークンとプロンプトトークンの両方で構成された固定サイズのバッチの動的構成を実行します。* +
+ +## 4. パフォーマンス評価 + +DeepSpeed-FastGenは、ブロックKVキャッシュとDynamic SplitFuseのcontinuous batchingを活用し、最先端のLLMサービング性能を提供します。我々は、以下で議論されるベンチマーク手法に従って、さまざまなモデルとハードウェア構成でDeepSpeed-FastGenとvLLMを評価します。 + +### A. ベンチマーク手法 + +パフォーマンスを測定するために、我々は2つの主要な定量的スキームを使用します。 + +**スループット-レイテンシカーブ**: 実サービス利用のための2つの主要な指標は、スループット(秒間リクエスト数で測定)とレイテンシ(各リクエストの応答性)です。これを測定するために、我々は複数のクライアント(1から32まで)を同時に起動し、サーバーにリクエスト(合計512)を送信します。各リクエストの結果としてのレイテンシは各リクエストの単位で測定され、スループットは実験を完了するためのエンドツーエンドの時間で測定されます。 + +**実効スループット**: チャットアプリケーションのようなインタラクティブなアプリケーションは、エンドツーエンドのレイテンシのようなトップレベルの指標では捉えきれない、より厳格で複雑な要件を持っている場合があります。特にここでは、急速に広く使われつつあるチャットアプリケーションのユーザシナリオに焦点を当てます: + +1. ユーザーがプロンプトを送信してタスクを開始します。 +2. システムがプロンプトを処理し、最初のトークンを返します。 +3. 続くトークンは、生成されると同時に、ユーザーにストリーミングで送信されます。 + +このプロセスの各ポイントで、ユーザーにとって望ましくない体験になる可能性があります。例えば、最初のトークンが遅すぎる場合や、生成がしばらくの間停止するように見える場合です。我々は、これらの2つの観点を考慮に入れたSLAのフレームワークを提案します。 + +プロンプトと生成されたテキストの長さには、非常に広い幅があり、またそれが計算コストに影響を与えるため、スループットとレイテンシに厳格なSLA値を設定することは非現実的です。したがって、我々はプロンプトのレイテンシのSLAをプロンプト内の|トークン数| / 512秒(= 512トークン/秒)と定義します。さらに、人間の読む速度を考慮して、生成レイテンシのSLAを、指数移動平均(EMA)で秒間2、4、または6トークンに設定します。これらのSLAを満たすリクエストは成功と見なし、これらの成功したリクエストのスループットを **実効スループット** とします。 + +我々は、NVIDIA A100、H100、およびA6000上のLlama-2 7B、Llama-2 13B、およびLlama-2 70BでvLLMとDeepSpeed-FastGenを評価しました。 + +### B. スループット・レイテンシ分析 + +この実験では、DeepSpeed-FastGenは、vLLMをスループットとレイテンシの両方で上回り、同じスループットでより低レイテンシを提供するか、あるいはより高スループットで同じレイテンシを提供します。4台の A100 GPU(メモリ80GB)とLlama-2 70Bを使用したテキスト生成では、DeepSpeed-FastGenは同じレイテンシ(9秒)で2倍高いスループット(それぞれ1.36 rpsと0.67 rps)を示すか、同じスループット(1.2 rps)を達成しながら最大50%のレイテンシ削減(それぞれ7秒と14秒)を実現します。この結果は図2に示されています。またこの傾向は、図3に示されるLlama-2 13Bでの評価でも同様です。 + + +
+
+ + *図2: テキスト生成のスループットとレイテンシ(4台のA100-80GB GPUでのテンソル並列を使用したLlama 2 70B)。プロンプトと生成の長さは、平均1200/2600と128/60の正規分布(30%の分散)に基づいて設定。* +

+ +
+
+ + *図3: テキスト生成のスループットとレイテンシ(1台のA100-80GB GPUでのテンソル並列なしでのLlama 2 13B)。プロンプトと生成の長さは、平均1200/2600と60/128の正規分布の正規分布(30%の分散)に基づいて設定。* +
+ +### C. 実効スループット分析 + +最初のトークンのレイテンシと、生成が行われる速度の両方を考慮した実効スループットにおいて、DeepSpeed-FastGenはvLLMに比べて最大2.3倍の性能を示しています。図4はDeepSpeed-FastGenとvLLMの実効スループットの比較分析を示しています。プロットされたそれぞれの点は、特定のクライアント数で得られた実効スループットを表します。クライアント数を増やすと初めは実効スループットが増加することが観察されました。しかし、クライアント数がシステムの容量に近づくとレイテンシも大幅に増加し、多くのリクエストがSLAを満たすことができなくなります。その結果、実効スループットはいずれかのポイントを上限として、その後減少します。使用性の観点から、最大実効スループットを達成するために必要なクライアント数は特に重要ではありません。ラインの最高点が、サービス提供における最適な点になります。 + +
+ + + *図4: DeepSpeed-FastGenとvLLMの実効スループット。Llama 2 70B/A100-80GBを使用し、4台のA100-80GB GPU間でテンソル並列を使用。プロンプトと生成の長さは、それぞれ平均2600と60の正規分布(30%の分散)に基づいて設定。* +

+ +vLLMが、新たなプロンプトを処理するために進行中の前のリクエストの生成を中断すると、生成のレイテンシは顕著に増加します。これにより、vLLMの実効スループットは直接測定されたスループットよりも低く見えます。vLLMのピーク時、実効スループットは0.63クエリ/秒であり、リクエストの約28%が4トークン/秒のSLAを満たすことができませんでした。同じSLAで、DeepSpeed-FastGenは1.42クエリ/秒(SLAを満たさなかったリクエストは1%未満)を達成し、これはvLLMの2.3倍です。 + +### D. トークン単位のレイテンシ分析 + +図5は生成プロセスのP50、P90、P95のレイテンシを表示しています。vLLMとDeepSpeed-FastGenを比べると、P50レイテンシに大きな違いはありませんが、vLLMはP90とP95で著しく高いレイテンシを示しています。 +P95レイテンシに関しては、DeepSpeed-FastGenは3.7倍の削減を達成しています。 + +この差異は、vLLMが進行中の生成を中断して新しいプロンプトを処理する際に、生成レイテンシに顕著なスパイクが生じるためです。 +対照的に、DeepSpeed-FastGenは通常、前のリクエストのプロンプトと生成を同時に処理するため、はるかに一貫した生成のレイテンシを実現します。 + +
+
+ + *図5: トークンごとの生成レイテンシ。Llama 2 70B/A100-80GBを使用し、4台のA100-80GB GPU間でテンソル並列を使用。クライアント数16。プロンプトと生成の長さは、それぞれ平均2600と128の正規分布(30%の分散)に基づいて設定。* +

+ + +### E. ロードバランシングを使用したスケーラビリティ +DeepSpeed-FastGenはレプリカ単位のロードバランシングの機能を備えており、複数のサーバーにリクエストを均等に分散させることで、アプリケーションを簡単にスケールアップすることができます。 + +図6は、ロードバランサーを使用し、最大16のレプリカを適用したときのDeepSpeed-FastGenのスケーラビリティを示しています。Llama 2 70Bモデルの計算には、レプリカ一つあたりで、4台のA100 GPUを使用しました。合計で16のレプリカを実行するために8ノードを使用しました。その結果はDeepSpeed-FastGenのほぼ完璧なスケーラビリティを示しています。1つのレプリカのスループットが1.46クエリ/秒である場合、16のレプリカでのスループットは23.7クエリ/秒に達し、1つのレプリカに比べて16倍の線形増加を示しています。 + +
+
+ + *図6: ロードバランシング機能を使用したスケーラビリティ。プロンプトと生成の長さは、それぞれ平均2600と60の正規分布(30%の分散)に基づいて設定。* +
+ +### F. 他のハードウェアプラットフォーム + +A100 GPUを用いた分析に加えて、H100とA6000を使用したベンチマーク結果を提供します。A6000とH100の両方で、A100と同様のパフォーマンスの傾向が観察されました。 + +
+
+ + *図7: 8つのH100 GPUを使用したLlama 2 70bのスループット・レイテンシカーブと実効スループット。プロンプトと生成の長さは、それぞれ平均2600と60の正規分布(30%の分散)に基づいて設定。* +
+ +
+
+ + *図8: A6000を使用したLlama 2 7bのスループット・レイテンシカーブと実効スループット。プロンプトと生成の長さは、それぞれ平均2600と60の正規分布(30%の分散)に基づいて設定。* +
+ +## 5. DeepSpeed-FastGen: 実装と使い方 + +DeepSpeed-FastGenは、以下の図に示されているように、[DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII)と[DeepSpeed-Inference](https://github.com/deepspeedai/DeepSpeed)を融合的に組み合わせたものです。これらのソフトウェアパッケージは、フロントエンドAPI、Dynamic SplitFuseを使用してバッチをスケジュールするホストおよびデバイスインフラストラクチャ、最適化されたカーネル実装、新しいモデル実装を構築するためのツールなど、システムの様々なコンポーネントを提供します。 + + +
+ + +
+ +DeepSpeed-FastGenのアルファリリースを使い始める最も簡単な方法は、 ``pip install deepspeed-mii`` を実行することです。 + +詳細については、[Getting Started](https://github.com/deepspeedai/deepspeed-mii#getting-started-with-mii)ガイドを参照してください。使用法や問題の報告には、[DeepSpeed-MII Github リポジトリ](https://github.com/deepspeedai/DeepSpeed-MII)を使用してください。 + +### A. 対応モデル + +現在、DeepSpeed-FastGenのこのアルファリリースでは、以下のモデルアーキテクチャをサポートしています: + +* [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2) +* [Mistral](https://huggingface.co/models?other=mistral) +* [OPT](https://huggingface.co/models?other=opt) + +現在のすべてのモデルは、モデルの重みとモデルに対応するトークナイザーの両方を提供するために、バックエンドで [HuggingFace](https://github.com/huggingface) を利用しています。 + +初期リリース後の数週間と数ヶ月に追加のモデルを追加する予定です。サポートを希望する特定のモデルアーキテクチャがある場合は、[issue](https://github.com/deepspeedai/DeepSpeed-MII/issues) を登録してください。。 + +### B. デプロイメントのオプション + +以下の例はすべて [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/inference/mii) で実行可能です。インストール後、デプロイメントのオプションとして、対話型の非永続パイプラインまたは永続的なサービス提供デプロイメントの2つのオプションがあります。 + +#### 非永続パイプライン + +非永続パイプラインデプロイメントは、非常に簡単に使い始めることができ、わずか数行のコードで実行可能です。 +非永続モデルは、Pythonスクリプトの実行中だけ起動しますが、一時的な対話型セッションには便利です。 + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +#### 永続デプロイメント + +永続デプロイメントは、長時間実行されるアプリケーションや本番アプリケーションに使用するためのものです。永続デプロイメントでは、以下の2行を使用して軽量なGRPCサーバーを起動できます。 + +```python +import mii +mii.serve("mistralai/Mistral-7B-v0.1") +``` + +上記のサーバーは、DeepSpeed-MIIの組み込みロードバランサーのおかげで、複数のクライアントから一度にクエリを受け取ることができます。クライアントも、以下の2行のコードだけで利用できます: + +```python +client = mii.client("mistralai/Mistral-7B-v0.1") +output = client.generate("Deepspeed is", max_new_tokens=128) +print(output) +``` + +永続デプロイメントは、必要なくなったときに、以下の方法で終了できます: + +```python +client.terminate_server() +``` + +### C. インストールの詳細情報 + +類似の他のプロジェクトでは、カスタムカーネルのコンパイルに非常に時間がかかることがよくあります。 +DeepSpeed-FastGenでは、このコンパイル時間を大幅に短縮し、利便性を向上するため、主要なカスタムカーネルの大部分を事前コンパイルしたPython wheelを、[DeepSpeed-Kernels](https://github.com/deepspeedai/DeepSpeed-Kernels)という新しいライブラリを通じて配布しています。 +このライブラリは、NVIDIA GPUのコンピュート能力が8.0以上(Ampere+)、CUDA 11.6以上、Ubuntu 20以上の環境で非常に移植性が高いことがわかっています。 +このライブラリは、DeepSpeed-MIIの依存関係としてインストールされるため、ほとんどの場合では、このライブラリの存在を知る必要はありません。しかし、何らかの理由でカーネルを手動でコンパイルする必要がある場合は、インストールに関する[詳細ドキュメント](https://github.com/deepspeedai/DeepSpeed-Kernels#source)をご覧ください。 + +# 6. DeepSpeed-FastGen を使ってみる + +このDeepSpeed-FastGenアルファリリースをユーザの皆さんと共有できることを非常に嬉しく思います。 + +* 使用を始めるにあたっては、DeepSpeed-MIIのGitHubページをご覧ください: [GitHubランディングページ](https://github.com/deepspeedai/DeepSpeed-MII) + +DeepSpeed-FastGenは、Deep Learningシステムやモデリングテクノロジーを数多く含む、より大きなDeepSpeedエコシステムの一部です。さらに詳しい情報が必要な方は、 +[詳細なブログ記事]、チュートリアル、役立つドキュメントがある私たちの [ウェブサイト](https://www.deepspeed.ai/) をご覧ください。 +DeepSpeedの最新情報については、[英語のTwitter](https://twitter.com/DeepSpeedAI)、[日本語のTwitter](https://twitter.com/DeepSpeedAI_JP)、[中国語の知乎](https://www.zhihu.com/people/deepspeed)をフォローしてください。 + +DeepSpeedは、皆様の開発への参加を歓迎しています。DeepSpeedのGitHubページで、バグ報告、Pull Request、ディスカッションへの参加が可能です。詳細は[ガイドライン](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)をご覧ください。[contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) にはより詳細な情報があります。 +また、深層学習の研究や、実世界のAIモデルやアプリケーションへのDeepSpeedの適用に取り組む大学、研究所、企業とのコラボレーションも行っています。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については まで直接メールをお送りください。 + +以下の項目は、今後のロードマップです。GitHubの問題やPRを通じてコミュニティと協力して取り組む予定です: + +- パフォーマンスの改善 +- より広範なモデルサポート +- パートナーとのコラボレーションによる新しいハードウェアバックエンド +- ブログに掲載したプロットを生成するパフォーマンスベンチマークのリリース + +このプロジェクトが気に入ったら、ぜひ [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) と [DeepSpeedMII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) のリポジトリに "スター" をつけてください。 + +# 7. 謝辞 + +HuggingFace、vLLM、HuggingFace TGIを含むさまざまなオープンソースコミュニティプロジェクトに感謝します。私たちはアルファリリースでのモデルとトークナイザーをサポートするためにHF APIを活用し、今後もさらに多くのモデルを追加する予定です。特に、[Flash Attention](https://github.com/Dao-AILab/flash-attention) の開発者の素晴らしい成果に感謝します。私たちはシステムでFlashAttentionカーネルを広範囲に活用しており、コードリポジトリに含まれる適切なファイルヘッダーにそのことを記載しています。最後に、私たちのMoEカーネルで使用している [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) カーネルの開発者に感謝します(DeepSpeed-Kernelsリポジトリの一部としてリリースされました)。 diff --git a/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md b/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md new file mode 100644 index 000000000000..0eaa256bac69 --- /dev/null +++ b/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md @@ -0,0 +1,143 @@ +
+ +# DeepSpeed-FP6:大型语言模型中以FP6为核心的强大推理服务 + +
+ +
+ +DeepSpeed-VisualChat! + +
+ + +要引用DeepSpeed-FP6,请引用以下两篇arxiv报告 - ZeroQuant(4+2) 和 FP6-LLM: + +``` +@article{wu2023zeroquant, + title={Zeroquant(4+2): Redefining llms quantization with a new fp6-centric strategy for diverse generative tasks}, + author={Wu, Xiaoxia and Xia, Haojun and Youn, Stephen and Zheng, Zhen and Chen, Shiyang and Bakhtiari, Arash and Wyatt, Michael and Aminabadi, Reza Yazdani and He, Yuxiong and Ruwase, Olatunji and Song, Leon and others}, + journal={arXiv preprint arXiv:2312.08583}, + year={2023} +} + +@article{xia2024fp6, + title={FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design}, + author={Xia, Haojun and Zheng, Zhen and Wu, Xiaoxia and Chen, Shiyang and Yao, Zhewei and Youn, Stephen and Bakhtiari, Arash and Wyatt, Michael and Zhuang, Donglin and Zhou, Zhongzhu and others}, + journal={arXiv preprint arXiv:2401.14112}, + year={2024} +} +``` + + +# Table of Contents +1. [为什么选择6位浮点(FP6)](#introduction) +2. [FP6的系统支持](#system-fp6) +3. [FP6的LLMs服务系统](#serving-llm) +4. [如何开始](#how-to-start) +5. [软件改进](#software-improvements) +6. [致谢和贡献](#ac) + +# 1. 为什么选择6位浮点 +大型语言模型(LLMs)领域正处于迅猛发展之中,模型量化是提升推理服务性能的关键技术之一。 我们的研究旨在提高计算效率和存储空间,同时保持模型质量。 + +**深入研究INT4的挑战** 在最近的研究成果 ZeroQuant(4+2)[1] 中, 我们探索了INT4量化技术(如GPTQ算法) 在大语言模型(LLMs)中的表现能力。虽然这些技术可以减小模型大小和参数存储量,但由于过拟合问题, 它们在更一般的许多任务中往往表现不佳,包括代码生成和摘要等更多生成任务。因此, 当前迫切需要新的方法来提高LLMs的效率和有效性。 + + **FP6的突破** 我们对不同量化方法的探索将我们带到了FP6精度标准。尽管FP6数据格式在当前AI硬件的高效支持中存在挑战(我们将在下一节中解决这一挑战),该格式在各种任务的性能和灵活性方面均表现出色。值得注意的是,使用FP6量化的模型,如StarCoder-15B,在代码生成方面达到了与FP16模型相当的结果,而较小的模型(如BART-406M)在摘要方面达到了标准FP16性能水平。为了提高FP6在当前主流AI硬件上的执行效率,我们提出了一种4+2新颖的FP6 GPU kernel方案。这一创新使FP6成为提高LLMs效率的有效途径。更多详细信息请参阅我们的研究论文 ZeroQuant(4+2)[1]。 + + +# 2. FP6的系统支持 + +**开创性的全栈GPU KERNEL设计** FP6量化的一个挑战是缺乏针对这种不规则位宽的高效GPU KERNEL设计。在我们最近的研究中(FP6-LLM[2]),我们设计并实现了TC-FPx,第一个具有Tensor Core支持的用于FP6和各种量化位宽(6位、5位、3位等)的浮点权重的GPU系统设计方案,缓解了LLM推理期间的“内存墙”问题。TC-FPx打破了底层GPU硬件的限制,允许GPU支持涉及任意位宽模型权重的矩阵乘法计算。在TC-FPx中,Tensor Cores用于矩阵乘法的密集计算,而SIMT cores在运行时有效地用于权重反量化,将模型权重反量化为FP16类型,Tensor Core基于此进行计算。它具有以下关键创新: +
+ fp6 design + +
+ +* 运行前比特层级的数据排布转换。用以解决权重具有不规则位宽时不友好的内存访问挑战,实现GPU内存的最优访问; + +* 运行时的高效SIMT计算。用以最小化权重反量化的运行时开销; + +* 全栈的高效流水线设计。其SIMT计算、Tensor Core计算和GPU内存访问进行高效调度,最大程度提升性能。 + + + +平均而言,我们的FP6 kernel在NVIDIA A100 GPU上进行(因decoder的矩阵形状狭长而导致参数矩阵的访存成为瓶颈的)矩阵乘法时,处理速度比FP16 cuBLAS基准提高了2.1倍。值得注意的是,通过FP6量化实现的FP6内核使LLaMA-70b模型能够在单个A100 GPU上运行。这一显著成就使得其在batch小于32的LLM推理任务中,性能比FP16基准高出1.69到2.65倍。目前,TC-FPx内核仅支持NVIDIA Ampere GPU,并且仅在A100 GPU上进行了测试和验证。 + + +# 3. 使用FP6服务LLMs + +我们已成功将FP6量化内核[3]集成到DeepSpeed-FastGen中,实现了运行时的即时量化。这一增强功能允许通过DeepSpeed-FastGen中的统一配置选项来高效量化和部署大型语言模型。通过我们的接口,用户可以输入HuggingFace模型名称或本地checkpoint目录。输入后,我们的系统将启动指定模型的加载,对每个线性层实现FP6量化,并将量化的权重进行比特层级的数据排布转换。转换后的张量随后作为更新后的权重,而原始的FP16权重被丢弃以优化内存使用。在推理阶段,FP6内核将利用这些6位的权重进行计算。 + +我们在两个A100 GPU-80G上评估了LLaMA-2-70b模型使用FP6量化的服务性能,实现了1.5倍的推理延迟减少和3.5倍的推理吞吐量增加,与FP16基线相比。FP6量化为模型推理提供了两个关键好处:它使大型语言模型(LLMs)能够在更少的GPU上部署——例如,LLaMA-70b在单个A100-80G GPU上就能以FP6形式运行,而FP16模型至少需要两个GPU。此外,它显著加快了小batch之下内存访问为瓶颈的线性层计算。此外,FP6量化减少了模型权重的GPU内存需求,允许同时服务更多查询,从而提高了服务吞吐量。 + +我们的系统在处理长序列生成时表现出很高的效率。如图1所示,对于超过提示长度的生成长度,我们的系统展现出显著的性能优势。随着生成序列长度的延伸,FP6与FP16之间的性能差异加大。这一趋势主要归因于解码长度扩展时,推理过程变得越来越受内存访问瓶颈限制,有利于我们的权重量化的GPU kernel,相对于FP16实现更大的kernel速度提升。需要强调的是,较长解码场景中内存访问瓶颈增强的两个因素如下: + +首先,KV缓存的内存使用随序列长度增加而增加,减少了可容纳的batch大小并导致线性层的矩阵计算瓶颈变为参数的访存。 + +其次,在DeepSpeed-FastGen的prefill-decoding-mixed-batch技术背景下,对于decoding较长的情况,用于和decoding进行mixed-batching的prefill切块会相对不足,这导致纯粹用于decoding的batch频率增加,进一步加剧了访存的瓶颈。 +

+ Caption1 + Caption2 + Caption3 +

+ +图1:在DeepSpeed-MII中,使用128个请求和32个客户端,对LLaMA-2-70B模型在2xA100-80g上进行端到端服务性能测试。我们尝试了128、256和512之间不同数量的请求,发现加速效果相似。 + +尽管FP6量化带来了显著的好处,但当前实现仍面临一些限制。值得注意的是,在GEMM因batch较大或有充足的GPU内存而使得瓶颈变为Tensor Core计算时,我们的仅限权重的量化kernel可能无法保持其性能优势,尤其是与厂商的优化库如cuBlas相比。然而,我们系统的低内存占用仍是一个关键优势。目前的支持限于非混合专家(Non-MoE)结构,我们正在努力将支持扩展到MoE结构。此外,当前系统仅与FP16输入模型兼容,因为当前实现的FP6 kernel仅支持处理FP16的激活。 + +
+ +# 4. 如何开始 + +DeepSpeed-FP6的量化和推理体验简单方便。这里我们以LLaMa-2-70B模型为例: +```python +import mii +pipe = mii.pipeline("NousResearch/Llama-2-70b-hf", quantization_mode='wf6af16') +response = pipe(["DeepSpeed is", "Seattle is"], max_new_tokens=128) +print(response) +``` + +您需要安装以下内容 + +``` +pip install deepspeed-mii +pip install qtorch +``` + +要使用我们的DeepSpeed-FP6进行基准测试,请访问以下脚本: +```bash +https://github.com/deepspeedai/DeepSpeedExamples/blob/master/benchmarks/inference/mii/run_fp6.sh +``` + +也请访问[FP6-LLM github](https://github.com/usyd-fsalab/fp6_llm) 获取FP6的独立kernel。不要忘了给仓库加星标以表达您的支持! + + +# 5. 软件改进 + + +我们的DeepSpeed-FP6目前仅支持线性GEMM。我们期待未来能够支持MoE GEMM。我们将继续根据您的反馈和支持改进DeepSpeed-FP6。DeepSpeed-FP6是更大DeepSpeed生态系统的一部分,包括一系列深度学习系统和建模技术。要了解更多, + +* 请访问我们的 [网站](https://www.deepspeed.ai/) 了解详细的博客文章、教程和文档。 +* 在我们的 [英文 X(Twitter)](https://twitter.com/DeepSpeedAI)、[日语 X(Twitter)](https://twitter.com/DeepSpeedAI_JP) 和 [中文知乎](https://www.zhihu.com/people/deepspeed) 上关注我们,以获取 DeepSpeed 的最新消息。 + +我们欢迎您为 DeepSpeed 做出贡献!我们鼓励您报告问题、贡献 PRs、并在 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 页面上参加讨论。有关更多详细信息,请查看我们的 [贡献指南](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)。我们对与大学、研究实验室、公司等进行合作持开放态度,例如共同进行深度学习研究、应用 DeepSpeed 为现实世界的 AI 模型和应用提供支持等等。对于此类请求(以及其他不适合 GitHub 的请求),请直接发送电子邮件至 info@deepspeed.ai。 + +* 如果你喜欢我们的工作,请在[DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/), [DeepSpeed-MII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) 和 [DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/)仓库“点赞”! + + +# 6. 致谢和贡献 +我们感谢悉尼大学和罗格斯大学的合作。我们还感谢开源库 [aspuru-guzik-group/qtorch](https://github.com/aspuru-guzik-group/qtorch). + +贡献: +Xiaoxia Wu\* $^1$, Zhen Zheng\* $^1$, Haojun Xia\* $^2$, Arash Bakhtiari $^1$, Michael Wyatt $^1$, Shiyang Chen $^3$, Stephen Youn $^1$, Reza Yazdani Aminabadi, Yuxiong He, Olatunji Ruwase $^1$, Zhewei Yao, Leon Song $^1$ $^2$(项目负责人) + +\* 平等贡献 1: 微软 2: 悉尼大学 3: 罗格斯大学 + +文献: + +[1] ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks. arXiv. https://arxiv.org/abs/2312.08583 + +[2] FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design. arXiv. https://arxiv.org/abs/2401.14112 + +[3] FP6-LLM kernel release. GitHub. https://github.com/usyd-fsalab/fp6_llm diff --git a/blogs/deepspeed-fp6/03-05-2024/README.md b/blogs/deepspeed-fp6/03-05-2024/README.md new file mode 100755 index 000000000000..a6c4060a42dc --- /dev/null +++ b/blogs/deepspeed-fp6/03-05-2024/README.md @@ -0,0 +1,147 @@ +
+ +# DeepSpeed-FP6: The Power of FP6-Centric Serving for Large Language Models + +
+ +
+ +DeepSpeed-VisualChat! + +
+ + +To cite DeepSpeed-FP6, please cite the following two arxiv reports - ZeroQuant(4+2) and FP6-LLM: + +``` +@article{wu2023zeroquant, + title={Zeroquant(4+2): Redefining llms quantization with a new fp6-centric strategy for diverse generative tasks}, + author={Wu, Xiaoxia and Xia, Haojun and Youn, Stephen and Zheng, Zhen and Chen, Shiyang and Bakhtiari, Arash and Wyatt, Michael and Aminabadi, Reza Yazdani and He, Yuxiong and Ruwase, Olatunji and Song, Leon and others}, + journal={arXiv preprint arXiv:2312.08583}, + year={2023} +} + +@article{xia2024fp6, + title={FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design}, + author={Xia, Haojun and Zheng, Zhen and Wu, Xiaoxia and Chen, Shiyang and Yao, Zhewei and Youn, Stephen and Bakhtiari, Arash and Wyatt, Michael and Zhuang, Donglin and Zhou, Zhongzhu and others}, + journal={arXiv preprint arXiv:2401.14112}, + year={2024} +} +``` + + +# Table of Contents +1. [Why 6-bit Floating Point (FP6)](#introduction) +2. [System Support for FP6](#system-fp6) +3. [LLMs Serving with FP6](#serving-llm) +4. [How to Start](#how-to-start) +5. [Software Improvements](#software-improvements) +6. [Acknowledgments and Contributions](#ac) + +# 1. Why 6-bit Floating Point (FP6) + + +In the evolving landscape of Large Language Models (LLMs) like GPT, our research aims to boost computational efficiency and storage while preserving model quality. This focus brings us to tackle the complex challenges of 4-bit quantization, where optimizing performance, efficiency, and accuracy is crucial. + +**Exploring the Challenges of 4-bit Quantization** In our recent research findings -- ZeroQuant (4+2)[1], we explore the capabilities of INT4 quantization techniques (like the GPTQ algorithm) for serving Large Language Models (LLMs). While these techniques reduce memory and computational requirements, they often perform poorly on a broad array of tasks, including generative tasks such as code generation and summarization, due to overfitting issues. This highlights the urgent need for new quantization approaches that simultaneously improve both the efficiency and effectiveness of LLMs. + +**Breakthroughs with FP6 Precision** Our exploration of different quantization methods led us to the FP6 precision standard. Despite the challenges in integrating and accelerating FP6 with current AI hardware -- which we will address in the next section - this format excels in performance and flexibility across various tasks. Notably, we observe that for generative tasks, FP6 quantization can match the performance of the half-precision (FP16) format. For example, with FP6 quantization, StarCoder-15B achieves comparable code generation results to the FP16 variant, while a smaller model, such as BART-460M, achieves comparable summarization performance to the standard FP16 equivalent. In order to preserve these quality gains, while matching the system efficiency of INT4 quantization on AI hardware, we propose a novel 4+2 FP6 scheme. This innovation makes FP6 a promising direction for improving the efficiency of LLMs, marking a significant leap in AI technology advancement. For more details, please refer to our research paper - ZeroQuant (4+2)[1]. + + +# 2. System Support for FP6 + +**Pioneering Full-Stack GPU Kernel Design** A key challenge of FP6 quantization is the lack of efficient GPU kernel designs for this irregular, i.e., "non-power of 2", bit-width. In our recent research — FP6-LLM [2], we introduce TC-FPx, the first full-stack GPU system design scheme with unified Tensor Core support of floating point weights for FP6 and other irregular quantization bit-widths (6-bit, 5-bit, 3-bit, etc.). TC-FPx breaks the limitations of the underlying GPU hardware, allowing the GPU to support linear layer calculations on model weights of arbitrary bit width. By increasing the number of bit-width options for efficient quantization, TC-FPx significantly mitigates the "memory wall" challenges of LLM inference. In TC-FPx, Tensor Cores are utilized for intensive computation of matrix multiplications, while SIMT cores are effectively leveraged for weight dequantization, transforming the x-bit model weights to FP16 type during runtime before feeding them to Tensor Cores. It has the following key innovations: +
+ fp6 design + +
+ +* *Ahead-of-time Bit-level Pre-packing*: resolve the challenge of unfriendly memory access for weights with irregular bit-width, and enable optimal GPU memory access. + +* *SIMT-Efficient GPU Runtime*: minimize the runtime overhead of weight de-quantization. + +* *The software pipeline of TC-FPx kernel*: efficiently utilize SIMT cores, Tensor Cores, and the GPU memory hierarchy for high performance. + + + +On average, the TC-FPx kernel demonstrates a 2.1-fold improvement in processing speed over the FP16 cuBLAS benchmark during memory-intensive General Matrix Multiply (GEMM) operations on NVIDIA A100 GPUs. Notably, the implementation of the FP6 kernel through FP6 quantization facilitates the operation of LLaMA-70b on a solitary A100 GPU. This remarkable feat results in a normalized inference throughput that is 1.69 to 2.65 times superior to the FP16 benchmark when conducting inference tasks with batch-size under 32. Currently, TC-FPx kernel only supports NVIDIA Ampere GPUs and is only tested and verified on A100 GPUs + + +# 3. LLMs serving with FP6 + +We have successfully integrated the FP6 quantization kernel [3] into DeepSpeed-FastGen, facilitating on-the-fly, weight-only quantization. This enhancement permits the efficient quantization and deployment of large language models (LLMs) through a unified configuration option within DeepSpeed-FastGen. Detailed information regarding this feature will be provided in due course. Through our interface, users have the flexibility to load a model checkpoint from either HuggingFace hub or a local directory. While loading the checkpoint, our system applies FP6 round-to-nearest quantization on each linear layer, and transforms the quantized weights into 6-bit prepacked tensors. These tensors will serve as the model weights for inference, while the original FP16 weights are discarded to release memory. Throughout the inference stage, the FP6 kernels leverage the 6-bit prepacked weights, ensuring a seamless experience for users engaging with our platform. + +We assessed the LLaMA-70b model's serving performance using FP6 quantization on two A100 GPUs-80G, and observed a *1.5x* reduction in inference latency and a *3.5x* increase in inference throughput compared to the FP16 baseline. FP6 quantization offers two key benefits for model inference: it enables the deployment of large language models (LLMs) on fewer GPUs — for instance, LLaMA-70b fits on a single A100-80G GPU with FP6, versus at least two GPUs required for the FP16 baseline. Additionally, it significantly accelerates linear layers in memory-bound scenarios, which are common in LLM inference. Moreover, FP6 quantization reduces the GPU memory requirements for model weights, allowing for more queries to be served simultaneously, and thus increasing serving throughput. + +Our system demonstrates exceptional efficiency in handling long generation sequences. As illustrated in Figure 1, for generation lengths surpassing the prompt length, our system exhibits a notable performance superiority. The disparity in performance between FP6 and the FP16 baseline widens with the extension of the generation sequence length. This trend is primarily attributed to the inference process becoming increasingly memory-constrained as the decoding length expands, favoring our weight-quantized GPU kernels by facilitating faster compute compared to the FP16 baseline. It is important to highlight two factors contributing to the increased memory constraints in longer decoding scenarios. + - Firstly, the memory usage for the KV cache escalates with the sequence length, reducing the feasible batch sizes and leading to memory-bound GEMM operations. + - Secondly, within the context of DeepSpeed-FastGen's prefill-decoding-mixed-batch technique, scenarios involving extended token generation encounter a reduction in prefill-chunks available for mixing with decodings. This results in a higher frequency of batches dedicated solely to decodings, further intensifying the memory-bound conditions. + +

+ Caption1 + Caption2 + Caption3 +

+ + *Figure 1*: End-to-end serving performances in DeepSpeed-MII with 32 clients and total of 128 requests, for LLaMA-2-70B model on 2xA100-80g with two-way tensor parallelism. We experimented with different number of requests between 128, 256 and 512 and found that the speedup is simillar. + +Despite the significant benefits of FP6 quantization, the current implementation faces limitations. Notably, in scenarios where GEMM operations become compute-bound due to large batch sizes or sufficient GPU memory, our weight-only quantization kernel may not sustain its latency advantage, especially against optimized libraries like cuBlas. However, our system's memory efficiency remains a key benefit. Currently, support is limited to Non-Mixture of Experts (Non-MoE) structures, with efforts underway to extend support to MoE structures. Additionally, the system is compatible only with FP16 input models, as the FP6 kernel processes FP16 activations exclusively. + +
+ +# 4. How to begin with DeepSpeed-FP6 + +The quantization-and-inference experience of DeepSpeed-FP6 is straightforward and convenient. Here we give an example based on LLaMa-2-70B model: + +```python +import mii +pipe = mii.pipeline("NousResearch/Llama-2-70b-hf", quantization_mode='wf6af16') +response = pipe(["DeepSpeed is", "Seattle is"], max_new_tokens=128) +print(response) +``` + +You need to install the following: +``` +pip install deepspeed-mii +pip install qtorch +``` + +To benchmark with our DeepSpeed-FP6, please visit the following script: +```bash +https://github.com/deepspeedai/DeepSpeedExamples/blob/master/benchmarks/inference/mii/run_fp6.sh +``` + +Please also visit the [FP6-LLM github](https://github.com/usyd-fsalab/fp6_llm) for the standalone kernel of FP6. Don't forget to star the repo to show your support! + + +# 5. Software Improvements + + +Currently, DeepSpeed-FP6 supports only dense models with MoE models support upcoming. We will continue to improve DeepSpeed-FP6 with your feedback and support. DeepSpeed-FP6 is a component of the larger DeepSpeed ecosystem, which includes a range of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* Follow us on our [English X(Twitter)](https://twitter.com/DeepSpeedAI), [Japanese X(Twitter)](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +We welcome your contributions to DeepSpeed! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. + +* "Star" our [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) and [DeepSpeed-MII GitHub](https://github.com/deepspeedai/DeepSpeed-MII/) and [DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/) repositories if you like our work! + + +# 6. Acknowledgments and Contributions +We thank the collaboration of the University of Sydney and Rutgers University. We also thank the open-source library [aspuru-guzik-group/qtorch](https://github.com/aspuru-guzik-group/qtorch). + +Contributions: +Xiaoxia Wu\* $^1$, Zhen Zheng\* $^1$, Haojun Xia\* $^2$, Arash Bakhtiari $^1$, Michael Wyatt $^1$, Shiyang Chen $^3$, Stephen Youn $^1$, Reza Yazdani Aminabadi, Yuxiong He, Olatunji Ruwase $^1$, Zhewei Yao, Leon Song $^1$ $^2$ (project lead) + +\* Equal Contribution +1: Microsoft +2: University of Sydney +3: Rutgers University + +Reference: + +[1] ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks. arXiv. https://arxiv.org/abs/2312.08583 + +[2] FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design. arXiv. https://arxiv.org/abs/2401.14112 + +[3] FP6-LLM kernel release. GitHub. https://github.com/usyd-fsalab/fp6_llm diff --git a/blogs/deepspeed-fp6/03-05-2024/assets/fp6-design.png b/blogs/deepspeed-fp6/03-05-2024/assets/fp6-design.png new file mode 100644 index 000000000000..5024332a8f33 Binary files /dev/null and b/blogs/deepspeed-fp6/03-05-2024/assets/fp6-design.png differ diff --git a/blogs/deepspeed-fp6/03-05-2024/assets/hero-figure.png b/blogs/deepspeed-fp6/03-05-2024/assets/hero-figure.png new file mode 100644 index 000000000000..61a5061dc954 Binary files /dev/null and b/blogs/deepspeed-fp6/03-05-2024/assets/hero-figure.png differ diff --git a/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-1000.png b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-1000.png new file mode 100644 index 000000000000..c1095ee0053b Binary files /dev/null and b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-1000.png differ diff --git a/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-250.png b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-250.png new file mode 100644 index 000000000000..aeeaab55466d Binary files /dev/null and b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-250.png differ diff --git a/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-500.png b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-500.png new file mode 100644 index 000000000000..eb3c1ac12a7b Binary files /dev/null and b/blogs/deepspeed-fp6/03-05-2024/assets/servingllm/100-500.png differ diff --git a/blogs/deepspeed-offloadpp/README.md b/blogs/deepspeed-offloadpp/README.md new file mode 100644 index 000000000000..f58173b7bc8b --- /dev/null +++ b/blogs/deepspeed-offloadpp/README.md @@ -0,0 +1,52 @@ +# DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow + +Deep learning has been successfully adopted in a wide range of applications such as speech recognition, chatbot, text and image generation, etc. To achieve better model serving accuracy, model size grows significantly. Take language models as example, from BERT with 110 million parameters to Megatron-Turing NLG with 530 billion parameters, the model size grows almost 5000x. Given limited GPU memory size, we need to efficiently utilize GPU memory to achieve good system throughput. + +ZeRO offers memory efficient data parallel training scheme. For training large models like LLMs using ZeRO, GPU memory size is still often insufficient to hold all the model parameters. Thus, ZeRO-Offload is introduced to solve this insufficient GPU memory issue. ZeRO-Offload releases GPU memory pressure by offloading data and compute to the CPU side while minimizing CPU-GPU data copy overhead. Given CPU memory is often orders-of-magnitude larger than GPU memory, ZeRO-Offload was the first piece of work that enables billion-level parameter training even with very limited GPU memory resources (e.g., to an extreme: single GPU). ZeRO-Offload provides excellent performance when model size is multiple times larger than total GPU memory size. + +However, system efficiency is still far from optimal when adopting ZeRO-Offload in some scenarios. Especially in the cases like small batch training, model that could not fit into GPU memory but not orders-of-magnitude bigger than GPU memory capacity, CPU offload not only introduce long end-to-end latency, but also underutilized GPU computation resources. To reduce memory copy latency as well as inefficient utilization of GPU introduced in these offload cases, we propose ZeRO-Offload++, which leverages both CPU and GPU coherently. ZeRO-Offload++ mainly includes 3 new features as _Twin-Flow_, MemCpy reduction, CPUAdam optimization. Now we release our __Twin-Flow__ feature. + +The key benefits are: +* With _Twin-Flow_, ZeRO-Offload++ achieves up to **6x** training speedup compared with ZeRO-Offload. +* High-level API provided in DeepSpeed config JSON makes it easy to use and fine-tune. + +![h100-img](./images/h100-8.png) + +## Twin-Flow + +In DeepSpeed, when training using popular optimizer like Adam, optimizer offloading follows an all-or-nothing policy. For simplifed example shown as Figure below, without offloading, all the parameters will be updated using GPU adam as FusedAdam optimizer. On the other hand, if offloading is enabled, all model weights use CPUAdam to update. + +![cpu-offload-img](./images/cpu-offload.png) + +The major downside of this all-or-nothing offloading is, when offload all optimizer states to CPU side, both GPU memory and compute resources remain under-utilized. Although increasing batch size improves GPU utilization rate, each training iteration time is still super long compared with no-offloading case. To improve GPU compute and memory utilization rate as well as decrease training iteration time, we introduce a new feature in our DeepSpeed training engine called _Twin-Flow_. + +In comparison, _Twin-Flow_ allows a portion of optimizer states to be held in CPU memory and the other portion of optimizer states remaining in GPU memory. When optimization step is triggered, both CPU and GPU can do parameter updates simultaneously. Once offloading is enabled, we provide an offload ratio configuration which allows users to adjust how many percentages of model weights are updated on CPU side and the rest are happened on GPU side. "_Twin_" comes from the idea that both CPU and GPU are using the same optimizer function here. "_Flow_" means parameters are not only hold in both host and device memory, but also computed using both CPU and GPU cores. + +As shown in Figure below, with ZeRO-Offload enabled and we set _Twin-Flow_ ratio of 0.4 (40%). DeepSpeed Training engine will automatically assign first 40% (i.e. 0-40%) of weights step procedure on the CPU side using CPUAdam, and use GPU side FusedAdam to update the rest 60% (i.e., 40-100%) model parameters jointly. Therefore, with _Twin-Flow_, we can achieve decent GPU memory and core utilization rate, at the same time reduce training iteation time in optimizer offloading cases. + +![_Twin-Flow_-img](./images/twin-offload.png) + +Note that this _Twin-Flow_ ratio can be adjusted based on how much GPU idle memory is available. The smaller this ratio is, the more GPU memory and cores are used and the shorter training iteration time it achieves. The ideal case is to be as near as GPU memory upper bound in order to minimize training iteration time. +Note that _Twin-Flow_ is not limited to Adam optimizer only, it can be applied to any optimizer (e.g., AdaGrad) from the user side. + +## Performance Evaluation + +We conduct our performance evaluations over both A100 and H100 DGX machine and test for OPT model with 13B and 30B parameters. We run 13B OPT model training on a 8 A100 DGX machine, and run OPT-30B model training using a 8 H100 DGX machine. With some tuning on offload ratio in ZeRO-Offload++, we achieve 6x and 3x training speedup of Meta OPT models on single DGX-H100-80GB and DGX-A100-40GB, respectively (top-most figure and bottom figure here). + +![a100-img](./images/a100-8.png) + +## On-going Optimizations + +* Reduce uncessary D2H/H2D memcpy + +* On-the-fly fp16 to fp32 casting for CPUAdam + +## Tutorials + +Examples and Tutorials are [here](https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/examples_deepspeed/offload_pp/README.md) + +## Contributors: + +This project was made possible by the contributions of the following people from DeepSpeed Team: + +[Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), Masahiro Tanaka, Xiaoxia Wu, Lok Chand Koppaka, Samyam Rajbhandari, [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/), [Yuxiong He](https://www.microsoft.com/en-us/research/people/yuxhe/) (team lead) diff --git a/blogs/deepspeed-offloadpp/images/a100-8.png b/blogs/deepspeed-offloadpp/images/a100-8.png new file mode 100644 index 000000000000..22b787f69e1e Binary files /dev/null and b/blogs/deepspeed-offloadpp/images/a100-8.png differ diff --git a/blogs/deepspeed-offloadpp/images/cpu-offload.png b/blogs/deepspeed-offloadpp/images/cpu-offload.png new file mode 100644 index 000000000000..cc4dae505cd3 Binary files /dev/null and b/blogs/deepspeed-offloadpp/images/cpu-offload.png differ diff --git a/blogs/deepspeed-offloadpp/images/h100-8.png b/blogs/deepspeed-offloadpp/images/h100-8.png new file mode 100644 index 000000000000..938625d52aaf Binary files /dev/null and b/blogs/deepspeed-offloadpp/images/h100-8.png differ diff --git a/blogs/deepspeed-offloadpp/images/twin-offload.png b/blogs/deepspeed-offloadpp/images/twin-offload.png new file mode 100644 index 000000000000..1c8c3ef92454 Binary files /dev/null and b/blogs/deepspeed-offloadpp/images/twin-offload.png differ diff --git a/blogs/deepspeed-superoffload/README.md b/blogs/deepspeed-superoffload/README.md new file mode 100644 index 000000000000..e9a61b4f9dff --- /dev/null +++ b/blogs/deepspeed-superoffload/README.md @@ -0,0 +1,211 @@ +# SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips + +**Efficient full-parameter fine-tuning of GPT-OSS-20B & Qwen3-14B models on a single NVIDIA GH200 Superchip and Llama3-70B on four NVIDIA GH200 Superchips, while delivering up to 600 TFLOPS training throughput** + +**Authors** +[Xinyu Lian](https://xinyulian.tech/)1, [Masahiro Tanaka](https://tohtana.github.io/)2, [Olatunji Ruwase](https://www.snowflake.com/en/blog/authors/olatunji--tunji--ruwase/)3, [Minjia Zhang](https://minjiazhang.github.io/)1 + +1SSAIL Lab, University of Illinois Urbana-Champaign · 2Anyscale · 3Snowflake + +--- + +## Table of Content + +- [SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips](#superoffload-unleashing-the-power-of-large-scale-llm-training-on-superchips) + - [SuperOffload Highlights](#superoffload-highlights) + - [Introduction](#introduction) + - [How SuperOffload Works](#how-superoffload-works) + - [1. Speculation-then-Validation (STV)](#1-speculation-then-validation-stv) + - [2. Heterogeneous Optimizer Computation](#2-heterogeneous-optimizer-computation) + - [3. Superchip-Aware Casting](#3-superchip-aware-casting) + - [4. GraceAdam for Optimizer Efficiency](#4-graceadam-for-optimizer-efficiency) + - [Experience and Insights](#experience-and-insights) + - [Getting Started](#getting-started) + - [Acknowledgements](#acknowledgements) + +--- + +## SuperOffload Highlights + +- **Single GH200:** Full fine-tuning of GPT-OSS-20B, Qwen3-14B, achieving up to 600 TFLOPS (seq len 4K, batch size 4). +- **Multi-GPU:** Qwen3-30B-A3B & Seed-OSS-36B on 2× NVIDIA GH200; Llama-70B on 4× NVIDIA GH200. +- **Faster Training:** Up to 4× higher throughput compared to prior work such as ZeRO-Offload under modest settings. +- **Increased GPU Utilization:** Boost GPU utilization from ~50% to >80%. +- **Engineering & Composability:** Works with ZeRO-3 and Ulysses; operational tips (e.g., NUMA binding, MPAM) are documented in the tutorial. + +--- + +## Introduction + +The emergence of tightly coupled heterogeneous GPU/CPU architectures (a.k.a., Superchips), such as NVIDIA GH200, GB200, and AMD MI300A, offers new optimization opportunities for large-scale AI. Yet it remains under-explored in terms of how to make the best use of these new hardware for large-scale LLM training. Existing offloading solutions were designed for traditional loosely coupled architectures, and are suboptimal on Superchips suffering high overheads and low GPU utilization. To address this gap and to make the best use of Superchips for efficient LLM training, we have developed and open-sourced **SuperOffload**. + +SuperOffload introduces a set of novel techniques that make the best use of Hopper GPU, Grace CPU, and NVLink-C2C, simultaneously for LLM training. Unlike prior offloading solutions which assume slow GPU-CPU interconnects (e.g., 64GB/sec for PCIe-Gen4), SuperOffload exploits the much faster interconnects (e.g., 900GB/sec for NVLink-C2C) to boost GPU and CPU utilization, and training throughput. With SuperOffload, models such as **GPT-OSS-20B**, **Qwen3-14B**, and **Phi-4** can be fully fine-tuned on a single GH200, delivering up to **600 TFLOPS** training throughput under modest settings (sequence length 4k, batch size 4). This delivers up to **4×** higher throughput compared to prior work such as ZeRO-Offload. SuperOffload enables scaling to even larger models, including Qwen3-30B-A3B and Seed-OSS-36B on two GH200s and Llama-70B on four GH200s. + +SuperOffload is built on top of DeepSpeed ZeRO Stage 3, and is available in DeepSpeed versions >= [0.18.0](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.18.0). To enable easy integration into LLM finetuning pipelines, SuperOffload is compatible with Hugging Face Transformers and does not require any changes to modeling code. + + + + + +
+SuperOffload system overview +

Figure 1: SuperOffload delivers up to 4× higher throughput than ZeRO-Offload for large-model fine-tuning across varying sequence lengths and batch sizes, achieving up to 600 TFLOPS throughput.

+
+ +--- + +## How SuperOffload Works + + + +SuperOffload consists of four composable offloading optimization techniques: (1) Speculation-then-Validation, (2) GPU/CPU Optimizer Computation, (3) Superchip-Aware Casting, and (4) GraceAdam. We provide brief descriptions of these techniques below. + + +### 1. Speculation-then-Validation (STV) + +In most offloading solutions, synchronizations between CPU and GPU are needed in the optimizer step to ensure numerical robustness. For example, clipping the gradient norm requires calculating the global gradient norm, and mixed precision training requires a global check of NaN and INF values. These operations require the CPU to wait until all gradients have been received before the optimizer step and weight updates. STV avoids this bottleneck by breaking this dependency but still preserves the semantics of training by overlapping speculative optimizer computation on CPU with backward propagation on GPU. When gradient post-processing eventually completes, the speculative optimizer computations are either committed, discarded, or correctly replayed as appropriate. STV's post-validation of training stability enables it to safely reduce the critical path compared to prior pre-validation approaches. The figure below illustrates how SuperOffload schedules backward propagation and optimizer computation differently from traditional approaches, such as ZeRO-Offload. + +
+Schedule comparison +

Figure 2: Previous offloading approach suffers from global gradient norm and global check of NAN and INF values, which expose the optimizer step to the critical path and prevent overlapping opportunities. In SuperOffload, we introduce a speculation-then-validation schedule to address this issue.

+
+ +We evaluated the effectiveness of STV by measuring the frequency of undoing speculative optimizer computations in a pre-training run of a BLOOM-176B model. As shown in the figure below, such rollbacks (e.g., due to gradient clipping, etc.) are rare after warmup, making the associated overheads negligible over the entire training run. This makes STV practical for accelerating large-scale training. + +
+Gradient clipping data +

Figure 3: Red points indicate gradient clipping triggered during BLOOM pre-training — rare after warm-up, indicating that SuperOffload's STV mechanism effectively eliminates stalls caused by gradient clipping and NaN/INF check-induced synchronizations.

+
+ +--- + +### 2. Heterogeneous Optimizer Computation + +SuperOffload improves optimizer efficiency beyond STV by partitioning optimizer computation across GPU and CPU. The GPU is used for optimizer computations of gradients created in the latter stages of the backward pass, while the CPU handles the rest. This partitioning scheme has multiple benefits. First, the GPU avoids idly waiting for optimizer computation to complete on the CPU. Second, optimizer computation is reduced by leveraging both GPU and CPU compute. Third, GPU-CPU transfers of parameters and gradients corresponding to GPU optimizer computations can be avoided. + + + +--- + +### 3. Superchip-Aware Casting + +In mixed precision training with offloading, tensor transfers between GPU and CPU require casting between the low-precision format on GPU (e.g., BF16, FP16, etc.) and the high-precision format on CPU (i.e., FP32). To address the bandwidth limitations of PCIe interconnects, prior offloading solutions transfer tensors in low-precision and type cast tensors on both GPU and CPU as appropriate. However, this is a suboptimal strategy on Superchip architectures because GPU compute throughput is ~100X higher than CPU, and high-bandwidth interconnects (e.g., NVLink-C2C) makes the transfer costs negligible. As an illustration, Figure 4 below shows that the optimal strategy on GH200 is tensor casting on the GPU and transferring in high-precision format. + + + +
+Tensor casting optimization +

Figure 4: GH200: Tensor casting to lower/higher precision on GPU and transferring in higher-precision is more efficient on Superchips.

+
+ +--- + +### 4. GraceAdam for Optimizer Efficiency + +Existing offloading solutions for LLM training require CPU implementations of the popular Adam optimizer, such as PyTorch Adam and DeepSpeed CPU-Adam. However, these are inadequate for Superchips because they are not optimized for the Grace CPU architecture. To address this issue, we created GraceAdam, a highly efficient Adam optimizer implementation for Grace CPUs. GraceAdam achieves high performance exploiting the underlying ARM architecture features such as Scalable Vector Extension (SVE), explicit memory hierarchy management, and instruction-level parallelism. Figure 5 below shows that on GH200 Superchip, GraceAdam is 3× faster than PyTorch Adam (PT-CPU) and 1.3× faster than CPU-Adam. To our knowledge, GraceAdam is the first open sourced Adam optimizer implementation for Grace CPU. + +
+GraceAdam +

Figure 5: Using GraceAdam for efficient Adam optimizer computation on GH200.

+
+ + +## Experience and Insights + +- **NUMA Binding:** + Pair each GPU with its directly associated CPU to maximize bandwidth. In DeepSpeed: + ```bash + --bind_cores_to_rank + ``` + +- **MPAM (Memory System Resource Partitioning and Monitoring):** + Reduces interference between CPU and GPU tasks. + + **How to enable MPAM on NVIDIA Superchips:** + 1. Install the kernel from [NVIDIA NV-Kernels](https://github.com/NVIDIA/NV-Kernels/tree/24.04_linux-nvidia-adv-6.11). + 2. Check MPAM support: + ```bash + grep MPAM /boot/config-$(uname -r) + ``` + Expected output: + ``` + CONFIG_ARM64_MPAM=y + CONFIG_ACPI_MPAM=y + CONFIG_ARM64_MPAM_DRIVER=y + CONFIG_ARM64_MPAM_RESCTRL_FS=y + ``` + Verify resctrl filesystem: + ```bash + ls -ld /sys/fs/resctrl + ``` + 3. Mount resctrl: + ```bash + mount -t resctrl resctrl /sys/fs/resctrl + ``` + 4. Create partitions: + ```bash + mkdir /sys/fs/resctrl/p1 /sys/fs/resctrl/p2 + ``` + 5. Set CPU cores & memory configs (example from experiments): + ```bash + /sys/fs/resctrl/p1/cpus_list: + 0-6 + /sys/fs/resctrl/p2/cpus_list: + 7-71 + /sys/fs/resctrl/p1/schemata: + MB:1=100 + L3:1=ff0 + /sys/fs/resctrl/p2/schemata: + MB:1=20 + L3:1=f + ``` + +--- + +## Getting Started + +End-to-end finetuning examples using SuperOffload are available in our tutorial/readme: [DeepSpeedExamples: SuperOffload](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-SuperOffload#readme). To enable SuperOffload quickly, add the following switch to your DeepSpeed config (see tutorial for full context): + +
+Enable SuperOffload +

Figure 6: Enable SuperOffload with a single line in the DeepSpeed config.

+
+ + + +Tip: On Superchip platforms (e.g., GH200/GB200/MI300A), combine NUMA binding and MPAM settings from "Experience and Insights" to stabilize bandwidth and improve end-to-end performance. + + + +--- + +## Acknowledgements + +This work is a close collaboration among [University of Illinois Urbana-Champaign (UIUC)](https://supercomputing-system-ai-lab.github.io/), [Anyscale](https://www.anyscale.com/), and [Snowflake](https://www.snowflake.com/en/blog/authors/snowflake-ai-research/). + +We also gratefully acknowledge William Gropp, Brett Bode, and Gregory H. Bauer from the National Center for Supercomputing Applications (NCSA), as well as Dan Ernst, Ian Karlin, Giridhar Chukkapalli, Kurt Rago, and others from NVIDIA for their valuable discussions and guidance on MPAM support on Grace CPU. + +Community feedback and contributions are welcome. For enablement and examples, see "Getting Started" above. + +--- + +## BibTeX + +```bibtex +@inproceedings{superoffload, + author = {Xinyu Lian and Masahiro Tanaka and Olatunji Ruwase and Minjia Zhang}, + title = "{SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips}", + year = {2026}, + booktitle = {Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating System (ASPLOS'26)} +} +``` diff --git a/blogs/deepspeed-superoffload/README_cn.md b/blogs/deepspeed-superoffload/README_cn.md new file mode 100644 index 000000000000..173e4c57be44 --- /dev/null +++ b/blogs/deepspeed-superoffload/README_cn.md @@ -0,0 +1,188 @@ +# SuperOffload: 释放超级芯片上大规模LLM训练的潜力 + +**在单个英伟达GH200超级芯片上高效完成GPT-OSS-20B和Qwen3-14B模型的全参数微调,并在四块英伟达GH200超级芯片上实现Llama3-70B模型的训练,同时提供高达600TFLOPS的训练吞吐量。** + +**作者** +[Xinyu Lian](https://xinyulian.tech/)1, [Masahiro Tanaka](https://tohtana.github.io/)2, [Olatunji Ruwase](https://www.snowflake.com/en/blog/authors/olatunji--tunji--ruwase/)3, [Minjia Zhang](https://minjiazhang.github.io/)1 + +1SSAIL Lab, University of Illinois Urbana-Champaign · 2Anyscale · 3Snowflake + +--- + +## 目录 + +- [SuperOffload:释放超级芯片上大规模LLM训练的潜力](#superoffload释放超级芯片上大规模llm训练的潜力) + - [SuperOffload的亮点](#superoffload的亮点) + - [介绍](#介绍) + - [SuperOffload的工作原理](#superoffload的工作原理) + - [1. 推测验证机制(STV)](#1-推测验证机制stv) + - [2. 异构优化器计算](#2-异构优化器计算) + - [3. 超级芯片感知的类型转换](#3-超级芯片感知的类型转换) + - [4. GraceAdam:提升优化器效率](#4-graceadam提升优化器效率) + - [经验与洞察](#经验与洞察) + - [快速使用指南](#快速使用指南) + - [致谢](#致谢) + +--- + +## SuperOffload的亮点 + +- 在**一块GH200**上能够对GPT-OSS-20B和Qwen3-14B进行全参数微调,达到600TFLOPS的运算速度(Seqlen=4K,BS=4)。 +- **多卡训练**:在两块英伟达GH200上训练Qwen3-30B-A3B和Seed-OSS-36B,在四块英伟达GH200上训练Llama-70B。 +- **训练速度**:在合理的设置下,比ZeRO-Offload快四倍的训练吞吐量。 +- **提高显卡利用率**:将显卡利用率从约50%提高到大于80%。 +- **灵活组合性**:支持ZeRO-3和Ulysses;一些操作技巧如NUMA绑定和MPAM等已在教程中详细说明。 + +--- + +## 介绍 + +紧密耦合的异构GPU/CPU架构(又称超级芯片)的出现,例如NVIDIA GH200、GB200和AMD MI300A,为大规模AI提供了新的优化机遇。然而,如何充分利用这些新硬件进行大规模LLM训练仍处于探索不足的状态。现有的offloading解决方案是为传统松散耦合架构设计的,在超级芯片上表现欠佳,存在高开销和低GPU利用率的问题。为弥补这一空白并充分利用超级芯片实现高效LLM训练,我们开发并开源了**SuperOffload**。 + +SuperOffload引入了一系列创新技术,可同时充分利用Hopper GPU、Grace CPU和NVLink-C2C进行LLM训练。与先前假设GPU-CPU互连速度较慢(如PCIe-Gen4的64GB/秒)的offloading解决方案不同,SuperOffload利用更高速的互连技术(如NVLink-C2C的900GB/秒)来提升GPU和CPU利用率及训练吞吐量。借助SuperOffload,诸如**GPT-OSS-20B**、**Qwen3-14B**和**Phi-4**等模型可在单台GH200上完成全参数微调,在常规设置下(序列长度4k,批次大小4)实现高达**600 TFLOPS**的训练吞吐量。与ZeRO-Offload等先前工作相比,此举可实现高达**4倍**的吞吐量提升。SuperOffload还能支持扩展至更大模型,包括在两台GH200上运行Qwen3-30B-A3B和Seed-OSS-36B,以及在四台GH200上运行Llama-70B。 + +SuperOffload构建于DeepSpeed ZeRO Stage 3之上,并在DeepSpeed [0.18.0]((https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.18.0)及以上版本中提供。为便于集成到LLM微调流程中,SuperOffload与Hugging Face Transformers兼容,且无需对模型代码进行任何修改。 + +
+SuperOffload system overview +

图1:在不同序列长度和批次大小的大型模型微调中,SuperOffload相比ZeRO-Offload可实现高达4倍的吞吐量提升,最高达到600 TFLOPS的吞吐量。

+
+ +--- + +## SuperOffload的工作原理 + +SuperOffload包含四项可组合的offloading优化技术:(1) 推测验证机制,(2) GPU/CPU优化器计算,(3) 超级芯片感知的类型转换,以及(4) GraceAdam优化器。以下我们将简要介绍这些技术。 + + +### 1. 推测验证机制(STV) + +在大多数offloading解决方案中,优化器步骤需要CPU和GPU之间的同步以确保数值鲁棒性。例如,梯度norm裁剪需要计算全局梯度norm,混合精度训练需要全局检查NaN和INF值。这些操作要求CPU等待直到收到所有梯度后才能执行优化器步骤和权重更新。STV通过打破这种依赖性来避免此瓶颈,同时通过将CPU上的推测性优化器计算与GPU上的反向传播重叠来保持训练语义。当梯度后处理最终完成时,推测性优化器计算会根据情况被提交、丢弃或正确重放。STV对训练稳定性的后验证使其能够相比先前的前验证方法安全地缩短关键路径。下图展示了SuperOffload如何以不同于传统方法(如ZeRO-Offload)的方式调度反向传播和优化器计算。 + +
+Schedule comparison +

图2:以往的offloading方法受限于全局梯度范数计算及全局NaN/INF值检查,导致优化器步骤暴露在关键路径中且无法实现计算重叠。SuperOffload通过引入推测验证调度机制来解决这一问题。

+
+ +我们通过测量BLOOM-176B模型预训练过程中推测性优化器计算被撤销的频率来评估STV的有效性。如下图所示,这类回滚(例如由于梯度裁剪等原因引起)在预热阶段后很少发生,使得相关开销在整个训练过程中可忽略不计。这使得STV在加速大规模训练方面具有实用性。 + +
+Gradient clipping data +

图3:红色数据点表示BLOOM预训练过程中触发梯度裁剪的时刻——在预热阶段后极少出现,这表明SuperOffload的STV机制有效消除了由梯度裁剪和NaN/INF检查引起的同步停顿。 +

+
+ +--- + +### 2. 异构优化器计算 + +SuperOffload通过将优化器计算分区到GPU和CPU上来提升STV之外的优化器效率。GPU用于处理反向传播后期阶段产生的梯度对应的优化器计算,而CPU则负责其余部分。这种分区方案具有多重优势:首先,GPU无需闲置等待CPU完成优化器计算;其次,通过同时利用GPU和CPU的计算资源减少了优化器计算时间;第三,避免了与GPU优化器计算对应的参数和梯度在GPU-CPU间的传输。 + +--- + +### 3. 超级芯片感知的类型转换 + +在采用offloading的混合精度训练中,GPU与CPU之间的张量传输需要在GPU低精度格式(如BF16、FP16等)与CPU高精度格式(即FP32)间进行类型转换。为应对PCIe互连的带宽限制,先前的offloading解决方案采用低精度传输张量,并在GPU和CPU上适时进行类型转换。然而这在超级芯片架构中并非最优策略,因为GPU计算吞吐量约为CPU的100倍,而高带宽互连(如NVLink-C2C)使得传输成本可忽略不计。如图4所示,GH200上的最优策略是在GPU上进行张量类型转换并采用高精度格式传输。 + +
+Tensor casting optimization +

图4:GH200:在超级芯片上,通过GPU进行张量高低精度转换并以高精度格式传输更为高效。

+
+ +--- + +### 4. GraceAdam:提升优化器效率 + +现有用于LLM训练的offloading解决方案需要流行Adam优化器(如PyTorch Adam和DeepSpeed CPU-Adam)的CPU实现版本。然而这些实现并不适用于超级芯片,因为它们未针对Grace CPU架构进行优化。为解决此问题,我们创建了GraceAdam——专为Grace CPU设计的高效Adam优化器实现。GraceAdam通过利用底层ARM架构特性(如可扩展向量扩展SVE、显式内存层次管理和指令级并行)实现高性能。图5显示在GH200超级芯片上,GraceAdam比PyTorch Adam快3倍,比CPU-Adam快1.3倍。据我们所知,GraceAdam是首个面向Grace CPU开源的Adam优化器实现。 + +
+GraceAdam +

图5:使用GraceAdam在GH200上实现高效Adam优化器计算。

+
+ + +## 经验与洞察 + +- **NUMA绑定:** + 将每个GPU与其直接关联的CPU进行配对以最大化带宽。在DeepSpeed中: + ```bash + --bind_cores_to_rank + ``` + +- **MPAM(内存系统资源分区与监控):** + 减少CPU与GPU任务间的相互干扰。 + + **如何在NVIDIA超级芯片上启用MPAM** + 1. 安装[NVIDIA NV-Kernels](https://github.com/NVIDIA/NV-Kernels/tree/24.04_linux-nvidia-adv-6.11)提供的内核。 + 2. 检查MPAM支持情况: + ```bash + grep MPAM /boot/config-$(uname -r) + ``` + 预期输出: + ``` + CONFIG_ARM64_MPAM=y + CONFIG_ACPI_MPAM=y + CONFIG_ARM64_MPAM_DRIVER=y + CONFIG_ARM64_MPAM_RESCTRL_FS=y + ``` + 检查resctrl文件系统: + ```bash + ls -ld /sys/fs/resctrl + ``` + 3. 挂载resctrl: + ```bash + mount -t resctrl resctrl /sys/fs/resctrl + ``` + 4. 建立分区: + ```bash + mkdir /sys/fs/resctrl/p1 /sys/fs/resctrl/p2 + ``` + 5. 设定CPU内核与内存配置: + ```bash + /sys/fs/resctrl/p1/cpus_list: + 0-6 + /sys/fs/resctrl/p2/cpus_list: + 7-71 + /sys/fs/resctrl/p1/schemata: + MB:1=100 + L3:1=ff0 + /sys/fs/resctrl/p2/schemata: + MB:1=20 + L3:1=f + ``` + +--- + +## 快速使用指南 + +我们已在教程/说明文档[DeepSpeedExamples: SuperOffload](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-SuperOffload#readme)中提供了SuperOffload的端到端微调示例。请在DeepSpeed配置中添加以下开关(完整上下文请参阅教程): + +
+Enable SuperOffload +

图6:通过在DeepSpeed配置中添加单行代码即可启用SuperOffload。

+
+ +提示:在超级芯片平台(如GH200/GB200/MI300A)上,结合"经验与洞察"章节中的NUMA绑定与MPAM设置,可稳定带宽并提升端到端性能。 + +--- + +## 致谢 + +本成果由[University of Illinois Urbana-Champaign (UIUC)](https://supercomputing-system-ai-lab.github.io/), [Anyscale](https://www.anyscale.com/)与[Snowflake](https://www.snowflake.com/en/blog/authors/snowflake-ai-research/)紧密协作完成。 + +我们同时衷心感谢美国国家超级计算应用中心的William Gropp、Brett Bode和Gregory H. Bauer,以及NVIDIA的Dan Ernst、Ian Karlin、Giridhar Chukkapalli、Kurt Rago等专家就Grace CPU的MPAM支持提供的宝贵讨论与指导。 + +欢迎社区反馈与贡献。具体启用方法与示例请参阅前文「快速开始」章节。 + +--- + +## BibTeX + +```bibtex +@inproceedings{superoffload, + author = {Xinyu Lian and Masahiro Tanaka and Olatunji Ruwase and Minjia Zhang}, + title = "{SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips}", + year = {2026}, + booktitle = {Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating System (ASPLOS'26)} +} +``` diff --git a/blogs/deepspeed-superoffload/images/superoffload_cast_transfer.jpg b/blogs/deepspeed-superoffload/images/superoffload_cast_transfer.jpg new file mode 100644 index 000000000000..08c0dca59c59 Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_cast_transfer.jpg differ diff --git a/blogs/deepspeed-superoffload/images/superoffload_comparison.jpg b/blogs/deepspeed-superoffload/images/superoffload_comparison.jpg new file mode 100644 index 000000000000..15ab8e03915b Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_comparison.jpg differ diff --git a/blogs/deepspeed-superoffload/images/superoffload_enable.jpg b/blogs/deepspeed-superoffload/images/superoffload_enable.jpg new file mode 100644 index 000000000000..471f00c9ccd8 Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_enable.jpg differ diff --git a/blogs/deepspeed-superoffload/images/superoffload_grace_adam.png b/blogs/deepspeed-superoffload/images/superoffload_grace_adam.png new file mode 100644 index 000000000000..b0f7f3ebf3a0 Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_grace_adam.png differ diff --git a/blogs/deepspeed-superoffload/images/superoffload_rollback.jpg b/blogs/deepspeed-superoffload/images/superoffload_rollback.jpg new file mode 100644 index 000000000000..861caa91f77e Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_rollback.jpg differ diff --git a/blogs/deepspeed-superoffload/images/superoffload_schedule.jpg b/blogs/deepspeed-superoffload/images/superoffload_schedule.jpg new file mode 100644 index 000000000000..93341929ecc1 Binary files /dev/null and b/blogs/deepspeed-superoffload/images/superoffload_schedule.jpg differ diff --git a/blogs/deepspeed-triton/README.md b/blogs/deepspeed-triton/README.md new file mode 100644 index 000000000000..57922c5e1a23 --- /dev/null +++ b/blogs/deepspeed-triton/README.md @@ -0,0 +1,95 @@ +# DeepSpeed with Triton compiler + +# 1. Overview + +We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision. +By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1. + +
+ +| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large | +|----------|:------:|:------:|:------:|:------:| +| A100 |1.65x | 1.68x | 1.53x | 1.61x | +| V100 | 1.29x | 1.14x | 1.23x | 1.21x | + +Table 1. The average speedup (see NOTE below for more detail) + + +
+ +For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators. +The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details). +In our experiments, Triton kernels help to reduce the average latency (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels. + + +Figures below show the latency reduction in more detail. +Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model. +The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels. +Figure 2 shows the same plot for Bert-large model in A100 GPU. + +
+ +triton-bert-base-latency + +*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths* + +triton-bert-large-latency + +*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths* + +
+ + +Next, we dive deeper into this new feature in DeepSpeed. + +# 2. How to use Triton in Deepspeed + +You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file. + +``` +pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0) +pipe.model = deepspeed.init_inference(pipe.model, + dtype=torch.float16, + replace_with_kernel_inject=True, + enable_cuda_graph=True, + use_triton=True, + triton_autotune=True, + max_out_tokens=pipe.tokenizer.model_max_length) +``` + + +## Running BERT inference with Triton kernels + +We use an example of Bert-base here. + +```python +pip install deepspeed[triton] + +git clone https://github.com/deepspeedai/DeepSpeedExamples.git +cd DeepSpeedExamples/inference/huggingface/fill-mask + +deepspeed --num_gpus 1 test-bert.py --triton +``` + +To run a performance benchmark, you can use the following command: + +```python +pip install deepspeed[triton] + +git clone https://github.com/deepspeedai/DeepSpeedExamples.git +cd DeepSpeedExamples/benchmarks/inference + +deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton +``` + +# NOTE + +* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/deepspeedai/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation. + +* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet. + +* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time. + +* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments. + +* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config. diff --git a/blogs/deepspeed-ucp/README.md b/blogs/deepspeed-ucp/README.md new file mode 100644 index 000000000000..3420e72c238e --- /dev/null +++ b/blogs/deepspeed-ucp/README.md @@ -0,0 +1,273 @@ +
+ +# DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training + +
+ + + +To cite DeepSpeed Universal Checkpoint, please cite our [arxiv report](https://arxiv.org/abs/2406.18820): + +``` +@article{lian2024-ucp, +title={Universal Checkpointing: Efficient and Flexible Checkpointing for +Large Scale Distributed Training}, +author={Xinyu Lian and Sam Ade Jacobs and Lev Kurilenko and Masahiro Tanaka +and Stas Bekman and Olatunji Ruwase and Minjia Zhang}, +journal={arxiv preprint arxiv:406.18820}, +year={2024}, + +} +``` + +# Introduction + +Checkpointing is a crucial technique for reducing the cost of training +machine learning models, as it enables saving the model state during the process. +This way, if the system fails, the training can resume from the most recent checkpoint +instead of from the beginning. Additionally, checkpointing allows for +evaluating the model performance at various stages of training, which +facilitates hyperparameter tuning and finetuning for different and +varied downstream tasks. + +However, there are challenges in the design, implementation and usage of +checkpointing especially in distributed training and finetuning +scenarios. Parallel training methods such as ZeRO data parallelism (ZeRO-DP), +pipeline parallelism (PP), tensor parallelism (TP) and sequence +parallelism (SP) are popular technologies for accelerating LLMs training. +However, elastic and flexible composition of these different parallelism +topologies with checkpointing is not currently available, in part, because +these techniques shard model and/or optimizer states making it difficult to +resume training with a checkpoint that was created on a different number of GPUs or +accelerators. + +In this release, we are excited to introduce DeepSpeed Universal +Checkpointing (*UCP*), a most comprehensive solution to the problem of +distributed checkpointing. *UCP* enables efficient checkpoint creation +while providing the flexibility of resuming on arbitrary parallelism +strategies and hardware configurations. *UCP* also unlocks unprecedented +capabilities for large-scale training such as improved resilience to +hardware failures through continued training on remaining healthy +hardware, and reduced training time through opportunistic exploitation +of elastic capacity. + +In summary, this release of *UCP* unlocks the following capabilities: + +- Flexible checkpoints reshape along any of the training parallelism + techniques (i.e., PP, TP, DP, ZeRO-DP, SP, MoE) + +- Elastic resource management, scale up or scale down of training and + finetuning accelerator resources + +- Real world examples with support for multiple commercial-scale models + (i.e., BLOOM, Megatron GPT, LLAMA, Microsoft Phi) + +# Core Design + +The key insight of DeepSpeed *UCP* is the selection of the optimal +representation in each phase of the checkpointing life cycle: +distributed representation for saving, and consolidated representation +for loading. This is achieved using two key mechanisms. First, the +universal checkpoint format, which consists of a consolidated +representation of each model parameter, and metadata for mapping +parameter fragments to the ranks of an arbitrary parallel training +configuration. Second, the universal checkpoint language, a simple but +powerful and robust specification language for converting distributed +checkpoints into the universal checkpoint format. + +## Universal Checkpoint Format + + + +Figure 1: UCP overview: top row and bottom row are Source and Target +parallelism configurations respectively. The middle row shows UCP as +an intermediate format of translating from Source to Target. + +Figure 1 shows high level schematic description of *UCP* conversion +process and format. Conversion starts with top block of checkpointing in +any parallel format e.g, DP, TP, PP, SP. Saving in the native format of parallel training avoids any overhead of +consolidating into a single global checkpoint. To ensure that +a checkpoint saved in one parallel configuration (herein called *Source*) can be +easily converted and loaded for continuous training in another parallel configuration (herein called *Target*), +we introduce the idea of atomic checkpoint as an intermediate format. + +The concept of atomic checkpoint is central to *UCP*. These are +fine-grained files containing the consolidated representation of each +model parameter, along with optimizer states. The atomic checkpoint +format is useful for three reasons. First, the atomic representation of +checkpoints decouples the dependencies of distributed checkpoints and +specific parallelism techniques and hardware configurations. As such, +one does not need to implement individual converters for each *Source* +and *Target* pair. Instead, *UCP* can act as a common interchange format +between different distributed training techniques, which then can be +easily transformed into other distributed training strategies, as shown +in Fig 2. By keeping the consolidated representation of each model +parameter, *UCP* enables easy splitting and flexible mapping of model states +or fragmented states to different GPUs on a parameter-by-parameter +basis, effectively reducing the working memory needed to load large +model checkpoints. Second, the *UCP* conversion happens lazily and +on-demand, e.g., when a training process detects a change of parallelism +technique and hardware configuration. In other words, the existing +distributed checkpoint saving logic does not need any change. Third, the +structure of the *UCP* also makes it easy to handle advanced techniques +in distributed training, such as mixed-precision training. In practice, +researchers and practitioners may switch between fp16 and bfloat16 mixed +precision training. By keeping the fp32 weight/optimizer values, the +training can resume either with fp16 or bfloat16. + +## Universal Checkpoint Language + + + +Figure 2: UCP language helps transform distributed checkpoints into the +UCP format and load UCP checkpoints based on the Target parallel +technique and new hardware configuration. + + +While *UCP* provides a common interface for different parallelism +strategies, the development of transformation from arbitrary distributed +checkpoints to *UCP* can still incur a high engineering and +implementation cost. This is because the number of distributed checkpoint files +and their contents can vary across the different parallel training techniques. + +To tackle this challenge, *UCP* provides *UCP* language, which is a +simple but powerful specification language for converting a distributed checkpoint +into the atomic checkpoint format, described in previous +section. *UCP* does this in two ways. First, it provides a declarative +system with pre-defined *parameter patterns*, which cover a wide range +of parallelism strategies for model states. Parameter patterns contain +runtime information about how a parameter is partitioned across GPUs. +For instance, *nopattern* means that a parameter is uniquely associated +with a GPU rank, which is the most common pattern seen in techniques +such as ZeRO-1/2 and PP (see our technical report for a completed list +of currently supported parameter patterns). Second, *UCP* language +provides a set of common operators that facilitate the transformation of +distributed checkpoints into consolidated atomic checkpoints. At a +high-level, as illustrated in Figure 3, *UCP* language is invoked when +support for a new *Target* is needed or the hardware +configuration changes. It first transforms distributed checkpoints into +the *UCP* format. It then loads the *UCP* checkpoints based on the +*Target* parallel technique and new hardware configuration. + +# Key Results + +We evaluate *UCP* through a series of experiments on training LLMs. We +focus on the decoder-only Transformers: an architecture chosen due to +its state-of-the-art performance. Some of the largest models are also +decoder-based, making flexible and efficient checkpointing especially +important. In this blog, we present results of correctness verification +across different models and parallel strategies. For more results on +parallel efficiency analysis, detailed system and model architectures +and training hyperparameters, please see our technical report referenced +above. + +*UCP* provides flexible checkpointing from a *Source* parallelism +strategy to a different *Target* with different hardware configurations. +To verify this capability, we conduct correctness tests of *UCP* with +two groups of experiments. + +## Single Source to Multiple Target + + + +Figure 3: Training curves of loading UCP checkpoints into different +Target at iteration 101 with various GPU counts and parallelism +strategies + +To test if UCP allows resuming training with different parallelism +strategies and hardware configuration, we first train the GPT-3 model +using a configuration of TP=2, PP=2, DP=2 (ZeRO-1), and SP=1. Due to +constraints in time and resources, we limited the experiment to the +first 200 iterations. We convert the checkpoints saved at the 100th +iteration to *UCP* checkpoints and resume training with these *UCP* +checkpoints using different GPU counts and parallelism strategies. We +record the LM loss (average losses across the data parallel group) for +each iteration. Figure 3 illustrates that the training can be seamlessly +resumed with *UCP* checkpoints using different *Target* parallelism +strategies, achieving consistent convergence if the training were to +continue with the *Source* strategy. + +## Multiple Source to Single Target + + + +Figure 4: Training curves of transforming different Source parallelism +strategies at iteration 100 to UCP and loading UCP with a different +Target. + +Figure 4 shows the training curves from multiple *Source* configurations +to a single *Target*. Given a fixed random seed, we first train the +GPT-3 model using different *Source* configurations. We then convert +their distributed checkpoints saved at the 100th iteration to *UCP* +checkpoints and resume training with a configuration of TP=2, PP=2, +DP=1, and SP=1. The results show that regardless of the different +*Source* configurations, their checkpoints can all be converted into +*UCP* and resume training with a different configuration. Most +importantly, the resumed training curves match the curves from the +*Source* at iterations 101--200. These results validate the +effectiveness of *UCP* of converting an arbitrary configuration to a +different configuration for resumed training. + +## Varying Model Architectures + +*UCP* is model architecture agnostic. As such, it is not only compatible +with GPT models but also flexible enough to support various other model +architectures and sizes. Figures 5, 6 and 7 show the training +convergence for LLaMA 7B, BLOOM 176B, and a variant of Mixtral-7x8B MoE, +when resuming from *UCP* at the middle of training with new parallelism +strategies. These figures show that training is seamlessly resumed with +*UCP*, achieving consistent convergence that aligns with the initial +training phase across these diverse models. These results suggest that +*UCP* is quite flexible for various model architectures and sizes. + + + +Figure 5: Training curve with LLaMA model architecture. Source is +TP=PP=DP=2. Training is resumed at iteration 101 with new Targets +TP=DP=2, PP=1 and TP=PP=2, DP=1 + + + +Figure 6: Training curve of BLOOM model architecture. Source is TP=2, +PP=24, DP=8. Training is resumed at iteration 94767 with a new Targets +TP=2, DP=4, PP=24. + + + +Figure 7: Training curve with a variant of the Mixtral-MoE model +architecture. Source is TP=1, PP=2, DP=4. Training is resumed at +iteration 501 with a new Target TP=PP=DP=2. + +# General Availability of DeepSpeed Universal Checkpoint + +We are excited to release DeepSpeed Universal Checkpoint. DeepSpeed +Universal Checkpoint is available in DeepSpeed versions >= +[0.14.4](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.14.4), +has been fully integrated with [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) ([commit c3a13be](https://github.com/deepspeedai/Megatron-DeepSpeed/commit/c3a13be721da0d0de16c338d0d665b0f7d13d14f)). +Detailed tutorial on usage is available on +[DeepSpeed tutorial page](https://www.deepspeed.ai/tutorials/universal-checkpointing/). + +We welcome contributions and collaboration from the broader open-source +community. DeepSpeed Universal Checkpoint is part of the bigger +DeepSpeed ecosystem of large-scale AI training and inference. For more +details on all DeepSpeed technologies and innovations, please visit our +[website]((https://www.deepspeed.ai/)) and follow us +on X, formerly Twitter, ([English](https://twitter.com/DeepSpeedAI), +[Japanese](https://twitter.com/DeepSpeedAI_JP)) +and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed). + +# Acknowledgements and Contributions +We thank the collaboration of University of Illinois at Urbana-Champaign, +Statosphere, and Intel Habana. + +Contributions: +Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$, +Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$ + +1: University of Illinois at Urbana-Champaign +2: Microsoft +3: StasoSphere +4: Intel Habana diff --git a/blogs/deepspeed-ucp/chinese/README.md b/blogs/deepspeed-ucp/chinese/README.md new file mode 100644 index 000000000000..4260595671dc --- /dev/null +++ b/blogs/deepspeed-ucp/chinese/README.md @@ -0,0 +1,124 @@ + +
+ +# DeepSpeed通用检查点:用于大规模分布式训练的高效灵活检查点系统 + +
+ + + +要引用DeepSpeed通用检查点,请引用我们的[arxiv报告](https://arxiv.org/abs/2406.18820): + +``` +@article{lian2024-ucp, +title={Universal Checkpointing: Efficient and Flexible Checkpointing for +Large Scale Distributed Training}, +author={Xinyu Lian and Sam Ade Jacobs and Lev Kurilenko and Masahiro Tanaka +and Stas Bekman and Olatunji Ruwase and Minjia Zhang}, +journal={arxiv preprint arxiv:406.18820}, +year={2024}, + +} +``` + +# 引言 + +检查点是降低训练大型语言模型成本的关键技术,它使我们在训练过程中可以保存模型状态。这样,如果训练失败,训练可以从最后保存的点继续,而不是从头开始。此外,检查点还允许在训练的不同阶段评估模型性能,从而便于进行超参数调整以及针对不同和多样化下游任务的微调。 + +然而,在分布式训练和微调场景中设计、实施和使用检查点存在困难。ZeRO数据并行(ZeRO-DP)、流水线并行(PP)、张量并行(TP)和序列并行(SP)等方法是加速大型语言模型训练的出色技术,但与传统的默认(Torch)保存和加载检查点机制不兼容。此外,目前尚无技术支持将这些不同的并行拓扑与检查点灵活组合,部分原因是这些技术将模型和/或优化器状态分片,使得在不同GPU或加速器数量上创建的检查点难以用于恢复训练。 + +在此,我们很高兴地发布DeepSpeed通用检查点(*UCP*),这是解决分布式检查点问题的最全面的解决方案。*UCP*在高效创建检查点的同时,提供了在任意并行策略和硬件配置上恢复的灵活性。*UCP*还解锁了大规模训练的前所未有的能力,例如通过在剩余健康硬件上继续训练来提高对硬件故障的抵抗力,以及通过机会性利用弹性容量来减少训练时间。 + +简单来说,当前版本的*UCP*解锁了以下功能: + +- 灵活的检查点可沿任何训练并行技术(即PP、TP、DP、ZeRO-DP、SP、MoE)重塑训练 + +- 弹性资源管理,在训练和微调中随意增加或减少硬件资源 + +- 支持多种商业规模模型的真实世界用例(例如BLOOM、Megatron GPT、LLAMA、Microsoft Phi) + +# 核心设计 + +DeepSpeed *UCP*的关键洞察是在检查点生命周期的每个阶段选择最佳表示:分布式表示用于保存,合并表示用于加载。这通过两个关键机制实现。首先,通用检查点格式,它包括每个模型参数的合并表示和用于将参数片段映射到任意模型并行配置的训练级别的元数据。其次,通用检查点语言,这是一个简单但强大且健壮的规范语言,用于将分布式检查点转换为通用检查点格式。 + +## 通用检查点格式 + + + +图1:UCP概述:顶部行和底部行分别为源并行配置和目标并行配置。中间行显示UCP作为从源到目标的转换中介块。 + +图1显示了*UCP*转换过程和格式的整体概念性描述。转换从任何并行策略格式的检查点顶部块开始。允许以训练的本地格式保存消除了可能因同步全局检查点保存而产生的任何开销。为确保保存的检查点(称为*源*)可以轻松转换并加载到任何并行策略以进行连续训练(称为*目标*),我们引入了作为中介块的原子检查点格式的概念。 + +原子检查点是*UCP*的核心概念。这些是包含每个模型参数的合并表示及其优化器状态的细粒度文件。原子检查点格式有三个用途。首先,原子检查点的表示解除了分布式检查点与特定并行技术和硬件配置的依赖。因此,无需为每个*源*到*目标*实现单独的转换器。相反,*UCP*可以充当不同分布式训练技术之间的通用交换格式,然后可以轻松地转换为其他分布式训练策略,如图2所示。通过保持每个模型参数的合并表示,*UCP*可以轻松地将模型状态或片段状态拆分并灵活地映射到不同GPU上,有效减少加载大型模型检查点所需的工作内存。其次,*UCP*转换是懒惰和按需进行的,例如,当训练过程检测到并行技术和硬件配置的变化时。换句话说,现有的分布式检查点保存逻辑不需要任何改变。第三,*UCP*的结构还易于处理分布式训练中的高级技术,例如混合精度训练。在实践中,研究人员和从业者可能在fp16和bfloat16混合精度训练之间切换。通过保持fp32的权重/优化器值,训练可以继续使用fp16或bfloat16恢复。 + +## 通用检查点语言 + + + +图2:UCP语言帮助将分布式检查点转换为UCP格式,并根据目标并行技术和新硬件配置加载UCP检查点。 + + +虽然*UCP*为不同的并行策略提供了一个公共接口,但从任意分布式检查点到*UCP*的转换仍然可能具有不菲的工程和实施成本。这是因为分布式训练中的每个GPU都调用一个持久方法(例如,在PyTorch中使用torch.save())将其拥有的GPU模型状态保存到磁盘上的检查点文件中,而每个检查点的具体内容在不同技术之间会有所不同。 + +为了应对这一挑战,*UCP*提供了*UCP*语言,这是一个简单但强大的规范语言,用于将几种类型的分布式检查点转换为前一节中描述的通用格式。*UCP*以两种方式实现这一点。首先,它提供了一个具有预定义*参数模式*的声明式系统,这些模式涵盖了模型状态的广泛并行 + +策略。参数模式包含有关参数如何在GPU之间分区的运行时信息。例如,*nopattern*表示一个参数与某个GPU唯一相关,这是ZeRO-1/2和PP等技术中最常见的模式(参见我们的技术报告,以获得当前支持的参数模式完整列表)。其次,*UCP*语言提供了一组常见操作符,以便将分布式检查点转换为合并的原子检查点。从高层次来看,如图3所示,当需要新的*目标*并行技术或硬件配置发生变化时,将调用*UCP*语言。它首先将分布式检查点转换为*UCP*格式。然后根据*目标*并行技术和新硬件配置加载*UCP*检查点。 + +# 关键结果 + +我们通过一系列实验评估*UCP*,专注于仅解码器的Transformers架构,这是由于其最先进的性能。一些最大的模型也是基于解码器的,这使得灵活高效的检查点尤为重要。在本博客中,我们展示了在不同模型和并行策略下正确性验证的结果。有关并行效率分析、详细的系统和模型架构以及训练超参数的更多结果,请参阅上面引用的技术报告。 + +*UCP*提供了从一个*源*并行策略到不同的*目标*和不同硬件配置的灵活检查点。为验证这一能力,我们进行了正确性测试的两组实验。 + +## 单源到多目标 + + + +图3:在第101次迭代时使用不同目标加载UCP检查点的训练曲线,具有不同GPU数量和并行策略 + +为测试UCP是否允许使用不同并行策略和硬件配置恢复训练,我们首先使用TP=2、PP=2、DP=2(ZeRO-1)和SP=1的配置训练GPT-3模型。由于时间和资源的限制,我们将实验限制在前200次迭代。我们将在第100次迭代保存的检查点转换为*UCP*检查点,并使用不同GPU数量和并行策略恢复训练。我们记录了每次迭代的LM损失(数据并行组的平均损失)。图3显示,训练可以使用不同的*目标*并行策略无缝地使用*UCP*检查点恢复,如果训练继续使用*源*策略,将实现一致的收敛。 + +## 多源到单目标 + + + +图4:在第100次迭代将不同源并行策略转换为UCP并加载UCP的训练曲线,具有不同的目标。 + +图4显示了从多个*源*配置到单一*目标*的训练曲线。在固定随机种子的情况下,我们首先使用不同的*源*配置训练GPT-3模型。然后我们将它们在第100次迭代保存的分布式检查点转换为*UCP*检查点,并使用TP=2、PP=2、DP=1和SP=1的配置恢复训练。结果显示,无论不同的*源*配置如何,它们的检查点都可以转换为*UCP*并使用不同的配置恢复训练。最重要的是,恢复的训练曲线与第101--200次迭代的*源*曲线匹配。这些结果验证了*UCP*将任意配置转换为不同配置以恢复训练的有效性。 + +## 不同模型架构的变化 + +*UCP*与模型架构无关。因此,它不仅与GPT模型兼容,而且足够灵活,可以支持各种其他模型架构和大小。图5、6和7显示了使用新并行策略从*UCP*中恢复训练时的训练收敛情况。这些图表显示,训练可以使用*UCP*无缝恢复,实现与初始训练阶段一致的收敛,这与这些不同模型相符。这些结果表明,*UCP*对于各种模型架构和大小都非常灵活。 + + + +图5:使用LLaMA模型架构的训练曲线。源是TP=PP=DP=2。训练在第101次迭代时使用新目标TP=DP=2, PP=1和TP=PP=2, DP=1恢复 + + + +图6:使用BLOOM模型架构的训练曲线。源是TP=2, PP=24, DP=8。训练在第94767次迭代时使用新目标TP=2, DP=4, PP=24恢复。 + + + +图7:使用Mixtral-MoE模型架构变种的训练曲线。源是TP=1, PP=2, DP=4。训练在第501次迭代时使用新目标TP=PP=DP=2恢复。 + +# DeepSpeed通用检查点的普遍可用性 + +我们很高兴发布DeepSpeed通用检查点。DeepSpeed通用检查点已与Megatron-DeepSpeed的重构版本完全集成,并可通过DeepSpeed和Megatron-DeepSpeed的GitHub仓库访问。详细的使用教程可在[DeepSpeed教程页面](https://www.deepspeed.ai/tutorials/universal-checkpointing/)上找到。 + +我们欢迎来自更广泛开源社区的贡献和合作。DeepSpeed通用检查点是大规模AI训练和推理DeepSpeed生态系统的一部分。有关所有DeepSpeed技术和创新的更多详细信息,请访问我们的[网站](https://www.deepspeed.ai/)并在X(前Twitter)([英文](https://twitter.com/DeepSpeedAI),[日文](https://twitter.com/DeepSpeedAI_JP))和[中文知乎](https://www.zhihu.com/people/deepspeed)上关注我们。 + +# 致谢和贡献 +我们感谢伊利诺伊大学厄巴纳-香槟分校、Statosphere和英特尔Habana的合作。 + +贡献者: +Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$, +Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$ + +1: 伊利诺伊大学厄巴纳-香槟分校 +2: 微软 +3: Statosphere +4: 英特尔Habana diff --git a/blogs/deepspeed-ucp/japanese/README.md b/blogs/deepspeed-ucp/japanese/README.md new file mode 100644 index 000000000000..24da72298dd8 --- /dev/null +++ b/blogs/deepspeed-ucp/japanese/README.md @@ -0,0 +1,115 @@ +
+ +# DeepSpeed Universal Checkpointing: 大規模分散学習のための効率的かつ柔軟なチェックポイント + +
+ + + +DeepSpeed Universal Checkpointを引用する際は、こちらの[arxiv report](https://arxiv.org/abs/2406.18820)を参照してください。 + +``` +@article{lian2024-ucp, +title={Universal Checkpointing: Efficient and Flexible Checkpointing for +Large Scale Distributed Training}, +author={Xinyu Lian and Sam Ade Jacobs and Masahiro Tanaka and Lev +Kurilenko and Stas Bekman and Olatunji Ruwase and Minjia Zhang}, +journal={arxiv preprint arxiv:406.18820}, +year={2024}, + +} +``` + +# はじめに + +モデルの状態を保存するをチェックポイントは、システム障害が発生した場合に途中から学習を再開するために、LLMのトレーニングコストを削減するための重要な技術です。さらに、学習のさまざまな段階でモデルのパフォーマンスを評価することができるため、ハイパーパラメータの調整や異なる下流タスクのためのファインチューニングが容易になります。 + +しかし、特に分散学習やファインチューニングのシナリオにおいて、チェックポイントの設計、実装、および使用には多くの課題があります。DeepSpeedが備えるZeROを用いたデータ並列化(ZeRO-DP)、パイプライン並列化(PP)、テンソル並列化(TP)、およびシーケンス並列化(SP)などのいくつかの方法は、LLM学習を加速するための優れた技術ですが、一般的なチェックポイント保存と読み込みのメカニズムと互換性がありません。さらに、これらの異なる並列化を用いたエラスティックで柔軟な組み合わせは、現在サポートされていません。主な理由の一つは、こうした並列化技術がモデルおよび/またはオプティマイザの状態を分割するため、異なるGPUまたはアクセラレータの数に基づいて作成されたチェックポイントから学習を再開することが困難であるためです。 + +このリリースでは、分散チェックポイントの問題に対する包括的なソリューションであるDeepSpeed Universal Checkpointing (*UCP*) を紹介します。*UCP*は、任意の並列化戦略とハードウェア構成で再開する柔軟性を提供しながら、効率的なチェックポイント作成を可能にします。また、*UCP*は、ハードウェア障害の際にも、残りの正常なハードウェアでのトレーニングの継続を可能にするため、キャパシティがエラスティックに変化するハードウェアを活用でき、トレーニング時間を短縮するなど、大規模学習を最大限に効率化できます。 + +現在のリリースには、*UCP*の次の機能が含まれます。 + +- 任意のトレーニング並列技術(例:PP、TP、DP、ZeRO-DP、SP、MoE)に沿った柔軟なチェックポイントの再構成 +- ファインチューニングを含む学習およびアクセラレータリソースのエラスティックなリソース管理、スケールアップまたはスケールダウン +- BLOOM、Megatron GPT、LLAMA、Microsoft Phiなどの複数の商用規模モデルのサポートを伴う実利用例 + +# UCPの設計 + +DeepSpeed *UCP*における中心的な考え方は、チェックポイントライフサイクルの各段階で最適な表現を選択することです。保存のための分散表現と、読み込みのための統合表現です。これは、2つの重要なメカニズムを使用して実現されます。一つ目は、各モデルパラメータの統合表現と、パラメータのフラグメントを任意のモデル並列化構成におけるランク(プロセスのインデックス)にマッピングするためのメタデータからなるユニバーサルチェックポイントフォーマットです。二つ目は、分散チェックポイントをユニバーサルチェックポイント形式に変換するためのシンプルで強力かつ堅牢な仕様言語であるユニバーサルチェックポイント言語です。 + +## ユニバーサルチェックポイントフォーマット + + + +図1:*UCP*の概要:上段と下段はそれぞれソースとターゲットの並列化構成です。中央の段は、ソースからターゲットへの翻訳の仲介ブロックとしての*UCP*を示しています。 + +図1は、*UCP*の変換プロセスとフォーマットの抽象レベルの概略図を示しています。変換は、DP、TP、PP、SPなどの任意の並列戦略形式のチェックポイントから始まります。訓練結果のモデルやオプティマイザ状態をネイティブ形式で保存することで、同期されたグローバルチェックポイントの保存に伴うオーバーヘッドを回避します。保存されたチェックポイント(以下、*ソース*と呼びます)を任意の並列戦略に簡単に変換してロードできるようにするために、中間ブロックとして原子チェックポイント (atomic checkpoint) 形式のアイデアを導入します。 + +原子チェックポイントの概念は、*UCP*の中心となるものです。これらは、各モデルパラメータの統合表現とオプティマイザ状態を含む細粒度のファイルです。原子チェックポイント形式は、次の3つの理由で有用です。まず、チェックポイントの原子表現は、分散チェックポイントと特定の並列技術およびハードウェア構成の依存関係を切り離します。そのため、*ソース*から*ターゲット*への個別のコンバータを実装する必要はありません。代わりに、*UCP*は異なる分散トレーニング技術間の共通交換形式として機能し、他の分散トレーニング戦略に簡単に変換できます(図2参照)。各モデルパラメータの統合表現を保持することで、*UCP*はモデル状態またはフラグメント状態をパラメータごとに異なるGPUに柔軟にマッピングし、大規模モデルチェックポイントを読み込むために必要な作業メモリを効果的に削減します。第二に、*UCP*の変換は遅延してオンデマンドで行われます。たとえば、トレーニングプロセスが並列技術とハードウェア構成の変更を検出したときです。つまり、既存の分散チェックポイント保存ロジックには変更が必要ありません。第三に、*UCP*の構造により、混合精度トレーニングなどの高度な技術を分散トレーニングで簡単に処理できます。実際には、研究者や実務者はfp16とbfloat16の混合精度トレーニングを切り替えることがあります。fp32の重み/オプティマイザの値を保持することで、トレーニングはfp16またはbfloat16のいずれかで再開できます。 + +## ユニバーサルチェックポイント言語 + + + +図2:*UCP*言語は、分散チェックポイントを*UCP*形式に変換し、新しいハードウェア構成とターゲットの並列技術に基づいて*UCP*チェックポイントを読み込みます。 + +*UCP*は異なる並列戦略に対する共通インターフェースを提供しますが、任意の分散チェックポイントから*UCP*への変換の開発には依然として高いエンジニアリングおよび実装コストがかかる場合があります。これは、分散トレーニングの各GPUが保存のためのメソッド(例:PyTorchのtorch.save())を呼び出して、所有するGPUモデル状態のチェックポイントファイルをディスクに保存し、各チェックポイントの正確な内容が異なる技術によって異なるためです。 + +この課題に取り組むために、*UCP*は*UCP*言語を提供します。これは、前述の共通形式にいくつかの種類の分散チェックポイントを変換するためのシンプルで強力な仕様言語です。*UCP*はこれを2つの方法で行います。まず、モデル状態の並列戦略の広範な範囲をカバーする事前定義された*パラメータパターン*を持つ宣言型システムを提供します。パラメータパターンには、パラメータがGPU間でどのように分割されているかについてのランタイム情報が含まれています。たとえば、*nopattern*は、パラメータがGPUランクに一意に関連付けられていることを意味し、これはZeRO-1/2やPPなどの技術で最も一般的に見られるパターンです(現在サポートされているパラメータパターンの完全なリストについては、技術レポートを参照してください)。第二に、*UCP*言語は、分散チェックポイントを統合された原子チェックポイントに変換するための一般的な演算子のセットを提供します。抽象的なレベルで見ると、図2に示すように、ターゲットへの移行後に新しい並列技術が必要な場合やハードウェア構成が変更された場合に、*UCP*言語が使用されます。最初に、分散チェックポイントを*UCP*形式に変換し、次にターゲットの並列技術と新しいハードウェア構成に基づいて*UCP*チェックポイントを読み込みます。 + +# 主要な結果 + +我々は、LLMの訓練に関する一連の実験を通じて*UCP*を評価します。デコーダーのみのトランスフォーマーに焦点を当てました。これは最先端のパフォーマンスを持つアーキテクチャです。いくつかの最大のモデルもデコーダーベースであるため、柔軟で効率的なチェックポイントは特に重要です。このブログでは、さまざまなモデルと並列戦略にわたる正確性の検証結果を紹介します。並列効率分析、詳細なシステムおよびモデルアーキテクチャ、および訓練のハイパーパラメータに関する詳細な結果については、上記の技術レポートを参照してください。 + +*UCP*は、異なるハードウェア構成を持つ異なる*ターゲットの*並列戦略に対する*ソース*並列戦略からの柔軟なチェックポイントを提供します。この能力を検証するために、2つの実験グループで*UCP*の正確さを確認しました。 + +## シングルソースから複数のターゲットへ + + + +図3:さまざまなGPU数と並列戦略で*ターゲット*に*UCP*チェックポイントをロードする訓練lossの曲線(イテレーション100で保存・ロード) + +*UCP*が異なる並列戦略とハードウェア構成での訓練再開を可能にするかどうかをテストするために、まずTP=2、PP=2、DP=2(ZeRO-1)、SP=1の構成でGPT-3モデルを訓練します。時間とリソースの制約のため、この実験は最初の200イテレーションに限定しました。100イテレーション目で保存されたチェックポイントを*UCP*チェックポイントに変換し、異なるGPU数と並列戦略を使用してこれらの*UCP*チェックポイントで訓練を再開します。各イテレーションのLM損失(データ並列グループ全体の平均損失)を記録しました。図3は、異なる*ターゲット*並列戦略を使用して*UCP*チェックポイントで訓練をシームレスに再開し、*ソース*戦略を継続して訓練する場合と一致する収束を達成することを示しています。 + +## 複数ソースからシングルターゲットへ + + + +図4:100イテレーション目で異なるソース並列戦略を*UCP*に変換し、異なるターゲットで*UCP*をロードする訓練lossの曲線 + +図4は、複数の*ソース*構成から単一の*ターゲット*へのlossの曲線を示しています。固定されたランダムシードを使用して、まずGPT-3モデルを異なる*ソース*構成で訓練します。次に、100イテレーション目で保存された分散チェックポイントを*UCP*チェックポイントに変換し、TP=2、PP=2、DP=1、SP=1の構成でトレーニングを再開します。結果は、異なる*ソース*構成にもかかわらず、そのチェックポイントはすべて*UCP*に変換され、異なる構成で訓練を再開できることを示しています。最も重要なのは、再開されたlossの曲線が、イテレーション101~200での*ソース*の曲線と一致することです。これらの結果は、訓練再開時に任意の構成を異なる構成に変換する*UCP*の効果を検証しています。 + +## 異なるモデルアーキテクチャへの対応 + +*UCP*はモデルアーキテクチャに依存しません。したがって、GPTモデルとの互換性だけでなく、さまざまなモデルアーキテクチャとサイズをサポートする柔軟性も備えています。図5、6、7は、新しい並列戦略で*UCP*から訓練を再開したときのLLaMA 7B、BLOOM 176B、およびMixtral-7x8B MoEを元にしたモデルのトレーニング収束を示しています。これらの図は、トレーニングが*UCP*でシームレスに再開され、これらの多様なモデル全体で訓練の初期フェーズと一致する収束を達成することを示しています。これらの結果は、さまざまなモデルアーキテクチャとサイズに対する*UCP*の柔軟性を示しています。 + +A graph of training step Description automatically generated + +図5:LLaMAモデルアーキテクチャの訓練lossの曲線。ソースはTP=PP=DP=2。訓練はイテレーション101で新しいターゲットTP=DP=2、PP=1およびTP=PP=2、DP=1で再開しました。 + +A graph with numbers and lines Description automatically generated + +図6:BLOOMモデルアーキテクチャの訓練lossの曲線。ソースはTP=2、PP=24、DP=8。訓練はイテレーション94767で新しいターゲットTP=2、DP=4、PP=24で再開しました。 + +A graph of training step Description automatically generated + +図7:Mixtral-MoEモデルアーキテクチャに基づくモデルの訓練lossの曲線。ソースはTP=1、PP=2、DP=4。訓練はイテレーション501で新しいターゲットTP=PP=DP=2で再開しました。 + +# DeepSpeed Universal Checkpointの一般公開 + +DeepSpeed Universal Checkpointは、リベースされたMegatron-DeepSpeedバージョンに完全に統合されており、DeepSpeedおよびMegatron-DeepSpeedのGitHubリポジトリを通じてアクセスできます。使用に関する詳細なチュートリアルは、[DeepSpeedチュートリアルページ](https://www.deepspeed.ai/tutorials/universal-checkpointing/)にあります。 + +DeepSpeedでは、広範なオープンソースコミュニティからの貢献とコラボレーションを受け入れています。DeepSpeed Universal Checkpointは、大規模AIトレーニングおよび推論のためのDeepSpeedエコシステムの一部です。すべてのDeepSpeed技術とイノベーションについての詳細は、[ウェブサイト](https://www.deepspeed.ai/)をご覧いただき、X(旧Twitter)での[英語](https://twitter.com/DeepSpeedAI)、[日本語](https://twitter.com/DeepSpeedAI_JP)、および[中国のZhihu](https://www.zhihu.com/people/deepspeed)をフォローしてください。 + +# 謝辞と貢献 + +University of Illinois at Urbana-Champaign、Statosphere、およびIntel Habanaとの協力に感謝します。 + +コントリビュータ: +Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$, Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$ + +1: University of Illinois at Urbana-Champaign +2: Microsoft +3: StasoSphere +4: Intel Habana diff --git a/blogs/deepspeed-ucp/media/flowchart.png b/blogs/deepspeed-ucp/media/flowchart.png new file mode 100644 index 000000000000..d5198ca00e03 Binary files /dev/null and b/blogs/deepspeed-ucp/media/flowchart.png differ diff --git a/blogs/deepspeed-ucp/media/image1.png b/blogs/deepspeed-ucp/media/image1.png new file mode 100755 index 000000000000..c9663de91cc2 Binary files /dev/null and b/blogs/deepspeed-ucp/media/image1.png differ diff --git a/blogs/deepspeed-ucp/media/image2.png b/blogs/deepspeed-ucp/media/image2.png new file mode 100644 index 000000000000..4262aa26600f Binary files /dev/null and b/blogs/deepspeed-ucp/media/image2.png differ diff --git a/blogs/deepspeed-ucp/media/image3.png b/blogs/deepspeed-ucp/media/image3.png new file mode 100755 index 000000000000..101a19c86ae5 Binary files /dev/null and b/blogs/deepspeed-ucp/media/image3.png differ diff --git a/blogs/deepspeed-ucp/media/image4.png b/blogs/deepspeed-ucp/media/image4.png new file mode 100755 index 000000000000..b4f083e8eeba Binary files /dev/null and b/blogs/deepspeed-ucp/media/image4.png differ diff --git a/blogs/deepspeed-ucp/media/image5.png b/blogs/deepspeed-ucp/media/image5.png new file mode 100755 index 000000000000..f0195ebc8d11 Binary files /dev/null and b/blogs/deepspeed-ucp/media/image5.png differ diff --git a/blogs/deepspeed-ucp/media/image6.png b/blogs/deepspeed-ucp/media/image6.png new file mode 100644 index 000000000000..19405123e79a Binary files /dev/null and b/blogs/deepspeed-ucp/media/image6.png differ diff --git a/blogs/deepspeed-ucp/media/image7.png b/blogs/deepspeed-ucp/media/image7.png new file mode 100644 index 000000000000..c2d383110a59 Binary files /dev/null and b/blogs/deepspeed-ucp/media/image7.png differ diff --git a/blogs/deepspeed-ucp/media/image8.png b/blogs/deepspeed-ucp/media/image8.png new file mode 100644 index 000000000000..0014db8b688f Binary files /dev/null and b/blogs/deepspeed-ucp/media/image8.png differ diff --git a/blogs/deepspeed-ulysses/README.md b/blogs/deepspeed-ulysses/README.md new file mode 100644 index 000000000000..68f68f08d110 --- /dev/null +++ b/blogs/deepspeed-ulysses/README.md @@ -0,0 +1,370 @@ +
+ +# DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models + +
+ +
+ + + +
+ +To cite DeepSpeed-Ulysses, please cite our [arxiv report](https://arxiv.org/abs/2309.14509): + +``` +@article{jacobs2023deepspeed, + title={DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models}, + author={Sam Ade Jacobs and Masahiro Tanaka and Chengming Zhang and Minjia Zhang and Shuaiwen Leon Song and Samyam Rajbhandari and Yuxiong He}, + journal={arXiv preprint arXiv:2309.14509}, + year={2023}, +} +``` + +## Introduction + +Training large models with long sequences is becoming very important +across the board from generative AI to models for scientific discovery. +On generative AI side, conversational AI, long document summarization +and video generation require reasoning over long contexts in spatial and +temporal domains. For example, multimodal foundation models such as ones +that process speech, images and waveforms concurrently require long +context reasoning over high dimensional inputs with extremely large +sequences. Similarly, chapter and book level summarization (estimated at +tens and hundreds of thousands of words) are of great importance in +conversational AI and abstract summarization tasks. + +Long sequence length is equally critical for AI for science opening +doors for better understanding of structure biology, health care, +climate and weather forecasting and large molecular simulation. For +instance, by adapting large language models with gene sequences, we can +create language models that can learn the evolutionary patterns of +genomes using simple alphabets and extremely long sequences (the human +genome has 6.4 billion letters). In health care, diagnostic predictive +model conditioned on entire patient care record requires context of +extremely long sequence. + +Despite the emerging importance of long sequence length for both +generative AI and AI for science, existing large model training systems +and the underlying parallelism technologies (data, tensor, pipeline, +sequence parallelism) are limited in their ability to support the +efficient long sequence training. Two challenges with existing +parallelism approach come to the fore. First, existing parallelism +approach such as data, tensor and pipeline parallelism cannot address +the scaling along sequence dimension. Second, existing sequence +parallelism approaches are not effective because of memory-communication +inefficiencies. Furthermore, existing +approaches have limited usability requiring intrusive and error prone +code refactoring. + +In this release, we are proud to introduce *DeepSpeed-Ulysses (or +Ulysses, a very long novel)*, a simple, portable, and effective +methodology for enabling highly efficient and scalable LLM training with +extremely long sequence lengths. + +DeepSpeed-Ulysses partitions individual samples along the sequence +dimension among participating GPU. Then right before the attention +computation, it employs *all-to-all communication* collective on the +partitioned queries, keys and values such that each GPU receives the +full sequence but only for a non-overlapping subset of the attention +heads. This allows the participating GPUs to compute attention for +different attention heads in parallel. Finally, DeepSpeed-Ulysses +employs another all-to-all to gather the results along the attention +heads while re-partitioning along the sequence dimension. + +The key properties of DeepSpeed-Ulysses and its implementation released +with this blog are as follows: + +* ***4x larger sequence lengths*** than existing systems, while +enabling training with sequences with ***over a million tokens***. + +* Communication reduction of ***over 10x*** compared to existing +systems, resulting in throughput improvements of ***up to 2.5x***, and +sustained throughput of over 175 TFlops/GPU (over 54% of hardware peak). + +* Fully general and implementation agnostic attention: DeepSpeed +sequence parallelism supports dense as well as sparse +attention, and it works with efficient attention implementations such as +FlashAttention v2. + +* Support for massive model training: DeepSpeed sequence parallelism +works together with ZeRO-3 to not only support large sequence lengths +but also massive model sizes. + +* Easy-to-use and portable, requiring minimal code changes to the +existing training frameworks. + +In subsequent sections, we provide detailed discussion of DeepSpeed-Ulysses +core design, communication complexity analysis, +experimental evaluation and comparison with existing work and highlight +of usability and guide on usage. + +## Core Design of DeepSpeed-Ulysses + +
+ + +*Figure 1: DeepSpeed sequence parallelism (DeepSpeed-Ulysses) design* +
+ +Figure 1 shows the core design of DeepSpeed-Ulysses. As with the known +transformer architecture, the design consists of input sequences *N* +partitioned across *P* available devices. Each local *N/P* partition is +projected into queries (Q), keys (K) and values (V) embeddings. Next, +(QKV) embeddings are gathered into global QKV through highly optimized +all-to-all collectives between participating compute devices. Sequel to +all-to-all collective is the attention computation per head in the form: + +$$Output\ context = Softmax\ (\frac{QK^{T}}{\sqrt{d}})V$$ + +After the attention computation, another all-to-all collective +transforms *output context* tensor of attention computation to sequence +(*N/P*) parallel for subsequent operators (MLP MatMul, layer norm etc) +in the remaining modules of transformer layer block. + +### Significant Communication Volume Reduction + +What distinguishes DeepSpeed-Ulysses from the other existing +long-sequence approaches is our much smaller aggregate communication +volume and overall better scalability with increasing degree of sequence +parallelism compared to existing solutions, as demonstrated by the +communication volume analysis below: + +On modern clusters with intra-node NVSwitch interconnect and inter-node +fat tree IB topology, the communication volume transmitted per link for +an all-to-all for aggregate message of size *M* over *P* GPUs is *M/P*. +For a transformer model with hidden size h, sequence length of N, and +parallelism degree of P, DeepSpeed sequence parallelism performs all-to-all for the QKV +projections with an aggregate message size of *3Nh* before the attention +computation, and another all-to-all for output context projection with a +size *Nh* for each transformer layer. Therefore, DeepSpeed sequence +parallelism incurs an aggregate communication volume per link of +***4Nh/P (or with the complexity of O(N/P).*** Note that this +communication volume is constant when both N and P are increased +proportionally. + +In contrast, the existing approaches like Megatron-LM incur +communication volume that increases linearly with N regardless of P, +resulting in the ***communication complexity of O(N).*** For instance, +Megatron-LM performs two *all-gather* with the message volume of *Nh* +and two *reduce-scatter* with the volume of *Nh* for each transformer +layer. However, the cost of each all-gather and reduce-scatter of size M +remains M when *P \>\> 1*, instead of *M/P*. Therefore, Megatron-LM +sequence parallelism incurs a communication volume per link of ***4Nh*** +which is P times larger than that for DeepSpeed sequence parallelism. +This allows DeepSpeed sequence parallelism to enable training with +extremely long sequences while achieving significantly higher training +efficiency compared to the existing approaches. Our evaluation results +match this analysis. + +### Additional Highlights of DeepSpeed-Ulysses + +***An Attention Agnostic Solution*** + +DeepSpeed implementation of distributed attention module is general +enough to support any attention: e.g., self-attention, cross-attention, +causal attention in both their dense and sparse counterparts, and their +various optimized kernels that support long-sequence at local attention +level such as different versions of FlashAttention. + +The generality property of DeepSpeed-Ulysses stems from the modular +nature of its core design: an attention-centric sequence parallelism +design. Prior to attention computation is sequence parallelism of N/P +partition, attention computation is head parallelism with full attention +per head but just with fewer heads, thus attention computation can be +replaced with any type of attention mechanisms, e.g., dense attention +and various forms of sparse attention. + +***Training Bigger Models with Longer Sequences through ZeRO-3 Integration*** + +While DeepSpeed sequence parallelism reduces the activation memory when +training with longer sequences, it does not impact the memory consumed +by the model states. Therefore, to support large sequence length +training with large language model, DeepSpeed sequence parallelism is +integrated with ZeRO-3. + +[ZeRO Redundancy Optimizer Stage 3 (ZeRO-3)](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) is a memory optimization technique for training large +models. Unlike the classic data parallel training of neural networks +where model states are replicated across data parallel ranks, ZeRO-3 +optimizes memory usage by partitioning model states across data parallel +ranks. However, with sequence parallelism, training data can be +considered in both batch (sample) and sequence dimensions and the +associated parallel groups combined to form a larger group for ZeRO +parallelism. + +Therefore, we extend ZeRO-3 partitioning to combination of data parallel +and sequence parallel ranks. In other words, in DeepSpeed sequence +parallelism, ZeRO partitions model states across both sequence and data +parallel group and collects per rank partitions (allgather) when they +are needed. Similarly, gradients are reduced across both data and +sequence parallel ranks for parameter update. ZeRO allows +for huge memory savings in both sequence and data dimensions and enables +scaling not just to large sequence lengths but also to large models. + +## Evaluation + +We evaluate DeepSpeed-Ulysses (Ulysses) on GPT, +a foundation model for many NLP tasks on up to 64 A100 GPUs with 40GB memory. Our +evaluations are four-fold: i) sequence length scalability, ii) +throughput for dense attention and comparison with existing system, and +iii) throughput with sparse attention and comparison with existing +system, iv) convergence study of DeepSpeed sequence parallelism. We discuss +and present evaluations from each of these categories next. + +### Sequence Length Scalability + +The first set of experiments is strong scaling of sequence length up to +1 million tokens on 1.2 billion parameter GPT model. Results of this +evaluation are shown in Figures 2. DeepSpeed sequence parallelism +allows increasing sequence length linearly with the +number of GPUs and +maintains similar computation throughput across different sequence +length at appropriate GPU count. + +
+ + +*Figure 2: DeepSpeed sequence parallelism strong scalability evaluation +at different sequence length and GPU count.* +
+ +### Dense Attention Evaluation + +Next, we evaluate Ulysses on 7 billion (7B) and 30 billion (30B) parameter +GPT dense attention models and compare against Megatron-LM's sequence +parallelism (Megatron LM) and Colossal AI sequence parallelism (ColAI-SP) on +32 and 64 A100 GPUs respectively. The results of these evaluations are shown +in Figures 3 and 4. + +We compare Ulysses with Megatron-LM and ColAI-SP for 7B and 30B models +running various sequence lengths. We chose the sequence parallelism +degree and micro-batch size that produced the best performance +(measured as TFLOPs) for the three methods, this we call optimal +(batch size-sequence length) configurations. For Ulysses, we always +use a ZeRO-3 parallelism degrees of 32 and 64 for 7B and 30B models +respectively. + + +Figures 3 and 4 show that Ulysses consistently outperforms Megatron-LM +and ColAI-SP for the sequence length that can be run with them. In addition, +Ulysses can run longer sequence than the two existing methods. Ulysses +performance advantages are two folds: (1) Ulysses in combination with ZeRO-3 +parameter sharding across both data and sequence parallel groups fits more +samples than Megatron-LM and ColAI-SP because of the memory optimization +leading to higher throughput (2) Ulysses benefits from efficient *all-to-all* +communication relative to *all-gather* *reduce-scatter* and *ring-style* P2P +communication as applied in Megatron-LM and ColAI-SP sequence parallelism. +However, for dense attention at long sequence length, the throughput is +primarily determined by local attention computation due to quadratic +computation complexity of attention, therefore performance gap between Ulysses +and the two existing methods closes for sequence length that can be run with them. + +
+ + +*Figure 3: Evaluation of Ulysses vs Megatron LM vs ColAI-SP on GPT-7B parameter + model with dense attention (32 GPUs).* +
+ +
+ + +*Figure 4: Evaluation of Ulysses vs Megatron LM vs ColAI-SP on GPT-30B parameter + model with dense attention (64 GPUs).* +
+ +### Sparse Attention Evaluation + +Similarly, we evaluate Ulysses on 7 billion and 30 billion parameter sparse +attention models and benchmark against Megatron-LM sequence parallelism. +There is no public implementation of block sparse attention for ColAI-SP, +therefore, evaluation of sparse attention is in comparison with Megatron-LM. +Results of our evaluation are shown in Figures 5 and 6. We observe similar +trends with sparse attention as dense attention experiments. We observe more +than 2x throughput performance of Ulysses compared to Megatron-LM. For memory +saving, Ulysses leveraging ZeRO-3 scales to 4x longer sequence lengths +than Megatron-LM. + +Ulysses outperforms Megatron-LM for sequence length that can be run with both. +In fact, the current Ulysses throughput is bottle-necked by the local sparse +attention implementation, and as a result Ulysses throughput decreases as +the sequence length increases. We expect this gap in performance between our +method and Megatron-LM to increase further for larger sequence lengths as we +improve the performance of the local sparse attention implementation in future. +A noteworthy observation is that the decreasing performance gap between Ulysses +and Megatron-LM observed in dense attention evaluation is less pronounced in +sparse attention evaluation, because the attention computation in sparse attention +is less dominant compared to dense attention. + +
+ + +*Figure 5: Evaluation of Ulysses and Megatron LM sequence parallelism on GPT-7B +parameter model with block sparse attention (32 GPUs).* +
+ +
+ + +*Figure 6: Evaluation of Ulysses and Megatron LM sequence parallelism on GPT-30B +parameter model with block sparse attention (64 GPUs).* +
+ +### Convergence Study + +Lastly, Figure 7 shows convergence of a 1.3 billion GPT model at 32K +sequence length on 8 A100 GPUs with sequence parallelism degree set at 4 +for both DeepSpeed and Megatron-LM sequence parallelism. For DeepSpeed +sequence parallelism, we evaluate convergence with different ZeRO +stages. DeepSpeed sequence parallelism is a purely system optimization +technique that enables training of long sequence Transformer model, thus +there is no (negative) impact on quality of trained models, this assertion is +validated through experiments and is shown in Figure 5. + +
+ + +*Figure 7: Convergence evaluation of DeepSpeed sequence parallelism with different +ZeRO memory optimization stages.* +
+ +## DeepSpeed-Ulysses Software Accessibility + +DeepSpeed-Ulysses can be easily integrated into your code with just a +few lines of simple code changes. Here is an example of how to enable +it: + +```python +from deepspeed.sequence.layer import DistributedAttention + +# Replace the original self-attention (attn) with DeepSpeed-Ulysses’s self-attention + +dist_attn = DistributedAttention(attn, get_sequence_parallel_group()) +``` + +Compared to other libraries that support sequence parallelism, such as +Megatron-LM, DeepSpeed-Ulysses does not require model refactoring. +DeepSpeed-Ulysses has been fully integrated and tested with the +Megatron-DeepSpeed code repository. This means that if you are already +using this repository for training large language models, you can +seamlessly benefit from DeepSpeed-Ulysses to train models with massive +sequence length. + +## Release: Try DeepSpeed-Ulysses Today + +We are excited to release DeepSpeed-Ulysses, accessible through +DeepSpeed GitHub. Detailed tutorial on usage is available on [DeepSpeed +tutorial page](https://www.deepspeed.ai/tutorials/ds-sequence/). + +We welcome contributions and collaboration as we together push forward +on what is possible when long context window is no longer a limitation. +DeepSpeed-Ulysses is part of the bigger DeepSpeed ecosystem of +large-scale AI training and inference. For more details on all DeepSpeed +technologies and innovations, please visit our [website]((https://www.deepspeed.ai/)) and follow us +on X, formerly Twitter, ([English](https://twitter.com/DeepSpeedAI), [Japanese](https://twitter.com/DeepSpeedAI_JP)) and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed). + +We are open to collaborations with universities, research labs, and +companies. For such requests (and other requests unsuitable for GitHub), +please directly email to . If you like +our work, please "Star" our [repo](https://github.com/deepspeedai/DeepSpeed). diff --git a/blogs/deepspeed-ulysses/chinese/README.md b/blogs/deepspeed-ulysses/chinese/README.md new file mode 100644 index 000000000000..cfd3b9664709 --- /dev/null +++ b/blogs/deepspeed-ulysses/chinese/README.md @@ -0,0 +1,155 @@ +
+ +# DeepSpeed Ulysses: 训练极长序列Transformer模型的系统优化 + +
+ +
+ + + +
+ +## 简介 + +从生成性AI到科研模型,长序列训练正在变得非常重要。 +在生成性AI领域,会话式AI、长文档摘要和视频生成等任务都需要在空间和时间层面对长上下文进行推理。 +例如,多模态基础模型,如同时处理语音、图像和波形的模型,需要对具有极长序列的高维输入进行长上下文推理。 +同样,章节和书籍级别的摘要(数万甚至数十万字)在会话式AI和摘要任务中也非常重要。 + +对于科学AI来说,长序列同样至关重要,它为更好地理解结构生物学、医疗保健、气候和天气预测以及大分子模拟打开了大门。 +例如,通过在基因序列上训练大型语言模型,我们可以创建可以使用极长序列(人类基因组有64亿个碱基对)学习基因组进化模式的语言模型。在医疗保健领域,以所有的患者护理记录为条件的诊断预测模型需要极长序列的上下文。 + +尽管对于生成性AI和科学AI来说,长序列长度的重要性逐渐增长,但现有的大型模型训练系统和底层的并行技术(数据、张量、流水线、序列并行)并不能支持高效的长序列训练。现有并行方法存在两个主要挑战。首先,现有的数据、张量和流水线等并行方法无法解决序列维度的扩展问题。其次,由于内存通信效率低下,现有的序列并行方法不够高效。此外,现有方法的易用性不足,需要进行侵入性和复杂易出错的代码重构。 + +为了解决这些问题,我们很高兴宣布推出*DeepSpeed-Ulysses(或称为Ulysses,一个非常长的小说)*,这是一种简单、易用且高效的方法,用于支持具有极长序列长度的高效可扩展LLM训练。 + +DeepSpeed-Ulysses将各个样本在序列维度上分割给参与的GPU。然后,在attention计算之前,它对已分割的查询(Q)、键(K)和值(V)执行*all-to-all通信*操作,以使每个GPU接收完整的序列,但仅用于注意力头的非重叠子集。这使得参与的GPU可以并行计算不同的注意力头。最后,DeepSpeed-Ulysses还使用另一个all-to-all来在注意力头上收集结果,同时重新在序列维度上进行分区。 + +DeepSpeed-Ulysses及其与此博客一起发布的实现的关键特性如下: + +* 与现有系统相比,序列长度增加了***4倍***,支持训练***超过百万个token***的序列。 + +* 与现有系统相比,通信减少了***超过10倍***,导致吞吐量提高了***高达2.5倍***,并且每个GPU的持续吞吐量超过175 TFlops(超过硬件峰值的54%)。 + +* 完全通用的attention:DeepSpeed序列并行支持密集和稀疏的注意力,并可与高效的注意力实现(如FlashAttention v2)一起工作。 + +* 支持大规模模型训练:DeepSpeed序列并行不仅支持大序列长度,还可以与ZeRO-3并用支持大模型尺寸。 + +* 易于使用和迁移,最小化对现有训练框架的代码更改要求。 + +在接下来的章节中,我们详细讨论DeepSpeed-Ulysses的核心设计、通信复杂度分析、实验评估以及与现有工作的比较,并展示其可用性和使用指南。 + +## DeepSpeed-Ulysses的核心设计 + +
+ + +*图1:DeepSpeed序列并行(DeepSpeed-Ulysses)设计* +
+ +图1显示了DeepSpeed-Ulysses的核心设计。与已知的Transformer架构一样,设计由*N*个输入序列在*P*个可用设备上分区组成。每个本地*N/P*分区都被投影到查询(Q)、键(K)和值(V)嵌入中。接下来,(QKV) 嵌入通过参与计算设备之间的高度优化的全对全集合(all-to-all collectives)进行全局的 QKV 收集。在全对全集合后,每个头的注意力计算形式为: + +$$Output\ context = Softmax\ (\frac{QK^{T}}{\sqrt{d}})V$$ + +注意力计算后,另一个全对全集合将注意力计算的输出上下文张量转换为序列(*N/P*)并行,用于Transformer模型层的剩余模块中的后续操作(MLP MatMul、层归一化等)。 + +### 显著的通信量减少 + +DeepSpeed-Ulysses与其他现有的长序列方法的区别在于其更小的累积通信量以及随着序列并行度增加而更好的可扩展性,如下所示: + +在具有节点内NVSwitch互连和节点间胖树IB拓扑的现代集群上,针对一个聚合消息大小为*M*的全对全传输,传输到*P*个GPU上的每个链接的通信量为*M/P*。 +对于隐藏层大小为h、序列长度为N且并行度为P的Transformer模型,DeepSpeed序列并行会在注意计算之前对QKV投影执行聚合消息大小为*3Nh*的全对全操作,并在注意计算之后对输出上下文投影执行大小为*Nh*的另一个全对全操作。因此,DeepSpeed序列并行每个链接的聚合通信量为***4Nh/P(或O(N/P)复杂度)***。值得注意的是,当N和P成比例增加时,这个通信量是恒定的。 + +相比之下,现有的方法,如Megatron-LM,在N线性增长的情况下会导致通信量线性增加,而与P无关,从而导致***O(N)的通信复杂度***。例如,Megatron-LM对每个Transformer模型层都执行两个大小为*Nh*的*all-gather*操作,以及两个大小为*Nh*的*reduce-scatter*操作。然而,当*P \>\> 1*时,大小为M的每个all-gather和reduce-scatter的成本仍然是M,而不是*M/P*。因此,Megatron-LM序列并行会导致每个链接的通信量为***4Nh***,这比DeepSpeed序列并行大P倍。这使得DeepSpeed序列并行可以在实现显著更高的训练效率的同时支持极长序列训练。我们的实验评估结果与此理论分析相符。 + +### DeepSpeed-Ulysses的其他亮点 + +***通用的注意力解决方案*** + +DeepSpeed分布式注意力模块的实现足够通用,以支持任何类型的注意力,例如自注意、交叉注意和因果注意,无论是它们的密集还是稀疏版本,以及支持局部注意层级上的长序列的各种优化内核,例如不同版本的FlashAttention。 + +DeepSpeed-Ulysses的通用性来自其核心设计的模块化性质:一个以注意力为中心的序列并行设计。在注意力计算之前,序列并行性是对N/P分区的,而注意力计算是对每个头的并行性,每个头的注意力全都保留,但头的数量较少,因此注意力计算可以用任何类型的注意力机制替代,例如密集注意力和各种形式的稀疏注意力。 + +***通过ZeRO-3集成实现更大的模型和更长的序列训练*** + +尽管DeepSpeed序列并行在使用更长的序列进行训练时减少了激活内存的使用,但并不影响模型状态的内存占用。因此,为了支持具有大序列长度的大语言模型训练,我们实现了DeepSpeed序列并行与ZeRO-3的集成。 + +[ZeRO Redundancy Optimizer Stage 3 (ZeRO-3)](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) 是一种用于训练大模型的内存优化技术。与传统的神经网络数据并行训练中,模型状态在数据并行等级上进行复制不同,ZeRO-3通过将模型状态在数据并行等级之间进行分区来优化内存使用。然而,使用序列并行时,训练数据可以在批(样本)和序列维度上考虑,相关的并行群组可以组合成一个更大的群组以实现ZeRO并行。 + +因此,我们将ZeRO-3分区扩展到数据并行和序列并行等级的组合。换句话说,在DeepSpeed序列并行中,ZeRO将模型状态分区在序列和数据并行组之间,并在需要时收集每个等级分区(allgather)。类似地,梯度将在数据并行和序列并行等级之间进行减少,用于参数更新。ZeRO可以在序列和数据维度上实现巨大的内存节省,并且不仅可以扩展到大序列长度,还可以扩展到大模型。 + +## 评估 + +我们在GPT(用于许多NLP任务的基础模型)上使用最多64个A100 GPU(40GB显存)对DeepSpeed-Ulysses进行了评估。我们的评估分为四个方面:i) 序列长度可扩展性,ii) 密集注意力的吞吐量以及与现有系统的比较,iii) 稀疏注意力的吞吐量以及与现有系统的比较,iv) DeepSpeed序列并行的收敛性研究。接下来,我们将对每个类别讨论和展示评估结果。 + +### 序列长度可扩展性 + +第一组实验是在12亿参数的GPT模型上将序列长度扩展到100万token。这个评估的结果如图2所示。DeepSpeed序列并行允许随着GPU数量的增加线性增加序列长度,并且序列长度与GPU数量保持线性比例关系,适当的GPU数量下保持相似的计算吞吐量。 + +
+ + +*图2:DeepSpeed序列并行强化可扩展性评估,使用不同的序列长度和GPU数量。* +
+ +### 密集注意力评估 + +接下来,我们在300亿参数的密集注意力模型上对DeepSpeed序列并行进行了评估,并与Megatron序列并行在64个A100 GPU上进行了对比。这些评估的结果如图3所示。 + +我们将DeepSpeed序列并行与Megatron-LM在不同序列长度下的性能进行了比较。对于我们的评估,我们选择了能使DeepSpeed序列并行和Megatron-LM分别达到最佳性能(通过吞吐量或TFLOPs衡量)的序列长度-批大小组合,我们称之为最佳(批大小-序列长度)配置。对于DeepSpeed序列并行,我们始终使用64的ZeRO并行度。 + +图3显示,DeepSpeed序列并行在相同序列长度下始终优于Megatron-LM。此外,DeepSpeed序列并行可以运行比Megatron-LM更长的序列。DeepSpeed序列并行的性能优势在于两个方面:(1)DeepSpeed序列并行结合ZeRO-3的内存优化,可以容纳更多的样本,从而提高吞吐量;(2)相对于Megatron-LM序列并行中应用的*all-gather*通信,DeepSpeed序列并行使用更高效的全对全通信。 + +
+ + +*图3:DeepSpeed和Megatron LM序列并行在300亿参数模型上的密集注意力评估。* +
+ +### 稀疏注意力评估 + +类似地,我们在300亿参数的稀疏注意力模型上对DeepSpeed序列并行进行了评估,并与Megatron序列并行进行了对比。我们的评估结果如图4所示。稀疏注意力的实验结果与密集注意力实验类似。我们观察到DeepSpeed序列并行的吞吐量性能相对于Megatron-LM提高了2倍以上。通过节省内存,DeepSpeed序列并行结合ZeRO-3可以扩展到比Megatron-LM更长4倍的序列长度。 + +DeepSpeed序列并行在相同序列长度下始终优于Megatron-LM。事实上,当前的DeepSpeed吞吐量受到本地稀疏注意力实现的瓶颈,因此DeepSpeed吞吐量随着序列长度的增加而降低。我们预计,随着未来局部稀疏注意力实现性能的改善,DeepSpeed与Megatron之间的性能差距将在更大的序列长度下进一步增加。 + +
+ + +*图4:DeepSpeed和Megatron LM序列并行在300亿参数模型上的稀疏注意力评估。* +
+ +### 收敛性研究 + +最后,图5显示了1.3亿参数GPT模型在32K序列长度下,使用序列并行度设置为4的情况下,在8个A100 GPU上的收敛性。对于DeepSpeed序列并行,我们使用不同的ZeRO阶段进行了收敛性评估。DeepSpeed序列并行是一种纯系统优化技术,用于实现长序列Transformer模型的训练,因此在训练模型质量上没有(负面)影响,并通过实验得到了验证,如图5所示。 + +
+ + +*图5:使用不同ZeRO内存优化阶段的DeepSpeed序列并行的收敛性评估。* +
+ +## DeepSpeed-Ulysses软件可用性 + +DeepSpeed-Ulysses只需进行少量简单代码更改来集成到您的代码中。下面是一个启用它的示例: + +```python +from deepspeed.sequence.layer import DistributedAttention + +# 将原始的自注意(attn)替换为DeepSpeed-Ulysses的自注意 + +dist_attn = DistributedAttention(attn, get_sequence_parallel_group()) +``` + +与其他支持序列并行的库(如Megatron-LM)相比,DeepSpeed-Ulysses不需要进行模型重构。 +DeepSpeed-Ulysses已经完全与Megatron-DeepSpeed代码库集成并经过测试。这意味着如果您已经 +在使用这个代码库来训练大型语言模型,您可以无缝地使用DeepSpeed-Ulysses训练具有极长序列的模型。 + +## 立即尝试DeepSpeed-Ulysses + +我们很高兴宣布推出DeepSpeed-Ulysses,您可以通过DeepSpeed GitHub获取代码。详细的使用教程在[DeepSpeed教程页面](https://www.deepspeed.ai/tutorials/ds-sequence/)上提供。 + +我们欢迎各种形式的贡献和合作,以共同推动当长上下文窗口不再是限制时的各种创新。DeepSpeed-Ulysses是大规模AI训练和推理的更大DeepSpeed生态系统的一部分。有关所有DeepSpeed技术和创新的更多细节,请访问我们的[网站]((https://www.deepspeed.ai/)),并关注我们在X(Twitter)上的帐号([英文](https://twitter.com/DeepSpeedAI),[日文](https://twitter.com/DeepSpeedAI_JP))和[中文知乎](https://www.zhihu.com/people/deepspeed)。 + +我们愿意与大学、研究实验室和公司合作。有关此类请求(以及不适合在GitHub上提出的其他请求),请直接发送电子邮件至。 diff --git a/blogs/deepspeed-ulysses/japanese/README.md b/blogs/deepspeed-ulysses/japanese/README.md new file mode 100644 index 000000000000..8e4a10f66845 --- /dev/null +++ b/blogs/deepspeed-ulysses/japanese/README.md @@ -0,0 +1,158 @@ +
+ +# DeepSpeed Ulysses: Transformerモデルを非常に長いシーケンスで訓練するための最適化 + +
+ +
+ + + +
+ +## 概要 + +巨大モデルを長いシーケンスで訓練することは、生成AIから科学的発見のためのモデルに至るまで、あらゆる分野で非常に重要になっています。 +生成AIでは、会話型AI、長文の要約、ビデオ生成など、空間的・時間的な文脈での長いコンテキストの理解が求められます。 +たとえば、音声、画像、波形を同時に処理するマルチモーダルの基盤モデルは、非常に長いシーケンス長の高次元の入力から、長期のコンテキストを理解することが求められます。同様に、章や書籍単位での要約(数万から数十万語と想定される)は、会話AIや要約タスクにおいて非常に重要です。 + +長いシーケンスを扱えることは、科学におけるAIの利用にも重要であり、構造生物学、医療、気候および天気予報、大規模分子シミュレーションを進歩させる可能性を持っています。例えば、大規模な言語モデルを遺伝子のシーケンスに適応させることにより、単純なアルファベットからなる非常に長いシーケンスから、ゲノムの進化のパターンを学ぶ言語モデルを作成できます(ヒトゲノムには64億の文字があります)。また医療分野において、全体の患者ケア記録に基づいて条件付けされる診断予測モデルでは、非常に長いシーケンスで表現される文脈を扱う必要があります。 + +生成AIや科学分野において、長いシーケンスを扱う重要性が急速に増している一方で、既存の大規模モデルの訓練システムや基盤となる並列化技術(データ並列、テンソル並列、パイプライン並列、シーケンス並列)では、効率的に長いシーケンスを訓練することができませんでした。既存の並列化のアプローチには、2つの課題があります。第一に、データ並列、テンソル並列、パイプライン並列のような、既存の広く使用されている並列アプローチは、シーケンスの次元に沿ってスケールアップすることができません。第二に、既存のシーケンス並列のアプローチは、メモリ上のデータの通信が理由で、高い効率が得られません。さらに、既存のアプローチは、大規模なコードの変更が必要となり、既存のコードにエラーを発生させやすいという課題もあります。 + +このリリースは、LLM(大規模言語モデル)の訓練において、非常に長いシーケンスの処理を、効率的かつスケーラブルに実現する新たな手法である *DeepSpeed-Ulysses(またはUlysses、非常に長い小説にちなんで名づけられました)* を公開するものです。 + + +DeepSpeed-Ulyssesは、個々のサンプルを、シーケンスの次元で複数のGPUで分割します。そして、Transformerにおけるアテンション計算の直前に、 クエリ (Q)、キー (K)、および値 (V)について、*all-to-all* 通信を適用します。 +このall-to-all通信により、アテンションヘッドの単位で重複のないように複数のGPUに分割配置される一方で、シーケンス全体が一つのGPUに保持されるようになります。各GPUは、それぞれに異なるアテンションヘッドを計算するため、並列に計算が可能です。アテンションの計算後、もう一度 all-to-all 通信によって、計算結果をシーケンスの次元で再分割します。 + +このブログで紹介するDeepSpeed-Ulysses及びその実装の主な特長は以下の通りです。 + + +* 既存のシステムに比べて ***4倍長いシーケンス長*** (***100万トークン以上***)のシーケンスでの訓練が可能。 + +* 既存のシステムと比較して ***10倍以上の通信削減***。これにより、***最大2.5倍のスループット向上***と、175 TFlops/GPU(ハードウェアピークの54%以上)のスループットを実現。 + +* アテンションの実装に依存しない汎用性: Denseなアテンション計算のアルゴリズムだけでなく、Sparseなアルゴリズムも利用できます。また、FlashAttention v2のような効率的なアテンションの実装も容易に利用できます。 + +* 大規模モデルの訓練のサポート: ZeRO-3と連携して、長いシーケンスを処理できるだけでなく、巨大なモデルサイズもサポートします。 + +* 最小限のコード変更で、既存の訓練フレームワークに適用できます。 + +以降のセクションでは、DeepSpeed-Ulyssesの中心となる設計アイデア、通信コストの分析、実験的な評価と既存手法との比較を詳しく示した後、使用方法について説明します。 + + +## DeepSpeed-Ulyssesの設計 + +
+ + +*図1: DeepSpeed-Ulysses の設計* +
+ +図1はDeepSpeed-Ulyssesの中心となる設計を示しています。既知のTransformerアーキテクチャと同様に、入力シーケンス長 *N* が *P* の利用可能なデバイスに分割されて構成されます。各デバイスにおける、サイズ *N/P* の分割されたシーケンスから、クエリ (Q)、キー (K)、および値 (V) が計算されます。次に、各デバイス上のローカルな QKV から、all-to-all 集合通信によって、グローバルな QKV が構成されます。all-to-all 通信に続いて、ヘッドごとに以下のようにアテンションが計算されます。 + +$$Output\ context = Softmax\ (\frac{QK^{T}}{\sqrt{d}})V$$ + +アテンションの計算後、all-to-all 通信を再度実行し、Transformerレイヤーの残りのモジュール (MLP、layer norm など) の後続のオペレータを実行するため、シーケンス次元に沿って出力を分割します(各デバイス上での分割されたシーケンス長は、また *N/P* になります)。 + +### 通信量の大幅な削減 + +DeepSpeed-Ulyssesが、長いシーケンスのための既存の並列化手法と異なる点は、以降の通信量の分析に示すように、総通信量がはるかに少なく、それによって、シーケンスの並列度が増加した際の全体的なスケーラビリティが優れていることです。 + +ノード内通信にNVSwitch、ノード間通信にfat tree IBトポロジを備えるなどのモダンな計算クラスタでは、*P* 個のGPU上でall-to-all通信を行ったとき、合計メッセージのサイズ *M* に対して、リンクごとの通信量は *M/P* になります。隠れサイズ*h*、シーケンス長*N*、および並列度*P*のTransformerモデルに対して、アテンション計算の前に、QKVについてall-to-allを実行しますが、この合計メッセージサイズは *3Nh* になります。また、アテンションの出力に対しても、all-to-allを実行しますが、このメッセージサイズは *Nh* になります。したがって、Transformerレイヤごとに、リンクあたり合計通信量が ***4Nh/P*** となります (オーダーでは O(N/P)) 。この通信量は、NとPの両方が比例して増加する場合に一定です。 + +対照的に、Megatron-LMのシーケンス並列のような既存のアプローチは、*P* に関係なく *N* とともに通信量が線形に増加するため、通信量のオーダーは ***O(N)*** となります。例えば、Megatron-LMは、Transformerの各レイヤーに対して、通信量がNhの2つのall-gatherと、同じく通信量がNhの2つのreduce-scatterを実行します。しかし、サイズMの各all-gatherおよびreduce-scatterのコストは、 *P \>\> 1* の場合に(M/Pではなく)Mのままです。したがって、Megatron-LMシーケンス並列は、DeepSpeed-UlyssesのP倍大きな ***4Nh*** の通信ボリュームを発生させます。これにより、DeepSpeed-Ulyssesは、既存のアプローチと比較して、極端に長いシーケンスでの訓練を可能にし、訓練効率を大幅に向上させることができます。以降で示す評価結果は、この分析と一致しています。 + +### その他の特長 + +***アテンションの実装に非依存*** + +DeepSpeed-Ulyssesでは、アテンションの実装について、self-attention, cross-attention, Dense/Sparse等の異なるアルゴリズム、FlashAttentionのように、長いシーケンスをサポートするさまざまな最適化されたカーネルを用いた実装など、任意のアテンションと組み合わせて用いることができるような、一般化された構造になっています。 + +この一般性は、アテンション計算をモジュール化して用いることによって実現されています。アテンション計算の前では、シーケンス長NをN/Pに分割しますが、アテンション計算自体は、ヘッドごとに完全なアテンションを計算しており、ただデバイスあたりのヘッド数が少ないだけです。したがって、アテンション計算は、Denseなアルゴリズムやさまざまな種類のSparseなアテンションのためのアルゴリズムなど、任意の種類のアテンションのメカニズムと置き換えることができます。 + +***ZeRO3による大規模モデルの訓練*** + +DeepSpeed-Ulyssesによるシーケンスの分割と並列化は、長いシーケンスでの訓練時のアクティベーションメモリを削減しますが、モデル状態の保持に必要なメモリ量には影響しません。したがって、大きな言語モデルで長いシーケンス長の訓練をサポートするために、シーケンスの並列化はZeRO-3と統合されています。 + + +[ZeRO Redundancy Optimizer Stage 3 (ZeRO-3)](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) は、大規模なモデルを訓練するためのメモリ最適化技術です。モデルの状態(パラメータ、勾配、Optimizer状態)を全てのGPUに複製する従来のデータ並列と異なり、ZeRO-3はGPUにモデルの状態を分割配置します。シーケンス並列を併用する場合、訓練データは、サンプルの次元と、シーケンスの次元の両方で分割されていることになります。 +そこで、データ並列およびシーケンス並列の両方のグループにまたがるプロセス群で、ZeRO-3におけるパラメータや勾配等の分割を行い、また必要な時にallgather通信によってそれらを収集します。同様に、勾配の集約(reduce)も、パラメータ更新のためにデータ並列とシーケンス並列の両方にまたがるプロセス群で実施されます。ZeROを使用することで、シーケンスとデータの両方の次元で大きなメモリ節約が可能となり、長いシーケンス長だけでなく、大きなモデルサイズにもスケーリングすることができます。 + +## 評価 + +多くのNLPタスクの基盤モデルとして用いられるGPTモデルの学習に、DeepSpeed-Ulyssesを適用し、最大64台のA100 GPU(40GBメモリ)を用いて評価を行いました。評価は以下の4つの観点で実施しました: i) シーケンス長のスケーラビリティ、ii) Denseなアテンションでのスループットおよび既存のシステムとの比較、iii) Sparseなアテンションのスループットおよび既存のシステムとの比較、iv) 収束性の検証。以降で、それぞれの評価結果を示します。 + +### シーケンス長のスケーラビリティ + + +最初の評価実験は、12億パラメータのGPTモデルでの、最大100万トークンまでのシーケンス長の強スケーリング(strong scaling)です。この評価の結果を、図2に示します。GPUの数に比例してシーケンス長を増加させた際に、それぞれのGPU数・シーケンス長で、ほぼ同等の計算スループットを維持しています。 + +
+ + +*図2: 異なるシーケンス長・GPU数での強スケーリング(strong scaling)* +
+ +### Denseなアテンションでの比較 + +次に、300億パラメータのDenseなアテンションを持つモデルで、64台のA100 GPU上でのMegatron-LMのシーケンス並列との比較を行ったベンチマーク結果を図3に示します。 + +ここでは、様々なシーケンス長で、DeepSpeed-UlyssesとMegatron-LMのシーケンス並列を比較しました。評価のために、それぞれのフレームワークが、最高の性能(スループットまたはTFLOPとして測定)を得られるシーケンス並列の並列度と、グローバルバッチサイズを選択しました。これを私たちは最適(バッチサイズ-シーケンス長)構成と呼びます。DeepSpeed-Ulyssesでは、常にZeRO-3を用い、64台のGPUにパラメータ・勾配・Optimizerの状態を分割配置しました。 + +図3に示すように、DeepSpeed-UlyssesとMegatron-LMの両方で処理できるシーケンス長では、DeepSpeed-Ulyssesが常にMegatron-LMよりも優れたパフォーマンスを示しました。さらに、DeepSpeed-Ulyssesは、Megatron-LMのシーケンス並列よりも、長いシーケンスを処理できます。DeepSpeed-Ulyssesの利点は2つあります:(1) ZeRO-3との組み合わせにより、メモリの必要量をより小さくできるため、Megatron-LMよりも大きなバッチサイズを処理できるようになり、スループットが高まる。 (2) DeepSpeed-Ulyssesは、Megatron-LMシーケンス並列処理で適用されるall-gather通信と比較して、より効率的なall-to-all通信のメリットを得られる。 + + +
+ + +*図3: 300億パラメータ・DenseなアテンションでのMegatron-LMとの比較* +
+ +### Sparseなアテンションでの比較 + +同様に、300億パラメータのSparseなアテンションを用いたモデルに、DeepSpeed-Ulyssesを適用し、Megatron-LMのシーケンス並列との比較を行いました。評価の結果を図4に示します。Sparseなアテンションに関しても、Denseなアテンションと同様の傾向が見られます。Megatron-LMに比べて、DeepSpeed-Ulyssesのスループット性能が2倍以上であることを確認しています。ZeRO-3を用いたメモリ使用量の削減によって、Megatron-LMよりも4倍長いシーケンス長を処理できています。 + +DeepSpeed-Ulyssesは、DeepSpeed-UlyssesとMegatron-LMの両方で実行できるシーケンス長において、Megatron-LMを上回っています。実際、現在のDeepSpeed-Ulyssesのスループットは、各GPU上でローカルに計算されるSparseなアテンションがボトルネックとなっており、その結果、シーケンス長が増加するにつれてスループットが減少します。将来、ローカルのSparseなアテンションの実装のパフォーマンスを向上させることで、DeepSpeed-UlyssesとMegatronの間の性能のギャップが、より大きなシーケンス長に対してさらに広がると予想しています。 + +
+ + +*図4: 300億パラメータ・SparseなアテンションでのMegatron-LMとの比較* +
+ +### 収束の検証 + +図5では、8台のA100 GPU上で32Kのシーケンス長を持つ13億パラメータのGPTモデルの収束を示しています。ここでは、DeepSpeed-UlyssesとMegatron-LMのシーケンス並列の両方で、シーケンス並列の並列度を4に設定しています。DeepSpeed-Ulyssesは、ZeROと併用可能なため、それぞれのZeROステージでの収束を評価しました。DeepSpeed-Ulyssesは、長いシーケンスのTransformerモデルの訓練を可能にするための、システム面での最適化技術であり、したがって訓練されたモデルの品質に対する(ネガティブな)影響はありません。このことは、図5の結果から確認できます。 + +
+ + +*図5: 異なるZeROのステージでの収束* +
+ +## DeepSpeed-Ulyssesの利用 + +DeepSpeed-Ulyssesは、数行の簡単なコードの変更だけで、既存のコードに簡単に統合することができます。以下は利用の例です: + +``` python +from deepspeed.sequence.layer import DistributedAttention + +# Replace the original self-attention (attn) with DeepSpeed-Ulysses’s self-attention +dist_attn = DistributedAttention(attn, get_sequence_parallel_group()) +``` + +シーケンス並列処理をサポートする他のライブラリ、例えばMegatron-LMと比較して、DeepSpeed-Ulyssesはモデルのリファクタリングを必要としません。 +また、DeepSpeed-UlyssesはMegatron-DeepSpeedコードリポジトリと統合され、テストされています。 +大規模な言語モデルの訓練のためにこのリポジトリをすでに使用している場合、巨大なシーケンス長のモデルを訓練するために、DeepSpeed-Ulyssesをすぐに活用できます。 + +## 早速試してみましょう! + +DeepSpeed-Ulyssesは、DeepSpeedのGitHubを通じてアクセス可能です。使用方法に関する詳しいチュートリアルは、[DeepSpeedのチュートリアルページ +](https://www.deepspeed.ai/tutorials/ds-sequence/)にあります。 + +長いコンテキストを扱う際の制約を取り除くことによって何が可能になるのか、ユーザの皆様と共に様々な可能性を探求するため、幅広い協力やコラボレーションを歓迎します。DeepSpeed-Ulyssesは、大規模なAIの訓練と推論のためのより大きなDeepSpeedエコシステムの一部です。DeepSpeedの多くの技術や革新的な機能の詳細については、[ウェブサイト](https://www.deepspeed.ai/)をご覧いただくか、X(以前のTwitter。[英語版](https://twitter.com/DeepSpeedAI)、[日本語版](https://twitter.com/DeepSpeedAI_JP))や、中国の[Zhihu](https://www.zhihu.com/people/deepspeed)でフォローしてください。 + +DeepSpeedは、皆様の開発への参加を歓迎しています。DeepSpeedのGitHubページで、バグ報告、Pull Request、ディスカッションへの参加が可能です。詳細は[ガイドライン](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)をご覧ください。また、大学、研究所、企業とのコラボレーションも行っています。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については まで直接メールをお送りください。 diff --git a/blogs/deepspeed-ulysses/media/convg.png b/blogs/deepspeed-ulysses/media/convg.png new file mode 100644 index 000000000000..b9586dc404e4 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/convg.png differ diff --git a/blogs/deepspeed-ulysses/media/convgZ.png b/blogs/deepspeed-ulysses/media/convgZ.png new file mode 100644 index 000000000000..324f47cd61bd Binary files /dev/null and b/blogs/deepspeed-ulysses/media/convgZ.png differ diff --git a/blogs/deepspeed-ulysses/media/dense1B1Mscale.png b/blogs/deepspeed-ulysses/media/dense1B1Mscale.png new file mode 100644 index 000000000000..eb886f879247 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/dense1B1Mscale.png differ diff --git a/blogs/deepspeed-ulysses/media/dense30B.png b/blogs/deepspeed-ulysses/media/dense30B.png new file mode 100644 index 000000000000..d2eef04b73cc Binary files /dev/null and b/blogs/deepspeed-ulysses/media/dense30B.png differ diff --git a/blogs/deepspeed-ulysses/media/dense7B.png b/blogs/deepspeed-ulysses/media/dense7B.png new file mode 100644 index 000000000000..042269276a6b Binary files /dev/null and b/blogs/deepspeed-ulysses/media/dense7B.png differ diff --git a/blogs/deepspeed-ulysses/media/fig2Ulysses.png b/blogs/deepspeed-ulysses/media/fig2Ulysses.png new file mode 100644 index 000000000000..39e8a8420bde Binary files /dev/null and b/blogs/deepspeed-ulysses/media/fig2Ulysses.png differ diff --git a/blogs/deepspeed-ulysses/media/fig3Ulysses.png b/blogs/deepspeed-ulysses/media/fig3Ulysses.png new file mode 100644 index 000000000000..fa1498096284 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/fig3Ulysses.png differ diff --git a/blogs/deepspeed-ulysses/media/fig4Ulysses.png b/blogs/deepspeed-ulysses/media/fig4Ulysses.png new file mode 100644 index 000000000000..f55838b36e78 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/fig4Ulysses.png differ diff --git a/blogs/deepspeed-ulysses/media/hero1.png b/blogs/deepspeed-ulysses/media/hero1.png new file mode 100644 index 000000000000..f0034ffdf8b9 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/hero1.png differ diff --git a/blogs/deepspeed-ulysses/media/hero2.png b/blogs/deepspeed-ulysses/media/hero2.png new file mode 100644 index 000000000000..323d3d419002 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/hero2.png differ diff --git a/blogs/deepspeed-ulysses/media/image3.png b/blogs/deepspeed-ulysses/media/image3.png new file mode 100644 index 000000000000..18be3c843370 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/image3.png differ diff --git a/blogs/deepspeed-ulysses/media/sparse30B.png b/blogs/deepspeed-ulysses/media/sparse30B.png new file mode 100644 index 000000000000..2637d353d0c6 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/sparse30B.png differ diff --git a/blogs/deepspeed-ulysses/media/sparse7B.png b/blogs/deepspeed-ulysses/media/sparse7B.png new file mode 100644 index 000000000000..2d9c9ad69420 Binary files /dev/null and b/blogs/deepspeed-ulysses/media/sparse7B.png differ diff --git a/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md b/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md new file mode 100644 index 000000000000..4a2c0c0948e7 --- /dev/null +++ b/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md @@ -0,0 +1,181 @@ + +
+ +# DeepSpeed-VisualChat:多轮图像+文字,为你展现不一样的AI聊天魅力 + +
+ +
+ +DeepSpeed-VisualChat! + +
+ +要引用 DeepSpeed-VisualChat,请引用我们的 [arxiv 报告](https://arxiv.org/abs/2309.14327): + + +``` +@article{yao2023deepspeed-visualchat, + title={{DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention}}, + author={Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qin and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He}, + journal={arXiv preprint arXiv:2309.14327}, + year={2023} +} +``` + +# 1. 概述 +大型语言模型 (LLMs),如 GPT 和 LLaMa,在各种文本生成和理解任务中都展现出了卓越的能力,特别是在经过零次/少次学习(zero-/few-shot learning)或微调(instructed fine-tuning)后。然而,要让 AI 模型为多样化的任务做好准备,需要加入的一个关键特性是多模态能力;例如,AI 模型应该能够读取图像、听到声音、观看视频等。这种能力在纯文本基础的 LLMs 中基本上是不存在的。 + +最近,大量的研究项目开始探索将视觉能力引入到 LLMs 中,特别是通过插入图片输入使 LLMs 来理解图片(简称为大型视觉语言模型或 LVLMs)。 + +大多数现有工作的主要缺点是: +* 重点主要放在与单一图像相关的任务上,如视觉问题回答和字幕,或处理需要同时输入的多个图像。两种方法都不太擅长管理交错的图像和文本输入。 +* 系统的可扩展性仅限于具有约 10B 参数的模型,这比最大的开源模型小了一个数量级。 + +然而,对于一个真正的 AI 聊天模型,输入内容可能是与文本交错的多个图像,这是目前的工作很少涉及的情况。此外,随着模型大小的增加,LLMs 的生成能力增长迅速。因此,将系统能力集中在约 10B 的模型上限制了对 LVLMs 潜力的进一步探索。 + +为了解决这些问题,我们推出了 DeepSpeed-VisualChat(请参阅 [arxiv 报告](https://arxiv.org/abs/2309.14327) 以获取更多详细信息),带有以下新特性: + +* ***全开源多轮多图框架与前所未有的可扩展性***:DeepSpeed-VisualChat,作为开创性的全开源框架之一,支持多轮和多图对话,容纳交错的文本和图像输入。我们利用 DeepSpeed 提高我们的训练效果,使用一个 2B 的视觉编码器和一个 70B 的 LLaMA-2 解码器模型,展示了我们框架的显著可扩展性。 +* ***多模态因果注意力 (MMCA)*** 我们为多模态模型设计了一个新的 MMCA 注意力机制,独立地计算各种模态的注意力权重。MMCA 达到了与传统交叉注意机制类似的目标,但为生成任务提供了增强的因果注意解释,消除了对额外模块或参数的需求。与标准的因果注意力相比,它还提供了更好的训练数据效率。 +* ***交错输入的数据混合*** 为了促进交错模态的对话,DeepSpeed-VisualChat 在现有数据集上采用了各种数据混合技术,克服了大多数现有开源数据集中交错文本和图像输入的短缺。 + +# 2. 模型架构概述 +
+ 模型结构 + + *图 1:DeepSpeed-VisualChat 的模型架构示意图。* +
+ +如 *图 1* 所示,DeepSpeed-VisualChat 的模型架构由三个部分组成:一个视觉编码器,如 CLIP;一个语言解码器,如 LLaMa-7B;和一个特征对齐线性投影层。模型的大部分都是冻结的,只有语言模型的嵌入和线性投影层是可训练的。因此,可训练参数的总数大约在 O(10M) (LLaMa-2-13B) 到 O(100M) (LLaMa-2-70B) 之间。 + +# 3. DeepSpeed 多模态因果注意力 + +用于在多模态模型中连接视觉和文本组件的两种常见注意机制是:因果注意力,如在 MiniGPT 和 QWen-VL 中使用的,以及交叉注意力,如在 Otter 和 Flamingo 中使用的。 + +
+ 不同的注意机制 + + *图 2:不同的注意机制:使用一个输入句子“用户:请描述这个图片。”和三个图像令牌(I-token1、I-token2、I-token3)来比较不同的注意机制。在左边,我们展示了标准的因果注意力,将图像令牌视为文本。在中间,我们展示了应用于图像的交叉注意力,同时保持文本令牌的标准因果注意力。在右边,我们展示了我们的创新 MMCA 注意力机制,其中图像令牌只执行自注意,文本令牌独立地注意文本/图像令牌,橙色为图像部分。这种机制由:softmax($`QK^T \odot M_1`$)+ softmax($`QK^T \odot M_2`$) 定义,其中 Q 和 K 分别为查询和密钥,$`M_1`$=[M==1],和 $`M_2`$=[M==2],其中 M $`\in`$ R10x10。* +
+ +因果注意力 (CA):基于 CA 的方法简单地将视觉特征(即来自最终视觉编码器层输出的特征)投影到文本特征,并将它们与文本嵌入层后的正常文本特征组合,以送入 LLMs。CA 的好处是它是 LLMs 原始注意机制的自然扩展,因此,它不引入任何额外的模块或参数。但是,直觉上这种方法会带来一些问题: + +* 每个视觉令牌会关注它之前的视觉和文本令牌。然而视觉令牌已经以双向方式完全编码,不需要进一步关注它之前的视觉和文本令牌。 +* 对于一个文本令牌,模型需要学习如何在其之前的文本和图像令牌之间分配其注意权重。由于这些问题,我们发现 LVLMs 中 CA 的数据效率通常是有问题的。为了解决这个问题,LLaVA 和 QWen-VL 需要视觉-语言预训练来完全对齐视觉特征和文本特征。 + +交叉注意力 (CrA):作为替代方案,交叉注意力 (CrA) 与 CA 的结合展示出更好的数据效率,但也带有一些缺点: + +* 它为模型引入了新的参数。例如,具有交叉注意力引入的新参数的 Otter 拥有超过 15 亿的可训练参数。和 LLaVA 的百万级可训练参数相比,这大大增加了训练成本和内存需求。 +* 如果在训练过程中中间引入了一个图像,需要仔细设计,因为先前的文本令牌不应该能够注意到图像。 + +多模态因果注意机制 (MMCA):为了克服这些问题,我们提出了一种新的多模态因果注意机制 (MMCA),它既有 CA 的参数效率,又有 CrA 的数据效率。总体思路如下: + +* 对于视觉令牌,它们只关注自己,因为视觉令牌是由视觉编码器编码的。 +* 对于文本令牌,它们关注所有以前的令牌。但是,对文本和图像令牌 MMCA 使用两个单独的注意权重矩阵。 + +MMCA 的第二点背后的直觉是,一个模态的注意权重可能会影响另一个模态。例如,文本令牌可能会比视觉信息更多地关注文本信息。因此,如果注意权重矩阵在两种模态之间进行归一化,那么视觉令牌的注意得分可能会非常小。请参考 *图 2* 以查看三种注意机制的可视化。 + +演示结果。我们首先通过几个例子展示在不同的注意机制下 DeepSpeed-VisualChat 的单图像视觉语言对话功能。在这些实验中,我们使用 LLaMA2-7B 语言模型和 QWen-VL 视觉编码器作为我们的视觉编码器。这两个模型通过一个简单的线性投影层连接在一起。这个模型在两个 LLaVa 数据集上进行了训练。正如 *图 3* 和 *图 4* 所示,当与 MMCA 配合使用时,DeepSpeed-VisualChat 有效地识别了图像中的视觉细节,对用户的问题提供了准确通顺的回答。 +此外,与其他注意机制(如使用因果注意力和交叉注意力的组合)相比,MMCA 表现出更全面和精确的图像细节把握。与 CrA 和 CA 的组合以及 MMCA 相比,仅使用 CA 可能会显示出稍微多一些的错误(*图 3*)或导致较低的理解能力(*图 4*)。 + +
+ 小猫咪 + + *图 3:示例视觉和语言输入,显示了(1)标准因果注意力 (CA) (2)与交叉注意力组合的标准因果注意力 (CA+ CrA) 和(3)DeepSpeed-VisualChat 中的特殊多模态因果注意力 (MMCA) 之间的输出比较。* +
+ +
+ 美丽的湖泊 + + *图 4:DeepSpeed-VisualChat 准确地识别了场景是一个美丽的湖泊,并提供了一组合理的建议。相比之下,其他的注意力机制误解了图像认为其包含“带船坡的码头”。* +
+ +# 4. 数据混合 +我们使用了 3 个来源的 9 个数据集,如我们的 [arxiv 报告](https://arxiv.org/abs/2309.14327) 所述。一个实现多轮和多图对话的关键缺失元素是没有足够的数据。我们找到的唯一的多轮多图数据来源是 SparklesDialogue 数据集,它只包含 6520 个样本。为了解决这个问题,我们采用了两种方法,从现有的单图或单轮数据中合成多轮多图数据:简单的数据连接和 LLaVA-Otter 数据混合。 + +## 4.1 简单数据连接 +对于 LLaVA 模型使用的 "llava" 和 "llava_dial" 数据集,每个样本包括单图像的单轮/多轮对话。为了模拟用户依次询问多个图像的情况,我们对这两个数据集进行了简单的数据后处理。具体来说,我们随机将不同数量的样本连接成一个样本。在 "llava" 的情况下,我们连接了 1 到 3 个样本,而在 "llava_dial" 的情况下,我们连接了 1 到 2 个样本。 + +## 4.2 LLaVA-Otter 数据混合 +我们注意到,LLaVA 模型使用的 llava 和 llava_dial 数据集以及 Otter 模型使用的 otter_mimicit_cgd 数据集都使用了 COCO train2017 图像。对于 llava 和 llava_dial 数据集,每个样本包括一个图像的单轮/多轮对话。对于 otter_mimicit_cgd 数据集,每个样本包括一对图像的单轮对话。这使我们能够构建一个合成的多轮多图数据 llava_otter_blend 作为更自然的混合:对于 otter_mimicit_cgd 数据集中的每个样本,我们寻找使用相同图像的 llava 和 llava_dial 样本,然后以 "llava/llava_dial 对话然后 otter_mimicit_cgd 对话" 的方式构建一个新样本。 + +
+ 朋友们 + + *图 5:经过 LLaVA-Otter 数据混合后的数据样本。灰色对话框来自 LLaVA 数据集,橙色对话框来自 Otter 数据集。* +
+ +# 5. 演示 +我们在几个开源数据集上训练了我们的 DeepSpeed-VisualChat-13B 模型,该模型使用一个 2B 的视觉编码器和 13B 的 LLaMA 模型。DeepSpeed-VisualChat-13B 展示了图像字幕功能(*图 6--8*),计数和文本阅读(*图 6*),名人识别(*图 7*),讲故事(*图 8*)等。 + +
+ 朋友们 + + *图 6:DeepSpeed-VisualChat 可以计算图像中的人数,并读取第一张图像中的文本。它还展示了跨图像的理解。* +
+ +
+ CEO + + *图 7:DeepSpeed-VisualChat 可以识别名人并将他们与其成就联系起来。* +
+ +
+ 疯狂动物城 + + *图 8:DeepSpeed-VisualChat 可以讲故事并识别电影。* +
+ +# 6. 如何开始使用 DeepSpeed-VisualChat +DeepSpeed-VisualChat 是一个易于使用的训练框架,具有很好的可扩展性,到目前为止已经在 LLaMa-2-70B 模型上进行了测试。我们为所有实验采用了统一的指令调优格式,模板如下所示。 +``` + % You are a powerful vision-language assistant. + +### Image 1: % some image, e.g., cat-1.png +### Question: % please describe the image. +### Answer: % It's a cute black cat. + +### Image 2: % some image, e.g., cat-2.png +### Image 3: % some image, e.g., cat-3.png +### Question: % What's the difference between the three cats? +### Answer: % The colors of the three cats are different. +... +``` + +使用 DeepSpeed-VisualChat 训练模型是简单和方便的。这里我们给出了基于 CLIP 视觉编码器和 LLaMa-7B 模型的一个例子: + +``` +git clone https://github.com/deepspeedai/DeepSpeedExamples.git +cd DeepSpeedExamples/applications/DeepSpeed-VisualChat/ +pip install -r requirements.txt +cd training +bash training_scripts/run_7b.sh +``` + +训练后的模型权重将自动保存为 Hugging Face 兼容版本,并且可以用于启动您自己的视觉聊天 API: +``` +cd ../chat +bash chat_scripts/run.sh # You need to change necessary variables, e.g, ckpt path +``` + +为了支持更大的模型推理,我们已经将 Hugging Face 大模型推理集成到我们的 DeepSpeed-VisualChat API 中。因此,用户可以根据 GPU 内存容量和模型大小选择不同数量的 GPU。 + +请参考我们的 [GitHub 主页](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat) 了解更多细节。 + +# 7. 发布:今天尝试 DeepSpeed-VisualChat! + +我们非常兴奋地分享 DeepSpeed-VisualChat 现已开源并供 AI 社区使用。 + +* 要开始使用,请访问我们的 DeepSpeed-VisualChat GitHub 页面:[GitHub 主页](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat) + +* 我们将继续在您的反馈和支持下改进 DeepSpeed-VisualChat。我们的 [路线图](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat/README.md#-deepspeed-visualchats-roadmap-) 显示了目前支持的功能以及未来计划支持的功能。 + +DeepSpeed-VisualChat 是更大的 DeepSpeed 生态系统的一部分,其中包括一系列深度学习系统和建模技术。要了解更多信息, + +* 请访问我们的 [网站](https://www.deepspeed.ai/) 了解详细的博客文章、教程和文档。 +* 在我们的 [英文 X(Twitter)](https://twitter.com/DeepSpeedAI)、[日语 X(Twitter)](https://twitter.com/DeepSpeedAI_JP) 和 [中文知乎](https://www.zhihu.com/people/deepspeed) 上关注我们,以获取 DeepSpeed 的最新消息。 + +我们欢迎您为 DeepSpeed 做出贡献!我们鼓励您报告问题、贡献 PRs、并在 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 页面上参加讨论。有关更多详细信息,请查看我们的 [贡献指南](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)。我们对与大学、研究实验室、公司等进行合作持开放态度,例如共同进行深度学习研究、应用 DeepSpeed 为现实世界的 AI 模型和应用提供支持等等。对于此类请求(以及其他不适合 GitHub 的请求),请直接发送电子邮件至 info@deepspeed.ai。 + +* 如果你喜欢我们的工作,请在 [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) 和 [DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/) 上为我们的仓库点“星”。 diff --git a/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md b/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md new file mode 100755 index 000000000000..a7c70c1ef158 --- /dev/null +++ b/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md @@ -0,0 +1,188 @@ +
+ +# DeepSpeed-VisualChat: 複数ラウンド・複数画像の入力が可能なAIチャット体験を実現 +
+ +
+ +DeepSpeed-VisualChat! + +
+ +DeepSpeed-VisualChatを引用する場合、[arxiv上のレポート](https://arxiv.org/abs/2309.14327)を引用してください。 + +``` +@article{yao2023deepspeed-visualchat, + title={{DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention}}, + author={Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qin and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He}, + journal={arXiv preprint arXiv:2309.14327}, + year={2023} +} +``` + +# 1. 概要 +GPTやLLaMaのような大規模言語モデル(LLM)は、テキスト生成やテキスト理解などの多くのタスクにおいて、Zero-shot/Few-shot学習、あるいはinstructed fine-tuningによって、非常に優れた能力を示してきました。しかし、AIエージェントをより多様なタスクに対応させるには、マルチモーダルを扱う能力が必要です。例えば、AIエージェントは画像を読んだり、音声を聞いたり、ビデオを見たりすることができる必要があります。こうした機能は、テキストベースのLLMにはほとんどありません。 + +近年、LLMに視覚的な能力を導入することは、研究・実践の両方において広く試みられています。特に、画像をそのまま与えて、LLMが理解できるようにする取り組みが行われています(大規模視覚言語モデル、略してLVLMなどと呼ばれる)。 + +こうした分野における、既存の研究の主な問題は以下の通りです: + +* 視覚に関する質問への回答やキャプション付けのように、単一の画像に関連するタスクや、同時に入力される複数の画像の処理に重点が置かれており、画像とテキストが交互に入力されるような状況には対応していない +* システムのスケーラビリティは、~10Bのパラメータを持つモデルに限定される + +しかし、本来はAIチャットエージェントには、複数のテキストと画像の両方が与えられる可能性があります。また、LLMの生成能力は、モデルサイズが大きくなるにつれて急速に向上することが知られており、~10Bのモデルではその能力が制限されてしまいます。 + +これらの問題を解決するために、我々は以下の新たな機能を備えたDeepSpeed-VisualChat(詳細は[arxivのレポート](https://arxiv.org/abs/2309.14327)を参照)を開発しました: + +* ***完全にオープンソース化され、前例のないスケーラビリティを備えた複数ラウンド・複数画像を処理できるフレームワーク***: DeepSpeed-VisualChatは、完全にオープンソース化された先進的なフレームワークの1つであり、複数ラウンドを通じて画像とテキストが両方与えられる対話を可能にします。また、DeepSpeedを利用することで、比類ないスケーラビリティを実現しており、実際に2Bのビジュアルエンコーダーと70BのLLaMA-2デコーダーモデルで訓練を行えます。 +* ***マルチモーダル因果的注意(MMCA)***: マルチモーダルモデルのための新しいアテンションMMCA(Multi-Modal Causal Attention)を考案し、異なるモダリティ間で独立にアテンションの重みを計算します。MMCAは、従来のcross attentionに類似したものですが、生成タスクのためのcausal attentionを強化しており、追加のモジュールやパラメータが不要になります。また、標準的なcausal attentionと比較して、優れた訓練データ効率を示します。 +* ***順次与えられる画像とテキストを扱うためのデータブレンディング***: DeepSpeed-VisualChatは、既存のデータセットに様々なデータブレンディング技術を採用しています。これにより、順次与えられるテキストと画像の不足という、利用可能なオープンソースデータセットのほとんどに当てはまる課題を克服しています。 + +# 2 モデルアーキテクチャの概要 +
+ model arch + + *図1: モデルアーキテクチャの概要* + +
+ +*図1*に示すように、DeepSpeed-VisualChatのモデルアーキテクチャは、CLIPのような視覚エンコーダー、LLaMa-7Bのような言語デコーダー、特徴アライメントを行う linear projectionレイヤの3つのコンポーネントで構成されています。モデルのほとんどのパラメータは固定されており、言語モデルのembeddingとlinear projectionレイヤのみが学習可能です。その結果、学習可能なパラメータの総数は O(10M) (LLaMa-2-13B) から O(100M) (LLaMa-2-70B) となります。 + +# 3. DeepSpeed マルチモーダル Causal Attention (MMCA) + +マルチモーダルモデルで、画像とテキストをつなぐ一般的なattentionの機構は二つあります。一つはMiniGPTやQWen-VLで使われているようなcausal attentionで、もう一つはOtterやFlamingoで使われているようなcross attentionです。 + + +
+ Different attention mechanisms + + *図2: 異なるアテンションの機構: 「ユーザー:画像を説明してください」という入力文と3つの画像トークン(I-token1、I-token2、I-token3)と組み合わせて与えた場合の、それぞれのattention機構の構成を示しています。左側では、標準的なcausal attentionによって、画像トークンをテキストとして扱う様子を示しています。中央は、テキストトークンに対する標準的なcausal attentionを維持しながら、画像に適用されるcross attentionを使用する様子を示しています。右側では、画像トークンはself attentionのみを行い、テキストトークンはテキスト/画像トークンへのアテンションを独立に計算するという、新しいマルチモーダルのためのアテンションの提案を、オレンジ色のマスクで強調して示しています。この仕組みは、Q, Kをクエリとキーとしたとき、 softmax($`QK^T \odot M_1`$)+ softmax($`QK^T \odot M_2`$)として定義されます。M $`\in`$ R10x10としたとき、$`M_1`$=[M==1], and $`M_2`$=[M==2] です。* +
+ +Causal Attention(CA):CAに基づく方法は、視覚的特徴(最終的な視覚エンコーダ層の出力からの特徴)を単純にテキストの特徴量に投影し、テキスト埋め込み層以降の通常のテキストの特徴量と組み合わせてLLMに送り込むというものです。CAの利点は、LLMにおける本来のアテンション機構の自然な拡張であり、そのため余分なモジュールやパラメータを導入しないことです。しかし、このアプローチにはいくつかの直感的な問題があります: + +* 視覚トークンはすでに双方向に特徴量に変換されており、本来他の視覚トークンやテキストトークンとのアテンションの必要はありませんが、実際には前の視覚またはテキストトークンとのアテンションがあります。。 +* テキストトークンの場合、モデルは前のテキストトークンと画像トークンとの間でどのようにアテンションの重みを配分するかを学習する必要があります。これらの問題により、LVLMにおけるCAのデータ効率にはしばしば問題があることが分かりました。この問題への対処として、LLaVAとQWen-VLは、視覚的特徴とテキストの特徴を完全に対応させるために、視覚言語の事前学習を必要とします。 + +Cross Attention (CrA):代替案であるCross Attention (CrA) と CAの組み合わせは、より優れたデータ効率を示しますが、いくつかの欠点もあります: + +* モデルに新しいパラメーターを導入する必要があります。例えば、Otterは、Cross Attentionによって導入された新しいパラメータがあるため、LLaVAが数百万個の学習可能なパラメータを持つのに対し、15億個以上のパラメータを必要とします。これにより、学習コストと必要メモリ量が大幅に増加します。 +* 訓練中に会話の途中で画像が与えられた場合、前のテキストトークンは与えられた画像とのアテンションを求められないので、慎重な設計が必要です。 + +マルチモーダル Causal Attention (MMCA):これらの問題を解決するために、我々は新しいマルチモーダルCausal Attention (MMCA) を提案します。この機構は、CAと同様のパラメータ効率と、CrAと同様のデータ効率の、両方の利点を持つものです。全体的なアイデアは以下の通りです: + +* 視覚トークンは視覚エンコーダによってエンコードされるため、視覚トークンは自分自身とのアテンションのみを利用する。 +* テキストトークンについては、その前のすべてのトークンに注目する。ただし、前のテキストトークンと画像トークンに対して、それぞれ別々のアテンションの重み行列を持つ。 + +MMCAの2つ目のポイントは、1つのモダリティに対するアテンションの重みが、もう1つのモダリティに影響を与える可能性があるということです。例えば、テキストトークンは、視覚情報よりもテキスト情報により大きなアテンションを持つかもしれません。そのため、アテンションの重み行列を両方のモダリティで正規化すると、視覚トークンのアテンションスコアが非常に小さくなる可能性があります。3つのアテンション機構の視覚化については、*図2*を参照してください。 + +出力例 まず、異なるアテンションの機構を採用した、画像を一つだけ用いた会話におけるDeepSpeed-VisualChatの能力を示す様々な例を紹介します。これらの実験では、LLaMA2-7B言語モデルとQWen-VL視覚エンコーダを視覚エンコーダとして併用します。これら2つのモデルはlinear projection layerを介して接続されています。このモデルは2つのLLaVaデータセットで学習を行いました。*図3*と*図4*で実証されているように、DeepSpeed-VisualChatはMMCAと組み合わされることで、画像内の視覚的な詳細を効果的に識別し、ユーザーのクエリに対して首尾一貫した応答を提供します。さらに、MMCAは、Causal AttentionとCross Attentionの両方から合成されたマスクを使用するような、別のアテンション機構と比べて、より包括的で正確な画像詳細の把握が可能です。また、CrAとCAの組み合わせやMMCAとは対照的に、CA単独では若干エラーが多く(*図3*)、推論能力の程度が低い(*図4*)可能性があることも明らかです。 + +
+ Small kitten + + *図3: (1) 標準的なcausal attention (CA) (2) cross attentionと組み合わせた標準的なcausal attention (CA+CrA) (3)DeepSpeed-VisualChatの特別なマルチモーダルCausal Attention (MMCA) の出力比較を示す視覚入力と言語入力の例。* +
+ +
+ Beautiful lake + + *図4:DeepSpeed-VisualChatは、示された場面を美しい湖として正確に識別し、妥当な提案のセットを提示する。対照的に、ベースラインは画像を「ボート乗り場のあるドック」と誤認識している。* +
+ +# 4. データブレンディング + +[arxivのレポート](https://arxiv.org/abs/2309.14327)に記載されているように、訓練には3つのソースから9つのデータセットを使用しました。複数ラウンド・複数画像の入力を可能にするために決定的に欠けている要素は、適切なデータがないことです。我々が見つけた複数ラウンド・複数画像の唯一のデータソースはSparklesDialogueデータセットで、そこにはわずか6520サンプルしか含まれていません。この制限に対処するため、既存の単一画像または単一ラウンドのデータから、複数ラウンド・複数画像のデータを合成するために、単純なデータ連結とLLaVA-Otterデータ混合という2つの方法を採用しました。 + +## 4.1 単純なデータ連結 +LLaVAモデルで利用する "llava" と "llava_dial" データセットでは、各サンプルは1つの画像に対する単一/複数ラウンドの会話で構成されています。ユーザーが複数の画像について逐次質問するシナリオをシミュレートするため、これら2つのデータセットに対して、簡単なデータ後処理を行いました。具体的には、ランダムな数のサンプルを1つのサンプルとして連結しました。 "llava" の場合は1~3個のサンプルを連結し、"llava_dial" の場合は1~2個のサンプルを連結しました。 + +## 4.2 LLaVAとOtterのデータブレンディング + +LLaVAモデルで使用されているllavaとllava_dialデータセット、およびOtterモデルで使用されているotter_mimicit_cgdデータセットは、すべてCOCO train2017画像を使用しています。llavaデータセットとllava_dialデータセットには、各サンプルに1つの画像に対する単発/複数回の会話が含まれます。otter_mimicit_cgdデータセットでは、各サンプルは画像のペアに対する1ラウンドの会話を含んでいます。そこで、otter_mimicit_cgdデータセットの各サンプルについて、同じ画像を使うllavaとllava_dialのサンプルを探し、「llava/llava_dial会話 -> otter_mimicit_cgd会話」という流れで新しいサンプルを構築しました。 + +
+ Friends + + *図5: LLaVA-Otterデータブレンド後のデータサンプル。灰色のダイアログボックスはLLaVAデータセットから、オレンジ色のダイアログボックスはOtterデータセットからのもの* +
+ +# 5. デモ +いくつかのオープンソースデータセットで2Bビジュアルエンコーダーと13B LLaMAモデルを使い、DeepSpeed-VisualChat-13Bモデルを訓練しました。DeepSpeed-VisualChat-13Bは、画像キャプション機能(*図6-8*)、計数とテキスト読み取り(*図6*)、著名人の認識(*図7*)、ストーリーテリング(*図8*)などを示しています。 + +
+ Friends + + *図6: DeepSpeed-VisualChatは、画像内の人数を数え、最初の画像のテキストを読み取ることができます。また、複数画像を横断的に理解することも可能です。* +
+ + +
+ CEO + + *図7: DeepSpeed-VisualChatは有名人を認識し、その人物の業績と関連付けることができます* +
+ + +
+ Zootopia + + *図8: DeepSpeed-VisualChatは、ストーリーを作ったり、映画を認識したりできます。* +
+ + +# 6. DeepSpeed-VisualChatを使い始めるには +DeepSpeed-VisualChatは使いやすく、かつ優れたスケーラビリティを持つ学習フレームワークで、これまでLLaMa-2-70Bモデルでテストされています。 +すべての実験で統一された命令チューニング形式を採用しており、そのテンプレートを以下に示します。 + +``` + % You are a powerful vision-language assistant. + +### Image 1: % some image, e.g., cat-1.png +### Question: % please describe the image. +### Answer: % It's a cute black cat. + +### Image 2: % some image, e.g., cat-2.png +### Image 3: % some image, e.g., cat-3.png +### Question: % What's the difference between the three cats? +### Answer: % The colors of the three cats are different. +... +``` + +DeepSpeed-VisualChatの訓練は簡単かつ便利に実行できます。ここではCLIPビジュアルエンコーダーとLLaMa-7Bモデルを使用する例を示します: + +``` +git clone https://github.com/deepspeedai/DeepSpeedExamples.git +cd DeepSpeedExamples/applications/DeepSpeed-VisualChat/ +pip install -r requirements.txt +cd training +bash training_scripts/run_7b.sh +``` + +訓練されたチェックポイントは自動的にHugging Faceと互換性のある形式で保存され、独自のビジュアルチャットAPIを提供するために使用できます: + +``` +cd ../chat +bash chat_scripts/run.sh # You need to change necessary variables, e.g, ckpt path +``` + +より大規模なモデル推論をサポートするために、我々はHugging Faceの大規模モデル推論をDeepSpeed-VisualChat APIに組み込みました。そのため、ユーザーはGPUメモリ容量とモデルサイズに基づいて、異なるGPU数を選択することができます。 + +詳細は[ランディングページ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat)をご参照ください。 + +# 7. 早速使ってみましょう! + +DeepSpeed-VisualChatがオープンソース化され、AIコミュニティで利用できるようになったことを大変嬉しく思います。 + +* まずは、DeepSpeed-VisualChatのGitHubページをご覧ください: [GitHubランディングページ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat) + +* DeepSpeed-VisualChatは、皆様からのフィードバックとサポートにより改良を続けていきます。私たちの[ロードマップ](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat/README.md#-deepspeed-visualchats-roadmap-)は、現在サポートされている機能と将来的に計画している機能を示しています。 + +DeepSpeed-VisualChatは、さまざまなDeep Learningシステムやモデリング技術を含む、より大きなDeepSpeedエコシステムの一部です。詳細については、以下をご覧ください。 + +* 私たちの[ウェブサイト](https://www.deepspeed.ai/)で、詳細なブログ記事、チュートリアル、役立つドキュメントを提供しています。 +* DeepSpeedの最新ニュースは、[English X(Twitter)](https://twitter.com/DeepSpeedAI)、[Japanese X(Twitter)](https://twitter.com/DeepSpeedAI_JP)、[Chinese Zhihu](https://www.zhihu.com/people/deepspeed)をフォローしてください。 + +DeepSpeedは、皆様の開発への参加を歓迎しています。DeepSpeedのGitHubページで、バグ報告、Pull Request、ディスカッションへの参加が可能です。詳細は[ガイドライン](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)をご覧ください。また、大学、研究所、企業とのコラボレーションも行っています。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については まで直接メールをお送りください。 + +* 私たちの[DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/)および[DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/)リポジトリが気に入ったら、ぜひスターをつけてください! diff --git a/blogs/deepspeed-visualchat/10-03-2023/README.md b/blogs/deepspeed-visualchat/10-03-2023/README.md new file mode 100755 index 000000000000..3c9fd25e4acd --- /dev/null +++ b/blogs/deepspeed-visualchat/10-03-2023/README.md @@ -0,0 +1,188 @@ +
+ +# DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs + +
+ +
+ +DeepSpeed-VisualChat! + +
+ +To cite DeepSpeed-VisualChat, please cite our [arxiv report](https://arxiv.org/abs/2309.14327): + +``` +@article{yao2023deepspeed-visualchat, + title={{DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention}}, + author={Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qin and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He}, + journal={arXiv preprint arXiv:2309.14327}, + year={2023} +} +``` +# 1. Overview +Large Language models (LLMs), such as GPT and LLaMa, have showcased exceptional prowess in a myriad of text generation and comprehension tasks, especially when subjected to zero-/few-shot learning, particularly after instructed fine-tuning. However, to equip AI agents for diverse tasks, one critical feature that needs to be incorporated is multi-modal capability; for instance, the AI agent should be able to read images, hear voices, watch videos, etc. This capability is largely absent in solely text-based LLMs. + +Recently, one of the research/practice mainstreams has begun exploring the incorporation of visual capability into LLMs, especially enabling LLMs to understand images by inserting raw pictures (referred to as large visual language models, or LVLMs in short). + +The main caveats of the majority of existing works are: +* The focus is predominantly on tasks related to a single image, such as visual question answering and captioning, or on handling multiple images that require concurrent input. Neither approach adeptly manages interleaved image-and-text input. +* The scalability of the system is limited to models with ~10B parameters, which is about an order of magnitude smaller than largest open-sourced models. + +However, for a genuine AI chat agent, the content of inputs could be multiple images interleaved with text, a situation rarely addressed by current works. Also, the generation capability of LLMs grows quickly as the model size increases. Therefore, focusing system capability on ~10B models limits further exploration of the potential of LVLMs. + +To resolve these issues, we are introducing DeepSpeed-VisualChat (see [arxiv report](https://arxiv.org/abs/2309.14327) for more details) with the following new features: + +* ***Fully Open-Sourced Multi-round Multi-image Framework with Unprecedented Scalability***: DeepSpeed-VisualChat, one of the pioneering fully open-sourced frameworks, enables multi-round and multi-image dialogues, accommodating interleaved text-and-image inputs. We leverage DeepSpeed to enhance our training with a 2B visual encoder and a 70B LLaMA-2 decoder model, illustrating the remarkable scalability of our framework. +* ***Multi-Modal Causal Attention (MMCA)*** +We devise a novel MMCA for multi-modal models that computes attention weights independently across various modalities. MMCA achieves objectives analogous to conventional cross-attention mechanisms but offers enhanced causal attention interpretations for generative tasks, eliminating the need for additional modules or parameters. It also presents superior training data efficiency compared to standard causal attention. +* ***Data Blending for Interleaved Inputs*** To facilitate conversations with interleaved modalities, DeepSpeed-VisualChat employs assorted data blending techniques on existing datasets, overcoming the shortage of interleaved text-and-image inputs in most available open-source datasets. + + + +# 2 Model architecture overview +
+ model arch + + *Figure 1: Model architecture illustration.* + +
+ +The model architecture of DeepSpeed-VisualChat, as depicted in *Figure 1*, is composed of three components: a visual encoder, such as CLIP; a language decoder, such as LLaMa-7B; and a feature alignment linear projection layer. Most parts of the model are frozen, with only the embedding of the language model and the linear projection layer being trainable. Consequently, the total number of trainable parameters ranges from approximately O(10M) (LLaMa-2-13B) to O(100M) (LLaMa-2-70B). + +# 3. DeepSpeed multi-modal causal attention + +There are two common attention mechanisms used to connect the visual and textual components in a multi-modal model: causal attention, as used in MiniGPT and QWen-VL, and cross attention, as used in Otter and Flamingo. + +
+ Different attention mechanisms + + *Figure 2: Different Attention Mechanisms: Examine the differing attention mechanisms using an input sentence "User: Please describe the image." coupled with three Image tokens (I-token1, I-token2, I-token3). On the left, we demonstrate standard causal attention, treating image tokens as text. In the middle, we present cross attention applied to images, while maintaining standard causal attention for text tokens. On the right, we illustrate our innovative multi-modal attention proposal where image tokens only perform self-attention, and text tokens attend to text/image tokens independently, highlighted with an orange mask. This mechanism is defined by: softmax($`QK^T \odot M_1`$)+ softmax($`QK^T \odot M_2`$) with Q and K as query and key, $`M_1`$=[M==1], and $`M_2`$=[M==2], with M $`\in`$ R10x10 in this case.* +
+ + +Causal Attention (CA): The CA-based method simply projects visual features (i.e., the features from the output of the final visual encoder layer) into textual features and combines them with the normal textual features after the textual embedding layer to feed into LLMs. The benefit of CA is that it's a natural extension of the original attention mechanism in LLMs, and as such, it doesn't introduce any extra modules or parameters. However, this approach raises some intuitive problems: + +* For a visual token, it attends to previous visual and textual tokens, even though visual tokens are already fully encoded in a bidirectional manner and do not need further attention to other visual tokens or previous textual tokens. +* For a textual token, the model needs to learn how to distribute its attention weights between its previous textual and image tokens. Due to these issues, we found that the data efficiency of CA in LVLMs is often problematic. To address this, LLaVA and QWen-VL require visual-language pretraining to fully align visual features with textual features. + +Cross Attention (CrA): The alternative, cross attention (CrA), along with CA, exhibits better data efficiency but also comes with a few drawbacks: + +* It introduces new parameters to the model. For example, Otter has more than 1.5 billion trained parameters compared to the millions of trained parameters in LLaVA due to the new parameters introduced by cross attention. This significantly increases the training cost and memory requirements. +* It requires careful design if an image is introduced in the middle of a conversation during training, as previous text tokens should not be able to attend to the image. + +Multi-Modal Causal Attention Mechanism (MMCA): To overcome these issues, we propose a new multi-modal causal attention mechanism (MMCA), which has both benefits, i.e., similar parameter efficiency as CA and similar data efficiency as CrA. The overall idea is as follows: + +* For visual tokens, they only attend to themselves, as visual tokens are encoded by the visual encoder. +* For textual tokens, they attend to all their previous tokens. However, they have two separate attention weight matrices for their previous textual tokens and image tokens. + +The intuition behind the second point of MMCA is that the attention weight for one modality may affect the other modality. For instance, a textual token may pay more attention to textual information than visual information. Therefore, if the attention weight matrix is normalized across both modalities, the attention score for visual tokens might be very small. Refer to *Figure 2* for a visualization of the three attention mechanisms. + + +Demo Results. We begin by showcasing various examples that highlight the capabilities of DeepSpeed-VisualChat in single-image visual language conversations, employing different attention mechanisms. In these experiments, we employ the LLaMA2-7B language model in conjunction with the QWen-VL visual-encoder as our visual encoder. These two models are connected via a straightforward linear projection layer. Our model underwent training on two LLaVa datasets. As demonstrated in *Figure 3* and *Figure 4*, DeepSpeed-VisualChat, when coupled with MMCA, effectively discerns visual details in images and furnishes coherent responses to user queries. +Furthermore, DeepSpeed-VisualChat exhibits a more comprehensive and precise grasp of image details compared to alternative attention mechanisms, such as the use of combined masks from both causal attention and cross attention. It is also evident that, in contrast to the combination of CrA and CA, as well as MMCA, CA alone may exhibit slightly more errors (*Figure 3*) and capture a lower degree of reasoning capability (*Figure 4*). + +
+ Small kitten + + *Figure 3: Example visual and language inputs that demonstrate the output comparison between (1) the standard causal attention (CA) (2) the standard causal attention combined with cross-attention (CA+ CrA) and (3) the special multi-modal causal attention (MMCA) in DeepSpeed-VisualChat.* + +
+ +
+ Beautiful lake + + *Figure 4: DeepSpeed-VisualChat accurately identifies the scene as a beautiful lake and offers a set of plausible suggestions. In contrast, the baseline misinterprets the image as containing “dock with a boat ramp”.* + +
+ +# 4. Data blending +We used 9 datasets from 3 sources as described in our [arxiv report](https://arxiv.org/abs/2309.14327). A critical missing element for enabling multi-round and multi-image conversations is the absence of adequate data. The sole source of multi-round multi-image data we located is the SparklesDialogue dataset, which contains a mere 6520 samples. To address this limitation, we employed two methods to synthesize multi-round multi-image data from existing single-image or single-round data: simple data concatenation and LLaVA-Otter data blending. + +## 4.1 Simple data concatenation +For the "llava" and "llava_dial" datasets utilized by the LLaVA model, each sample comprises single/multi-round conversations for a single image. To simulate scenarios where a user sequentially asks questions about multiple images, we conducted straightforward data post-processing for these two datasets. Specifically, we randomly concatenated different numbers of samples into a single sample. In the case of "llava," we concatenated 1 to 3 samples, while for "llava_dial," we concatenated 1 to 2 samples. + +## 4.2 LLaVA-Otter data blending +We noticed that the llava and llava_dial datasets used by LLaVA model and the otter_mimicit_cgd dataset used by the Otter model all use the COCO train2017 images. For the llava and llava_dial datasets, each sample includes a single/multi-round conversations for a single image. For the otter_mimicit_cgd dataset, each sample includes a single-round conversation for a pair of images. This enables us to build a synthesized multi-round multi-image data llava_otter_blend as a more natural blending: for each sample in the otter_mimicit_cgd dataset, we look for llava and llava_dial samples that use the same image, and then build a new sample in a "llava/llava_dial conversations then otter_mimicit_cgd conversation" fashion. + +
+ Friends + + *Figure 5: A data sample after LLaVA-Otter data blending. Gray dialog boxes are from LLaVA datasets, and orange ones are from Otter dataset.* +
+ +# 5. Demonstration +We trained our DeepSpeed-VisualChat-13B model with a 2B visual encoder and the 13B LLaMA model on several open-sourced datasets. DeepSpeed-VisualChat-13B shows image captioning capabilities (*Figure 6--8*), counting and text reading (*Figure 6*), celebrity recognition (*Figure 7*), storytelling (*Figure 8*), etc. + +
+ Friends + + *Figure 6: DeepSpeed-VisualChat can count the number of people in the image and read the text in the first image. It also demonstrates cross-image understanding.* +
+ + +
+ CEO + + *Figure 7: DeepSpeed-VisualChat can recognize celebrities and associate them with their achievements.* +
+ + +
+ Zootopia + + *Figure 8: DeepSpeed-VisualChat can tell stories and recognize movies.* +
+ + +# 6. How to begin with DeepSpeed-VisualChat +DeepSpeed-VisualChat is an easy-to-use training framework with great scalability, having been tested up to LLaMa-2-70B models so far. We adopt a unified instruction tuning format for all experiments, and the template is shown below. +``` + % You are a powerful vision-language assistant. + +### Image 1: % some image, e.g., cat-1.png +### Question: % please describe the image. +### Answer: % It's a cute black cat. + +### Image 2: % some image, e.g., cat-2.png +### Image 3: % some image, e.g., cat-3.png +### Question: % What's the difference between the three cats? +### Answer: % The colors of the three cats are different. +... +``` + +The training experience of DeepSpeed-VisualChat is straightforward and convenient. Here we give an example based on the CLIP visual encoder and the LLaMa-7B model: +``` +git clone https://github.com/deepspeedai/DeepSpeedExamples.git +cd DeepSpeedExamples/applications/DeepSpeed-VisualChat/ +pip install -r requirements.txt +cd training +bash training_scripts/run_7b.sh +``` + +The trained checkpoint will be automatically saved in a Hugging Face-compatible version and can be used to launch your own visual chat API: +``` +cd ../chat +bash chat_scripts/run.sh # You need to change necessary variables, e.g, ckpt path +``` +To support larger model inference, we have incorporated Hugging Face large model inference into our DeepSpeed-VisualChat API. Therefore, users can choose a different number of GPUs based on the GPU memory capacity and the model size. + +Please refer to our [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat) for more details. + +# 7. Release: Try DeepSpeed-VisualChat today! + +We are very excited to share that DeepSpeed-VisualChat is now open-sourced and available to the AI community. + +* To get started, please visit our GitHub page for DeepSpeed-VisualChat: [GitHub Landing Page](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat) + +* We will continue to improve DeepSpeed-VisualChat with your feedback and support. Our [roadmap](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/applications/DeepSpeed-VisualChat/README.md#-deepspeed-visualchats-roadmap-) shows currently supported features as well as ones that are planned for the future. + + +DeepSpeed-VisualChat is a component of the larger DeepSpeed ecosystem, which includes a range of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* Follow us on our [English X(Twitter)](https://twitter.com/DeepSpeedAI), [Japanese X(Twitter)](https://twitter.com/DeepSpeedAI_JP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +We welcome your contributions to DeepSpeed! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) page. Please see our [contributing guide](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to info@deepspeed.ai. + +* "Star" our [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/) and [DeepSpeedExamples GitHub](https://github.com/deepspeedai/DeepSpeedExamples/) repositories if you like our work! diff --git a/blogs/deepspeed-visualchat/assets/images/attention.png b/blogs/deepspeed-visualchat/assets/images/attention.png new file mode 100644 index 000000000000..b01d8f8027ce Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/attention.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/cat-chat.png b/blogs/deepspeed-visualchat/assets/images/cat-chat.png new file mode 100755 index 000000000000..5a5c27381f65 Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/cat-chat.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/ceos.png b/blogs/deepspeed-visualchat/assets/images/ceos.png new file mode 100644 index 000000000000..e148f545a44b Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/ceos.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/data-blending.png b/blogs/deepspeed-visualchat/assets/images/data-blending.png new file mode 100644 index 000000000000..a8afb5144fb1 Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/data-blending.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/friends.png b/blogs/deepspeed-visualchat/assets/images/friends.png new file mode 100644 index 000000000000..2689d8d4bb1c Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/friends.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/hero-figure.png b/blogs/deepspeed-visualchat/assets/images/hero-figure.png new file mode 100644 index 000000000000..ca79b2c6239f Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/hero-figure.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/lake-chat.png b/blogs/deepspeed-visualchat/assets/images/lake-chat.png new file mode 100755 index 000000000000..c47199737d54 Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/lake-chat.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/model.png b/blogs/deepspeed-visualchat/assets/images/model.png new file mode 100644 index 000000000000..dbd1f05c484b Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/model.png differ diff --git a/blogs/deepspeed-visualchat/assets/images/zootopia.png b/blogs/deepspeed-visualchat/assets/images/zootopia.png new file mode 100644 index 000000000000..c9e3783ed198 Binary files /dev/null and b/blogs/deepspeed-visualchat/assets/images/zootopia.png differ diff --git a/blogs/deepspeed-zenflow/README.md b/blogs/deepspeed-zenflow/README.md new file mode 100644 index 000000000000..fddef9ddb933 --- /dev/null +++ b/blogs/deepspeed-zenflow/README.md @@ -0,0 +1,183 @@ +

+ zenflow logo +

+
+ +# ZenFlow: Stall-Free Offloading Engine for LLM Training + +
+ + +
+ + +*Figure 1: ZenFlow is DeepSpeed’s stall-free offloading engine for LLM training. It decouples GPU and CPU updates by prioritizing important gradients for immediate GPU updates and deferring the rest for asynchronous CPU-side accumulation. By fully overlapping CPU work and PCIe transfers with GPU computation, ZenFlow eliminates stalls and achieves high hardware utilization across both single-GPU and multi-GPUs settings.* + +## Table of Content + +- [ZenFlow: Stall-Free Offloading Engine for LLM Training](#zenflow-stall-free-offloading-engine-for-llm-training) + - [Table of Content](#table-of-content) + - [Introduction](#introduction) + - [ZenFlow at a Glance](#zenflow-at-a-glance) + - [ZenFlow Highlights](#zenflow-highlights) + - [Design Motivation](#design-motivation) + - [ZenFlow Design](#zenflow-design) + - [Getting Started: Try out DeepSpeed-ZenFlow](#getting-started-try-out-deepspeed-zenflow) + - [Citation](#citation) + - [Acknowledgements](#acknowledgements) + +--- + +## Introduction + +
+ + +
+ +*Figure 2: ZeRO-Offload causes repeated GPU stalls due to blocking CPU updates and PCIe transfers, leading to >60% idle time per step when training Llama 2-7B on 4× A100s.* + +Offloading has become a standard approach to scale fine-tuning of large language models (LLMs) beyond GPU memory limits. Frameworks like ZeRO-Offload reduce GPU memory usage by pushing gradients and optimizer states to the CPU. However, they also create a new bottleneck: expensive GPUs often sit idle, waiting on slow CPU updates and PCIe data transfers. In practice, enabling offloading when training Llama 2-7B on 4× A100 GPUs can inflate each step from 0.5s to over 7s—a 14× slowdown. + +
+ + +
+ +*Figure 3: In ZeRO-Offload, CPU-side optimizer updates and PCIe transfers dominate iteration time, leaving the GPU idle for over 5 seconds.* + +**ZenFlow** addresses this bottleneck with a stall-free training pipeline. It prioritizes high-impact gradients for immediate GPU updates, while offloading the rest to the CPU and applying them asynchronously. These deferred CPU updates are fully overlapped with GPU compute, eliminating stalls and significantly improving throughput. Best of all, ZenFlow maintains the same model accuracy and integrates seamlessly with DeepSpeed. + +--- + +## ZenFlow at a Glance + +- **Zero GPU stalls:** Top-k important gradients are updated immediately on GPU; low-priority gradients are asynchronously processed on CPU—no GPU wait time. +- **Asynchronous and bounded:** ZenFlow decouples CPU and GPU execution with a bounded-staleness strategy that preserves convergence. +- **Auto-tuned:** ZenFlow adapts update intervals at runtime based on gradient dynamics—no need to tune manually. + +--- + +## ZenFlow Highlights + +ZenFlow is the **first offloading framework** to offer a **bounded-asynchronous** update scheme that preserves convergence while delivering **up to 5× end-to-end speed-up** over ZeRO-Offload. + +### Performance + +| Feature | Benefit | +|--------|---------| +| Up to **5×** end-to-end speed-up over ZeRO-Offload and **6.3×** over ZeRO-Infinity | Faster time-to-convergence | +| **> 85% reduction in GPU stalls** on A100 / H100 nodes | Keeps GPUs busy, higher utilization | +| **≈ 2× lower PCIe traffic** (1.13× model size per step vs. 2× in ZeRO) | Less bandwidth pressure on clusters | +| **Maintains or improves accuracy** on GLUE (OPT-350M → Llama-13B) | No accuracy loss | +| **Lightweight gradient selection** (6000× cheaper than full AllGather) | Scales to multi-GPU settings without memory footprint spikes | +| **Auto-tuning (Zen-auto)** automatically adapts update interval on-the-fly | No manual knob tuning | + +For more detailed performance results, please refer to our [arXiv paper](https://arxiv.org/abs/2505.12242). + +--- + +## Design Motivation + +Training large models with offloading can save GPU memory, but often at the cost of *performance*. In this section, we briefly discuss three topics. **First**, we explain why coupling CPU-side optimizer updates with GPU compute leads to severe GPU stalls during LLM fine-tuning. **Next**, we quantify how full-gradient offloading saturates the limited PCIe bandwidth on A100/H100 servers, inflating iteration time. **Finally**, we reveal the highly skewed importance distribution of gradients, showing that uniformly updating all parameters in GPUs at the same time is wasteful and unnecessary. + +### Offloading-Induced GPU Stalls + + +
+ + +
+ +*Figure 4: CPU updates dominate step time, causing >60% GPU idle due to poor overlap with compute.* + +Synchronous offloading frameworks (e.g., ZeRO-Offload) keep the GPU idle while the CPU performs a full optimizer step and transfers updated parameters back to GPU. For Llama-2-7B with 4× A100, the CPU path can take **longer than 4s** while the backward pass takes **approximately 2s**, so **over 60% of each iteration is pure GPU wait time**. Eliminating this serialization is essential for achieving high GPU utilization. + +### Bandwidth Bottlenecks + +A single training step moves a full copy of the model gradients from GPU to CPU and a full copy of the model parameters back, i.e., **2× model size of PCIe traffic per step**. Even on PCIe 4.0 (≈ 32 GB/s), Llama-2-13B pushes ~40 GB per iteration, adding **> 1s** of transfer latency. +### Unequal Gradient Importance + +Not all gradients matter equally. Our analysis shows that **the top 1% of gradient channels contribute over 90% of the ℓ²-norm energy** during fine-tuning. In other words, most updates have little impact on model learning, yet still incur disproportionately high compute and I/O costs in traditional offloading pipelines. + +This skew in gradient importance opens the door to a better design: update critical gradients on GPU right away, and defer the rest for asynchronously batched, lower-priority updates on CPU. ZenFlow turns this idea into a principled, efficient training engine. + +
+ + +
+ +*Figure 5: Top 1% of gradients may contribute over 85% of gradient norms.* + +--- + +## ZenFlow Design + +ZenFlow is designed around three key ideas that separate critical and non-critical gradient updates while minimizing communication bottlenecks. Here's how we break the tight coupling between GPU and CPU computation to create a **stall-free** pipeline. + +### Idea 1: Importance-Aware Top-k Gradient Update + +Not all gradients are equally impactful for training. ZenFlow introduces an **importance-aware** design that prioritizes updates for the top-k most significant gradients. These gradients are updated directly on the GPU, using its high compute bandwidth. This approach allows us to **reduce the size of the per-step gradient update** by nearly **50%**, cutting down the communication load by around 2×. + +For the rest of the gradients, which contribute less to the model's learning, ZenFlow batches them and performs asynchronous updates on the CPU. These updates are **deferred** until they are sufficiently accumulated, thereby reducing the impact on training speed. + +### Idea 2: Bounded-Asynchronous CPU Accumulation + +ZenFlow’s **asynchronous accumulation** allows the CPU to stay busy while the GPU performs other computations. We apply an **accumulation window** for the non-critical gradients, allowing them to accumulate over several iterations before updating. This gives ZenFlow the ability to process **multiple rounds of gradient updates** concurrently, eliminating idle time typically spent waiting for the CPU optimizer. + +By carefully coordinating CPU updates with GPU execution, ZenFlow **fully hides CPU execution** behind GPU computation—ensuring that GPUs remain actively utilized, avoiding stalls, and **maximizing hardware efficiency**. + +### Idea 3: Lightweight Gradient Selection + +A key challenge in distributed training is **selecting important gradients** without introducing prohibitive communication and GPU memory costs. Traditional systems rely on global synchronization (via `AllGather`) to gather full gradients, which can become a major bottleneck in multi-GPU settings. + +ZenFlow solves this with a **lightweight gradient proxy**: instead of transferring full gradients, ZenFlow uses a **per-column gradient norm** to approximate the importance of each gradient. By computing a compact summary of per-column gradients (e.g., squared norms), ZenFlow reduces communication volume by more than **4,000×**—with nearly no loss in accuracy. + +This approach allows ZenFlow to **scale efficiently across GPUs**, without high memory or communication overhead, and it supports **dynamic gradient selection** as the model evolves. + +### Putting It All Together: ZenFlow’s Zero-Stall Pipeline + +
+ + +
+ + +*Figure 6: ZenFlow’s stall-free pipeline overlaps CPU updates and transfers with multi-steps GPU compute.* + +1. **Forward/Backward Pass on GPU:** ZenFlow processes the forward and backward passes on the GPU, immediately updating the **top-k gradients** on the GPU without waiting for the CPU. + +2. **Gradient Transfer to CPU:** While the GPU is busy, gradients from the current iteration (or previous ones) are transferred to the CPU over a dedicated PCIe stream. This is done in parallel with GPU computation, without causing any GPU wait time. + +3. **CPU Update:** Once a batch of non-critical gradients has accumulated, the CPU performs the update asynchronously. This update typically spans multiple GPU iterations, but is hidden behind GPU work, making it virtually invisible to the overall pipeline. + +4. **Double Buffering:** ZenFlow uses **double buffering** to manage the newly updated gradients. When the CPU update is complete, the new parameters are transferred back to the GPU. The swap is as fast as a pointer flip—no need to reload the entire model or re-launch the kernel. + +By constantly **overlapping GPU computation with CPU-side work**, ZenFlow transforms the traditional compute → wait → update cycle into a continuous, **stall-free pipeline**. + +--- + +## Getting Started: Try out DeepSpeed-ZenFlow + +To try out DeepSpeed-ZenFlow, please refer to the [ZenFlow tutorial](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/training/DeepSpeed-ZenFlow/README.md) in our DeepSpeedExamples repo. + +--- + +## Citation + +```bibtex +@article{lan2025zenflow, + title = {ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates}, + author = {Tingfeng Lan and Yusen Wu and Bin Ma and Zhaoyuan Su and Rui Yang and Tekin Bicer and Masahiro Tanaka and Olatunji Ruwase and Dong Li and Yue Cheng}, + journal = {arXiv preprint arXiv:2505.12242}, + year = {2025} +} +``` + +--- + +## Acknowledgements + +This work is the result of a close collaboration between University of Virginia (UVA), University of California, Merced (UC Merced), Argonne National Laboratory (ANL) and DeepSpeed team. + +The contributors include [Tingfeng Lan](https://antlera.github.io/), [Yusen Wu](https://joshwoo2003.github.io/), [Zhaoyuan Su](https://alexsssu.github.io/), [Rui Yang](https://ruiyang00.github.io/), and [Yue Cheng](https://tddg.github.io/) from UVA; [Bin Ma](https://www.linkedin.com/in/bin-ma-ba665b182/) and [Dong Li](https://faculty.ucmerced.edu/dong-li/) from UC Merced; [Tekin Bicer](https://www.anl.gov/profile/tekin-bicer) from ANL; [Olatunji Ruwase](https://www.linkedin.com/in/tunji-ruwase-088952/) and [Masahiro Tanaka](https://www.linkedin.com/in/masahiro-tanaka-77482926/) from the DeepSpeed team. We especially thank [Olatunji Ruwase](https://www.linkedin.com/in/tunji-ruwase-088952/) and [Masahiro Tanaka](https://www.linkedin.com/in/masahiro-tanaka-77482926/) for their early feedback and insightful discussions and also for open-source community support. diff --git a/blogs/deepspeed-zenflow/images/zenflow-example.png b/blogs/deepspeed-zenflow/images/zenflow-example.png new file mode 100644 index 000000000000..316e8123eccf Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-example.png differ diff --git a/blogs/deepspeed-zenflow/images/zenflow-gradients.png b/blogs/deepspeed-zenflow/images/zenflow-gradients.png new file mode 100644 index 000000000000..017d5e7ba0a7 Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-gradients.png differ diff --git a/blogs/deepspeed-zenflow/images/zenflow-logo.png b/blogs/deepspeed-zenflow/images/zenflow-logo.png new file mode 100644 index 000000000000..1e6021d36e98 Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-logo.png differ diff --git a/blogs/deepspeed-zenflow/images/zenflow-no-overlap.png b/blogs/deepspeed-zenflow/images/zenflow-no-overlap.png new file mode 100644 index 000000000000..7995d8d4daa0 Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-no-overlap.png differ diff --git a/blogs/deepspeed-zenflow/images/zenflow-overview.png b/blogs/deepspeed-zenflow/images/zenflow-overview.png new file mode 100644 index 000000000000..c6d4e41132a8 Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-overview.png differ diff --git a/blogs/deepspeed-zenflow/images/zenflow-workflow.png b/blogs/deepspeed-zenflow/images/zenflow-workflow.png new file mode 100644 index 000000000000..6f704f7a48ec Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zenflow-workflow.png differ diff --git a/blogs/deepspeed-zenflow/images/zero-offload-stall.png b/blogs/deepspeed-zenflow/images/zero-offload-stall.png new file mode 100644 index 000000000000..f68f4421af33 Binary files /dev/null and b/blogs/deepspeed-zenflow/images/zero-offload-stall.png differ diff --git a/blogs/deepspeed4science/README.md b/blogs/deepspeed4science/README.md new file mode 100644 index 000000000000..a318490329a5 --- /dev/null +++ b/blogs/deepspeed4science/README.md @@ -0,0 +1,18 @@ +
+ +# Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies + +
+ +[https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/) + +To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610): + +``` +@article{song2023deepspeed4science, + title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies}, + author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others}, + journal={arXiv preprint arXiv:2310.04610}, + year={2023} +} +``` diff --git a/blogs/deepspeed4science/chinese/README.md b/blogs/deepspeed4science/chinese/README.md new file mode 100644 index 000000000000..e1672c044b96 --- /dev/null +++ b/blogs/deepspeed4science/chinese/README.md @@ -0,0 +1,156 @@ +
+ +# DeepSpeed4Science:利用先进的AI系统优化技术实现科学发现 + +
+ +*此博客为英文博客[Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)的官方翻译* + +
+ + +*图1:DeepSpeed4Science方法概述:专为加速科学发现和应对其复杂性而量身定制的AI系统技术开发。* +
+ +如需引用 DeepSpeed4Science,请引用我们的[white paper](https://arxiv.org/abs/2310.04610): + +``` +@article{song2023deepspeed4science, + title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies}, + author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others}, + journal={arXiv preprint arXiv:2310.04610}, + year={2023} +} +``` + +## 简介 + +在接下来的十年中,深度学习可能会彻底改变自然科学,增强我们对自然现象进行建模和预测的能力。这可能预示着科学探索的新时代,为从药物开发到可再生能源的各个领域带来重大进展。为了响应这一机会以及微软“予力全球每一人、每一组织,成就不凡”的使命,[微软DeepSpeed团队](https://www.deepspeed.ai/)启动了一个名为[DeepSpeed4Science](https://deepspeed4science.ai/)的新计划,旨在通过AI系统技术创新帮助领域专家解锁当今最大的科学之谜。 + +[DeepSpeed](https://www.deepspeed.ai/)系统是由微软开发的业界领先的开源AI系统框架,它为各种AI硬件上的深度学习训练和推理提供了前所未有的规模和速度。图1展示了我们对DeepSpeed4Science这一新计划的基本方法。通过利用DeepSpeed当前的技术方案(训练、推理和压缩)作为基础技术推动器,DeepSpeed4Science将创建一套专为加速科学发现而量身定制的AI系统技术,以应对其独特的复杂性,超越用于加速通用大型语言模型(LLMs)的常见技术方法。我们与拥有科学AI模型的内部和外部团队紧密合作,以发现和解决领域特定AI系统的挑战。这包括气候科学、药物设计、生物学理解、分子动力学模拟、癌症诊断和监测、催化剂/材料发现、和其他领域。 + +我们的长期愿景是将DeepSpeed4Science发展成一个用于分享支持科学发现的先进AI技术的软件平台和统一代码仓库。DeepSpeed4Science的设计旨在包容性,呼应微软的[“AI for Good”承诺](https://www.microsoft.com/en-us/ai/ai-for-good)。这体现在该计划对一系列标志性科学模型的支持上,他们代表了一些最关键的AI4Science应用场景。在这篇博客中,我们展示了DeepSpeed4Science如何帮助解决结构生物学研究中的两个关键AI系统挑战:(1) 解决了以Evoformer为中心的蛋白质结构预测模型中的内存爆炸问题,以及(2)为更好地理解引发大流行的病毒的进化提供AI模型长序列支持。 + +## 我们的初期主要合作者 + +DeepSpeed4Science的新系统技术可以用于很多推动科学边界的标志性模型,赋能AI驱动的科学发现。目前,DeepSpeed4Science很荣幸地支持来自[微软研究院AI4Science](https://www.microsoft.com/en-us/research/lab/microsoft-research-ai4science/)、[微软WebXT/Bing](https://www.msn.com/en-us/weather/forecast/)、[美国能源部国家实验室](https://www.energy.gov/national-laboratories)和多所大学的几个关键科学模型。 + +### 微软内部合作伙伴 + +#### 科学基础模型(Scientific Foundation Model,SFM),微软研究院AI4Science + +
+ + + +*图2:科学基础模型(Scientific Foundation Model,SFM)及其当前探索:Distributional Graphormer。* +
+ +科学基础模型(SFM)旨在创建一个统一的大规模基础模型,以支持自然科学发现,支持多种输入、多个科学领域(例如,药物、材料、生物学、健康等)和计算任务。DeepSpeed4Science合作伙伴关系将为SFM团队提供新的训练和推理技术,以支持他们的新生成AI方法(例如[Distributional Graphormer](https://www.microsoft.com/en-us/research/blog/distributional-graphormer-toward-equilibrium-distribution-prediction-for-molecular-systems/))这样的项目进行持续研究。 + +#### ClimaX,微软研究院AI4Science + +
+ + +*图3:ClimaX是第一个设计用于执行各种天气和气候建模任务的基础模型。* +
+ +我们的气候正在发生变化,导致极端天气事件的频率增加。为了减轻负面影响,预测这些事件将发生的地方变得越来越重要。[ClimaX](https://www.microsoft.com/en-us/research/group/autonomous-systems-group-robotics/articles/introducing-climax-the-first-foundation-model-for-weather-and-climate/)是第一个设计用于执行各种天气和气候建模任务的基础模型。它可以吸收许多具有不同变量和分辨率的数据集以提高天气预报的准确性。DeepSpeed4Science正在为ClimaX创建新的系统支持和加速策略,以高效地预训练/微调更大的基础模型,同时处理非常大的高分辨率图像数据(例如,数十到数百PB)和长序列。 + +#### 分子动力学和机器学习力场(Molecular Dynamics and Machine Learning Force Field),微软研究院AI4Science + +
+ + +*图4:一百万步的分子动力学模拟:RBD-蛋白(RBD-protein)与蛋白抑制剂(protein inhibitor)相互作用。* +
+ +这个项目模拟了使用[AI驱动的力场模型](https://www.microsoft.com/en-us/research/publication/ai2bmd-efficient-characterization-of-protein-dynamics-with-ab-initio-accuracy/)进行近似第一性原理计算精度的大型(百万原子)分子系统的动态模拟,同时保持了经典分子动力学的效率和可扩展性。这些模拟足够高效,可以生成足够长的轨迹来观察化学上有意义的事件。通常,这个过程需要数百万甚至数十亿的推理步骤。这对优化图神经网络(GNN)+ LLM模型的推理速度提出了重大挑战,DeepSpeed4Science将为此提供新的加速策略。 + +#### 微软天气,微软WebXT/Bing + +
+ + +*图5:微软降水预报(每4分钟一次对接下来4小时进行预测)。* +
+ +[微软天气](https://www.msn.com/en-us/weather/forecast/)提供精确的天气信息,[帮助用户为他们的生活方式、健康、工作和活动做出更好的决策](https://blogs.windows.com/windowsexperience/2022/08/31/microsoft-joins-noaas-weather-ready-nation-ambassador-initiative-to-help-improve-americas-readiness-and-response-to-weather-events/)——包括每小时多次更新的准确的10天全球天气预报。此前,微软天气受益于DeepSpeed技术,加速了他们的多GPU训练环境。目前,DeepSpeed4Science正在与微软WebXT天气预报团队合作,进一步增强微软天气预报服务的最新功能和改进。 + +### 外部合作者 + +DeepSpeed4Science的旅程始于两个开创性的基于LLM的结构生物学研究AI模型:来自哥伦比亚大学的[OpenFold](https://openfold.io/),一个开源的高保真蛋白质结构预测模型;以及来自[阿贡国家实验室](https://www.anl.gov/)的[GenSLMs](https://github.com/ramanathanlab/genslm),一个获得[ACM戈登贝尔奖](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)的用于学习SARS-CoV-2(COVID-19)基因组的进化的语言模型。作为此次发布的特色展示,它们代表了当今AI驱动的结构生物学研究面临的两个常见AI系统挑战。我们将在下一节中讨论DeepSpeed4Science如何赋能这些科学研究。 + +此外,DeepSpeed4Science最近扩大了其范围,以支持更多样的科学模型。例如,在我们与阿贡国家实验室合作训练[Aurora Exascale系统](https://www.anl.gov/aurora)上的万亿参数科学模型的工作中,DeepSpeed4Science技术将帮助他们达到这一关键任务所需的性能要求和可扩展性。此外,通过与[橡树岭国家实验室](https://ai-roadmap.ornl.gov/)和[国家癌症研究所(NCI)](https://www.cancer.gov/)合作进行癌症监测,DeepSpeed4Science将帮助从非结构化的临床文本中高保真地提取和分类信息,以供[MOSSAIC项目](https://www.olcf.ornl.gov/tag/mossaic/)使用。[Brookhaven国家实验室](https://www.bnl.gov/world/)还将采用DeepSpeed4Science技术,支持使用LLMs开发大型数字双胞胎模型,以便为清洁能源研究产生更真实的模拟数据。您可以在[deepspeed4science.ai](https://deepspeed4science.ai/)上找到有关我们外部合作者及其科学任务的更多详细信息。 + +## 合作展示 + +### 展示(I):DeepSpeed4Science通过DS4Sci_EvoformerAttention消除以Evoformer为中心的结构生物学模型的内存爆炸问题 + +
+ + + +*图6:在训练过程中OpenFold对PDB链7B3A_A的预测。* +
+ +[OpenFold](https://github.com/aqlaboratory/openfold)是DeepMind的[AlphaFold2](https://alphafold.com/)的开源社区再现,使其可以在新数据集上训练或微调AlphaFold2。研究人员已经使用它从头开始重新训练AlphaFold2,生成新的模型参数集,研究AlphaFold2的早期训练阶段(图6),并开发新的蛋白质折叠系统。 + +
+ + +*图7:在OpenFold中,对多序列比对(MSA)Attention内核(包含偏差)变体的训练峰值内存需求。 (左) 使用在AlphaFold2中的EvoformerAttention的原始OpenFold实现。对于这些类型的蛋白质结构预测模型,在训练/推理中的内存爆炸问题是常见的。最先进的FlashAttention无法有效支持这样的Attention变体。 (右) DeepSpeed4Science的一种新解决方案DS4Sci_EvoformerAttention在不影响模型品质的条件下显著地减少了OpenFold的训练峰值内存需求(最多13倍)。* +
+ +尽管OpenFold有使用最先进的系统技术进行性能和内存优化,但从头开始训练AlphaFold2仍然在计算上很昂贵。目前阶段的模型参数很小,只有9300万个参数,但它包含了几个需要非常大的中间内存的特殊Attention变体。在标准AlphaFold2训练的“微调”阶段,只是这些变体中的其中一个在半精度下就生成了超过12GB的张量,使其峰值内存要求远远超过了相同大小的语言模型。即使使用像activation checkpointing和DeepSpeed ZeRO优化这样的技术,这种内存爆炸问题仍然严重限制了可训练模型的序列长度和MSA深度。此外,近似策略可能会显著影响模型的准确性和收敛性,同时仍然导致内存爆炸,如图7左侧(橙色)所示。 + +为了应对结构生物学研究(例如,蛋白质结构预测和平衡分布预测)中的这一常见系统挑战,DeepSpeed4Science通过为这类科学模型中广泛出现的注意力变体(即EvoformerAttention)设计定制的精确注意力内核来解决这一内存效率问题。具体来说,我们设计了一套由复杂的融合/矩阵分块策略和动态内存减少方法而组成的高内存效率DS4Sci_EvoformerAttention内核,作为高质量机器学习模块供更广泛的生物学研究社区使用。通过整合到OpenFold中,这些定制内核在训练期间提供了显著的加速,并显著减少了模型的训练和推理的峰值内存需求。这使得OpenFold可以用更大、更复杂的模型,使用更长的序列在更广泛的硬件上进行实验。关于这项技术的详细信息可以在[这里](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/)找到。 + +### 展示(II):DeepSpeed4Science通过系统和算法方法为基因组基础模型(例如,GenSLMs)提供长序列支持 + +
+ + +*图8:GenSLMs:获2022年ACM 戈登贝尔奖的COVID基因组模型(基于GPT-NeoX的25B/33B模型)。它用于学习描述SARS-CoV-2基因组生物学意义的潜在空间。这个GIF展示了一个重要的蛋白质家族苹果酸脱氢酶(malate dehydrogenase)的根据重要特征(如序列长度和GC含量(核酸鸟嘌呤和胞嘧啶的含量与腺嘌呤和胸腺嘧啶的比率。它测量DNA链抵抗热的能力))着色的潜在空间的投影。* +
+ +[GenSLMs](https://github.com/ramanathanlab/genslm),一个来自阿贡国家实验室的[2022年ACM 戈登贝尔奖获奖](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)的基因组模型,可以通过大型语言模型(LLMs)的基因组数据训练来学习SARS-CoV-2(COVID-19)基因组的进化。它旨在改变如何识别和分类引发大流行的病毒(特别是SARS-CoV-2)的新变种。GenSLMs代表了第一批可以泛化到其他预测任务的基因组基础模型。对潜在空间的良好理解可以帮助GenSLMs处理超出仅仅是病毒序列的新领域,并扩展它们模拟细菌病原体甚至真核生物的能力(例如,理解功能、途径成员资格和进化关系等事物)。为了实现这一科学目标,GenSLMs和类似的模型需要非常长的序列支持用于训练和推理,这超出了像[FlashAttention](https://arxiv.org/abs/2307.08691)这样的通用LLM的长序列策略。通过DeepSpeed4Science的新设计,科学家现在可以构建和训练具有显著更长的上下文窗口的模型,允许他们探索以前无法访问的关系。 + +
+ + +*图9:由不同框架在不同规模下支持的两个GenSLMs模型的最大序列长度。使用NVIDIA DGX,每个节点有八个40G A100 GPU。* +
+ +具体在系统层面,我们发布了包括[长序列支持和其他新优化](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support)的最新的[Megatron-DeepSpeed框架](https://github.com/deepspeedai/Megatron-DeepSpeed)。科学家现在可以通过我们新添加的内存优化技术(如注意力掩码异步处理和位置码分割)、张量并行、流水线并行、序列并行、基于ZeRO的数据并行和模型状态异步处理等技术的协同组合,用更长的序列训练他们的GenSLMs等大型科学模型。图9展示了我们的新版本使GenSLMs的25B和33B模型的最长序列长度分别比之前的Megatron-DeepSpeed版本增加了12倍和14倍。在支持的序列长度方面,这个新Megatron-DeepSpeed框架也显著地超过了NVIDIA的Megatron-LM(对于25B和33B模型分别高达9.8倍和9.1倍)。例如,阿贡实验室团队的GenSLMs 25B模型在64个GPU上的原始序列长度为42K,而现在可以用512K的核苷酸序列进行训练。这在不损失准确性的条件下大大提高了模型质量和科学发现的范围。对于那些更喜欢相对位置编码技术这样的算法策略的领域科学家,这个[新版本](https://deepspeed4science.ai/2023/09/18/model-showcase-genslms/)也进行了集成。 + +## 总结和路线图 + +我们非常自豪和兴奋地宣布DeepSpeed4Science计划以及几个研发亮点和成果。从今天开始,我们将在[deepspeed4science.ai](https://deepspeed4science.ai/)上介绍我们的新计划,包括关于我们的外部合作者的信息,以及当前和未来的DeepSpeed4Science技术发布。我们的一个高层次目标是推广广泛解决大规模科学发现的主要系统痛点的AI系统技术。我们希望全球的科学家们能够从DeepSpeed4Science通过开源软件解锁的新功能中受益。我们期待更好地了解阻碍您的科学发现的AI系统设计挑战。我们真诚地欢迎您的参与,帮助构建一个更有前途的AI4Science未来。请给我们发送电子邮件至。我们鼓励您在我们的[DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/)上报告问题、贡献PR、参与讨论。 + +## 致谢 + +**Core DeepSpeed4Science Team:** + +Shuaiwen Leon Song (DeepSpeed4Science lead), Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Xiaoxia (Shirley) Wu, Masahiro Tanaka, Martin Cai, Adam Graham, Charlie Zhou, Yuxiong He (DeepSpeed team lead) + +**Our Founding Collaborators (in alphabetical order):** + +**Argonne National Lab team:** Rick Stevens, Cristina Negri, Rao Kotamarthi, Venkatram Vishwanath, Arvind Ramanathan, Sam Foreman, Kyle Hippe, Troy Arcomano, Romit Maulik, Maxim Zvyagin, Alexander Brace, Yuntian Deng, Bin Zhang, Cindy Orozco Bohorquez, Austin Clyde, Bharat Kale, Danilo Perez-Rivera, Heng Ma, Carla M. Mann, Michael Irvin, J. Gregory Pauloski, Logan Ward, Valerie Hayot, Murali Emani, Zhen Xie, Diangen Lin, Maulik Shukla, Weili Nie, Josh Romero, Christian Dallago, Arash Vahdat, Chaowei Xiao, Thomas Gibbs, Ian Foster, James J. Davis, Michael E. Papka, Thomas Brettin, Anima Anandkumar + +**AMD:** Ivo Bolsen, Micheal Schulte, Bo Begole, Angela Dalton, Steve Reinhart, Ashwin Aji, Jalal Mahmud, Mahesh Balashibramanian + +**Brookhaven National Lab team:** Adolfy Hoisie, Shinjae Yoo, Yihui Ren. + +**Columbia University OpenFold team:** Mohammed AlQuraishi, Gustaf Ahdritz + +**Microsoft Research AI4Science team:** Christopher Bishop, Bonnie Kruft, Max Welling, Tie-Yan Liu, Christian Bodnar, Johannes Brandsetter, Wessel Bruinsma, Chan Cao, Yuan-Jyue Chen, Peggy Dai, Patrick Garvan, Liang He, Elizabeth Heider, PiPi Hu, Peiran Jin, Fusong Ju, Yatao Li, Chang Liu, Renqian Luo, Qi Meng, Frank Noe, Tao Qin, Janwei Zhu, Bin Shao, Yu Shi, Wenlei Shi, Gregor Simm, Megan Stanley, Lixin Sun, Yue Wang, Tong Wang, Zun Wang, Lijun Wu, Yingce Xia, Leo Xia, Shufang Xie, Shuxin Zheng, Jianwei Zhu + +**Oakridge National Lab team:** Prassana Balaprakash, Georgia Tourass + +**Princeton University:** William Tang, Kyle Felker, Alexey Svyatkovskiy (Microsoft liaison) + +**Rutgers University:** Hang Liu + +**WebXT Weather team:** Pete Luferenko, Divya Kumar, Jonathan Weyn, Ruixiong Zhang, Sylwester Klocek, Volodymyr Vragov diff --git a/blogs/deepspeed4science/japanese/README.md b/blogs/deepspeed4science/japanese/README.md new file mode 100644 index 000000000000..2599289a86df --- /dev/null +++ b/blogs/deepspeed4science/japanese/README.md @@ -0,0 +1,156 @@ +
+ +# DeepSpeed4Scienceイニシアティブ: 洗練されたAIシステムのテクノロジーにより大規模な科学的発見を可能に + +
+ +*こちらは英語ブログ[Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)の公式の翻訳です* + +
+ + +*図1:DeepSpeed4Scienceのアプローチ: 汎用の言語モデルのサポートを超え、科学的発見とその複雑さの解決に特化したAI技術を開発* +
+ +DeepSpeed4Science を引用するには、こちらの[white paper](https://arxiv.org/abs/2310.04610)を引用してください: + +``` +@article{song2023deepspeed4science, + title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies}, + author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others}, + journal={arXiv preprint arXiv:2310.04610}, + year={2023} +} +``` + +## はじめに + +自然の出来事をモデル化し予測する深層学習の能力は急速に高まっており、次の10年間に、自然科学に革命を起こすかも知れません。薬の開発から再生可能エネルギーまでの各セクターで、大きな進展をもたらす新しい科学的探求の時代が到来するでしょう。「地球上のすべての人と組織がもっと多くのことを成し遂げられるようにする」というMicrosoftのミッションに従い、この機会に、[DeepSpeedチーム](https://www.deepspeed.ai/)では[DeepSpeed4Science](https://deepspeed4science.ai/)という新しいイニシアティブを立ち上げました。これは、AIシステム技術のイノベーションを通じて他に類を見ない技術を構築し、様々な分野の専門家が、科学分野における大きな謎を解き明かす手助けをすることを目指しています。 + +[DeepSpeed](https://www.deepspeed.ai/)システムは、Microsoftが開発した、AI分野をリードするオープンソースのAIシステムのフレームワークであり、多様なAIハードウェア上での深層学習の訓練と推論において、前例のない規模と速度を実現します。図1は、この新しいDeepSpeed4Scienceイニシアティブでの基本的なアプローチを示しています。DeepSpeedの現在の柱となる技術(訓練、推論、圧縮)を基盤として活用しつつ、DeepSpeed4Scienceでは、大規模言語モデル(LLM)を加速するための汎用の技術的アプローチを超え、科学的発見を加速する目的で新たに構築された、一連のAIシステム技術を提供します。私たちは、重要な科学的ミッションを推進している、代表的な科学分野向けAIモデルを所有する内外のチームと連携し、ドメイン固有のAIシステムの課題を特定し、解決していきます。これには、気候科学、薬物設計、生物学的理解、分子動力学シミュレーション、がんの診断と監視、触媒/材料の発見、およびその他の分野が含まれます。 + +私たちの長期的なビジョンは、DeepSpeed4Scienceを、科学的発見をサポートする先進的なAIシステム技術を共有するための新しいソフトウェアプラットフォームおよび統一的なリポジトリに発展させることです。DeepSpeed4Scienceは、Microsoftの[AI for Good](https://www.microsoft.com/en-us/ai/ai-for-good)のコミットメントを反映して、包括的に設計されています。このことは、AI4Scienceへのもっとも重要な投資の成果として構築された、様々な代表的モデルへの、DeepSpeed4Scienceイニシアティブによるサポートに現れています。このブログでは、DeepSpeed4Scienceが、構造生物学の研究における2つの重要なシステムの課題にどのように対処するかを紹介します:(1) Evoformer中心のタンパク質構造予測モデルをスケールアップする際に極めて大きなメモリが必要となる問題を解決し、(2) パンデミックを引き起こすウイルスの進化の様子をよりよく理解するための非常に長いシーケンスのサポートを可能にします。 + +## 主要な初期コラボレータ + +DeepSpeed4Scienceによる新しいシステム技術はAI駆動の幅広い科学研究を強化するものです。現在、DeepSpeed4Scienceは、[Microsoft Research AI4Science](https://www.microsoft.com/en-us/research/lab/microsoft-research-ai4science/)、[Microsoft WebXT/Bing](https://www.msn.com/en-us/weather/forecast/)、[U.S. DoE National Labs](https://www.energy.gov/national-laboratories)、および複数の大学のいくつかの重要な科学モデルをサポートしています。 + +### Microsoft内のパートナーシップ + +#### 科学基盤モデル (Scientific Foundation Model, SFM), Microsoft Research AI4Science + +
+ + + +*図2: 科学基盤モデル (Scientific foundation model, SFM) とその探索: Distributional Graphormer* +
+ +科学的基盤モデル(SFM)は、多様なインプット、複数の科学領域(薬物、材料、生物学、健康など)、および計算タスクをサポートする、自然科学的発見を強化するための統一された大規模基盤モデルを作成することを目的としています。DeepSpeed4Scienceパートナーシップは、[Distributional Graphormer](https://www.microsoft.com/en-us/research/blog/distributional-graphormer-toward-equilibrium-distribution-prediction-for-molecular-systems/)などのMicrosoftの新しい生成AI手法などのプロジェクトに関する、SFMチームの継続的な研究を強化するための新しい訓練および推論テクノロジーを提供します。 + +#### ClimaX, Microsoft Research AI4Science + +
+ + +*図3: 天気・気候の多様なモデリングタスクのための最初の基盤モデルClimaX* +
+ +気候の変化は、より頻繁な異常気象を引き起こしています。悪影響を軽減するため、これらのイベントが発生する場所を予測することがますます重要になっています。[ClimaX](https://www.microsoft.com/en-us/research/group/autonomous-systems-group-robotics/articles/introducing-climax-the-first-foundation-model-for-weather-and-climate/)は、さまざまな気象および気候モデリングタスクを実行するために設計された最初の基盤モデルです。さまざまな変数と解像度を持つ多くの異なるデータセットを扱えるため、天気予報の精度が向上する可能性があります。DeepSpeed4Scienceは、非常に大きな高解像度画像データ(数十から数百ペタバイトなど)を長いシーケンスで処理しながら、より大きな基盤モデルを効率的に事前訓練/ファインチューニングするためのClimaXの新しいシステムサポートを提供しています。 + +#### 分子動力学と機械学習型力場(Molecular Dynamics and Machine Learning Force Field),Microsoft Research AI4Science + +
+ + +*図4: 100万ステップの分子動力学シミュレーション: RBD-proteinとprotein inhibitorの相互作用* +
+ +このプロジェクトは、古典的な分子動力学の効率とスケーラビリティを維持しながら、[AIを利用した力場モデル](https://www.microsoft.com/en-us/research/publication/ai2bmd-efficient-characterization-of-protein-dynamics-with-ab-initio-accuracy/)を使用して、原理に基づく精度(ab initio accuracy)に近い精度で大規模(原子数で100万規模)な分子システムの力学をシミュレートします。このシミュレーションは、化学的に重要なイベントを観察するのに十分な長さの軌道を生成できる効率を実現しています。通常、このプロセスには数百万から数十億の推論ステップが必要です。これは、グラフニューラルネットワーク(GNN)+ LLMモデルの推論速度を最適化する上で大きな課題となります。DeepSpeed4Scienceは、この課題に対して、新しいシステムサポートを提供します。 + +#### 天気 from Microsoft Start, Microsoft WebXT/Bing + +
+ + +*図5: Microsoft Startにおける降水予想 (次の4時間について4分ごと)* +
+ +[天気 from Microsoft Start](https://www.msn.com/en-us/weather/forecast/)は、[ユーザーがライフスタイル、健康、仕事、活動についてより適切な決定を下せるよう](https://blogs.windows.com/windowsexperience/2022/08/31/microsoft-joins-noaas-weather-ready-nation-ambassador-initiative-to-help-improve-americas-readiness-and-response-to-weather-events/)、正確な気象情報を提供します。 (1 時間ごとに複数回更新される、10 日間に渡る正確かつグローバルな天気予報など)。 以前にも、この天気予報は、DeepSpeedの技術を使用して、マルチ GPU を用いた訓練を高速化していました。現在、DeepSpeed4ScienceはMicrosoft WebXT気象チームと協力して、最先端の機能と更なる改善により、マイクロソフトの気象サービスをさらに強化しています。 + +### 外部のコラボレータ + +DeepSpeed4Scienceは、構造生物学研究のための2つの先駆的なLLMベースのAIモデルを扱うことから始まりました: オープンソースのハイフィデリティタンパク質構造予測モデルであるコロンビア大学の[OpenFold](https://openfold.io/)と、SARS-CoV-2(COVID-19)ゲノムの進化を学習する、[Gordon Bell Special Prize](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)を受賞したゲノム用言語モデルである[アルゴンヌ国立研究所](https://www.anl.gov/)の[GenSLMs](https://github.com/ramanathanlab/genslm)です。次のセクションでは、今日のAI主導の構造生物学研究が直面している2つの一般的なAIシステムの課題を紹介し、DeepSpeed4Scienceが科学研究をどのように強化したかについて説明します。 + +またDeepSpeed4Scienceは最近、より多様な科学モデルをサポートするために、その対象を拡大しました。たとえば、[Aurora Exascaleシステム](https://www.anl.gov/aurora)で、1兆パラメータの科学モデルを訓練するアルゴンヌ国立研究所との協力にあたって、DeepSpeed4Scienceテクノロジーは、求められるパフォーマンス要件とスケーラビリティを実現するのに重要な役割を果たします。さらに、DeepSpeed4Scienceは、がんの調査に関して、[オークリッジ国立研究所](https://ai-roadmap.ornl.gov/)および[国立がん研究所(NCI)](https://www.cancer.gov/)と協力することにより、[MOSSAICプロジェクト](https://www.olcf.ornl.gov/tag/mossaic/)の非構造化臨床テキストからの情報の高信頼度抽出と分類にも用いられます。さらに、DeepSpeed4Scienceのテクノロジーは、[ブルックヘブン国立研究所](https://www.bnl.gov/world/)にも採用され、LLMを使用してより現実的なシミュレーションデータを生成することにより、クリーンエネルギー研究用の大規模なデジタルツインモデルの開発をサポートします。外部のコラボレータとその科学ミッションに関するより詳細な情報は、[deepspeed4science.ai](https://deepspeed4science.ai/)に掲載しています。 + +## パートナーシップの事例 + +### 事例(I): DeepSpeed4ScienceのDS4Sci_EvoformerAttentionにより、Evoformerで構成された生物学モデルをスケールアップする際のメモリ問題を解決 + +
+ + + +*図6: モデル学習の進行に伴うPDB chain 7B3A_AについてのOpenFoldの予測* +
+ +[OpenFold](https://github.com/aqlaboratory/openfold)は、DeepMindによる[AlphaFold2](https://alphafold.com/)をオープンソースで再現したものであり、新しいデータセットでAlphaFold2を訓練またはファインチューニングすることを可能にします。研究者は、これを使用して、AlphaFold2をゼロから再訓練して新しいモデルパラメータを作成し、AlphaFold2の初期訓練フェーズを研究し(図6)、新しいタンパク質フォールディングシステムを開発しました。 + +
+ + +*図7: OpenFoldで可能な最大の訓練サンプル次元を持つ多重配列アライメント(MSA)アテンションカーネル(バイアス付き)のバリエーションを訓練するために必要なピークメモリ。(左)AlphaFold2で使用されているEvoformerAttentionを用いたオリジナルのOpenFold実装。この種のタンパク質構造予測モデルの訓練/推論では、極めて多くのメモリが必要とされることは一般的な課題となっている。特に、最新技術として広く知られるFlashAttentionでも、このような科学研究のためのアテンションのバリエーションを効果的にサポートできない。(右)DS4Sci_EvoformerAttentionと呼ばれるDeepSpeed4Scienceの新しい技術は、精度を落とすことなく、OpenFoldモデルの訓練に必要なピークメモリを1/13に大幅に削減する。* +
+ +OpenFoldには、最先端のシステムテクノロジーを使用したパフォーマンスとメモリの最適化が含まれていますが、AlphaFold2をゼロから訓練することは依然として大きな計算コストがかかります。現段階でのモデルは、パラメータ数の絶対値は小さい(9,300万個)のですが、極めて大きなアクティベーションを持つアテンションのバリエーションが含まれています。標準的なAlphaFold2訓練のファインチューニングフェーズでは、これらのバリエーションのうちのの1つが生成したロジットテンソル(入力としてモデルに供給されるディープタンパク質MSAに対応するように設計されたもの)は、半精度浮動小数で12GBを超え、同等のサイズの言語モデルが使用するメモリを大幅に上回ります。Activation checkpointingや、DeepSpeed ZeRO 最適化などの手法を使用しても、非常に多くのメモリが必要とされるため、モデルを訓練できるシーケンスの長さと MSA の深さが大幅に制限されます。さらに、近似解を与えるような戦略を用いると、モデルの精度と収束に大きな影響を与える可能性があり、それでもメモリが爆発的に増加します(図7の左側のバー(オレンジ色))。 + +DeepSpeed4Scienceは、構造生物学研究(タンパク質構造予測や平衡分布予測など)におけるこの一般的なシステムの課題に対処するために、このカテゴリの科学モデルに広く見られるアテンションのバリエーション(つまりEvoformerAttention)用にカスタマイズされた正確なアテンションのカーネルを設計することにより、このメモリの非効率性の問題に対処しています。具体的には、高度なフュージョン/タイリング戦略とオンザフライのメモリ削減方法によって可能になるメモリ効率の高いDS4Sci_EvoformerAttentionカーネルのセットを、高品質の機械学習プリミティブとして、より広いコミュニティ向けに作成しました。これらをOpenFoldに組み込むことで、訓練中の速度が大幅に向上し、訓練と推論のためのモデルのピークメモリが大幅に削減されます。これにより、OpenFoldはより大きく、より複雑なモデル、より長いシーケンスで実験し、より幅広いハードウェアで訓練することができます。この技術の詳細については、[こちら](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/)をご覧ください。 + +### 事例(II): DeepSpeed4Scienceのシステムとアルゴリズムの両方からのアプローチにより、ゲノム基盤モデルでの非常に長い系列の使用をサポート + +
+ + +*図8: GenSLMs:2022年ACM Gordon Bell Special Prize受賞COVIDゲノム用モデル(GPT-NeoXに基づく25B/33Bモデル)。SARS-CoV-2ゲノムの生物学的に意味のある特性を記述する潜在空間を学習するために使用される。このGIFは、重要なタンパク質ファミリーであるリンゴ酸デヒドロゲナーゼ(malate dehydrogenase)を可視化し、配列の長さやGC含量(アデニンとチミンと比較した核酸グアニンとシトシンの含量の比率。これはDNA鎖が熱に耐える能力を測るものである。)などの重要な特徴で色付けされた潜在空間の投影を表示している。* +
+ +アルゴンヌ国立研究所が開発し、[2022年ACM Gordon Bell Special Prize](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)を受賞したゲノム用言語モデルである[GenSLMs](https://github.com/ramanathanlab/genslm)は、ゲノムデータに大規模言語モデル(LLM)を適用することにより、SARS-CoV-2(COVID-19)ゲノムの進化を学習します。これは、パンデミックを引き起こすウイルス、特にSARS-CoV-2の新たに出現する亜種を特定し、分類する方法を変えるように設計されています。GenSLMsは、他の予測タスクに一般化できる最初のゲノム基盤モデルの1つです。潜在空間をうまく表現することにより、GenSLMsはウイルス配列だけでなく新しいドメインに適用し、細菌性病原体や真核生物をモデル化する能力を拡大し、機能、経路のメンバーシップ、進化的関係などを理解することができます。この科学的目標を達成するために、GenSLMsおよび同様のモデルは、[FlashAttention](https://arxiv.org/abs/2307.08691)のように、長いシーケンスのための一般的な戦略では扱うことが困難なレベルの、非常に長いシーケンスサポートを、訓練と推論の両方に対して必要とします。DeepSpeed4Scienceの新しい設計により、科学者はより長いシーケンスでモデルを構築および訓練できるようになり、以前は扱えなかった科学探索が可能になりました。 + +
+ + +*図9: 異なるスケールで異なるフレームワークがサポートする2つのGenSLMsモデルの最大シーケンス長。1ノードあたり8個の40G A100 GPUを搭載したNVIDIA DGXノードを使用。* +
+ +システムレベルでは、非常に長いシーケンスをサポートするための最新の[Megatron-DeepSpeedフレームワーク](https://github.com/deepspeedai/Megatron-DeepSpeed)を、[他の新しい最適化とともにリリースします](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support)。科学者は、(アテンションマスクと位置の埋め込みに関する)新しく追加されたメモリ最適化手法、テンソル並列処理、パイプライン並列処理、シーケンス並列処理、ZeROスタイルのデータ並列処理、モデル状態のオフロードなどの技術を相乗的な組み合わせにより、GenSLMsのような大規模な科学モデルをはるかに長いシーケンスで訓練できるようになりました。図9は、新しいリリースにより、GenSLMsの25Bおよび33Bモデルで、以前のMegatron-DeepSpeedよりもそれぞれ最大12倍および14倍の最長シーケンス長を処理できることを示しています。サポートされているシーケンス長に関しては、この新しいMegatron-DeepSpeedは、25Bモデルと33Bモデルでそれぞれ最大9.8倍と9.1倍でNVIDIAのMegatron-LMを大幅に上回っています。たとえば、GenSLMsの25Bモデルは、64個のGPUでのアルゴンヌチームの元の42Kシーケンス長と比較して、512Kのヌクレオチド配列で訓練できるようになりました。これにより、精度を損なうことなく、モデルの品質と科学的発見の範囲が大幅に向上します。Relative position embeddingなどのアルゴリズム戦略を必要とする科学者向けの追加サポートも、[このリリース](https://deepspeed4science.ai/2023/09/18/model-showcase-genslms/)に統合されています。 + +## まとめとロードマップ + +DeepSpeed4Scienceイニシアティブを、いくつかのR&Dのハイライトや成果と共に発表できることを嬉しく思います。本日から、外部の協力者に関する情報や、現在および将来のDeepSpeed4Scienceテクノロジーリリースなど、新しいイニシアティブでの活動を[deepspeed4science.ai](https://deepspeed4science.ai/)上で進めていきます。私たちの高レベルな目標の1つは、大規模な科学的発見のための主要なシステムの問題点に広く対処するAIシステムテクノロジーを一般化することです。世界中の科学者によって、オープンソースのソフトウェアを通じてDeepSpeed4Scienceによって利用可能になる新機能が活用されることを願っています。科学的発見の障害となるAIシステム設計の課題を解決していくことを楽しみにしています。AI4Scienceの有望な未来を築くために、皆様の参加を歓迎します。お問い合わせはまでお願いします。問題の報告や、PRを通じての貢献、ディスカッションへの参加は、[DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed/)でお願いします。 + +## 謝辞 + +**Core DeepSpeed4Science Team:** + +Shuaiwen Leon Song (DeepSpeed4Science lead), Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Xiaoxia (Shirley) Wu, Masahiro Tanaka, Martin Cai, Adam Graham, Charlie Zhou, Yuxiong He (DeepSpeed team lead) + +**Our Founding Collaborators (in alphabetical order):** + +**Argonne National Lab team:** Rick Stevens, Cristina Negri, Rao Kotamarthi, Venkatram Vishwanath, Arvind Ramanathan, Sam Foreman, Kyle Hippe, Troy Arcomano, Romit Maulik, Maxim Zvyagin, Alexander Brace, Yuntian Deng, Bin Zhang, Cindy Orozco Bohorquez, Austin Clyde, Bharat Kale, Danilo Perez-Rivera, Heng Ma, Carla M. Mann, Michael Irvin, J. Gregory Pauloski, Logan Ward, Valerie Hayot, Murali Emani, Zhen Xie, Diangen Lin, Maulik Shukla, Weili Nie, Josh Romero, Christian Dallago, Arash Vahdat, Chaowei Xiao, Thomas Gibbs, Ian Foster, James J. Davis, Michael E. Papka, Thomas Brettin, Anima Anandkumar + +**AMD:** Ivo Bolsen, Micheal Schulte, Bo Begole, Angela Dalton, Steve Reinhart, Ashwin Aji, Jalal Mahmud, Mahesh Balashibramanian + +**Brookhaven National Lab team:** Adolfy Hoisie, Shinjae Yoo, Yihui Ren. + +**Columbia University OpenFold team:** Mohammed AlQuraishi, Gustaf Ahdritz + +**Microsoft Research AI4Science team:** Christopher Bishop, Bonnie Kruft, Max Welling, Tie-Yan Liu, Christian Bodnar, Johannes Brandsetter, Wessel Bruinsma, Chan Cao, Yuan-Jyue Chen, Peggy Dai, Patrick Garvan, Liang He, Elizabeth Heider, PiPi Hu, Peiran Jin, Fusong Ju, Yatao Li, Chang Liu, Renqian Luo, Qi Meng, Frank Noe, Tao Qin, Janwei Zhu, Bin Shao, Yu Shi, Wenlei Shi, Gregor Simm, Megan Stanley, Lixin Sun, Yue Wang, Tong Wang, Zun Wang, Lijun Wu, Yingce Xia, Leo Xia, Shufang Xie, Shuxin Zheng, Jianwei Zhu + +**Oakridge National Lab team:** Prassana Balaprakash, Georgia Tourass + +**Princeton University:** William Tang, Kyle Felker, Alexey Svyatkovskiy (Microsoft liaison) + +**Rutgers University:** Hang Liu + +**WebXT Weather team:** Pete Luferenko, Divya Kumar, Jonathan Weyn, Ruixiong Zhang, Sylwester Klocek, Volodymyr Vragov diff --git a/blogs/deepspeed4science/media/Figure1.png b/blogs/deepspeed4science/media/Figure1.png new file mode 100644 index 000000000000..614c4b40d6a1 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure1.png differ diff --git a/blogs/deepspeed4science/media/Figure2-1.png b/blogs/deepspeed4science/media/Figure2-1.png new file mode 100644 index 000000000000..bb0b8d9206d1 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure2-1.png differ diff --git a/blogs/deepspeed4science/media/Figure2-2.gif b/blogs/deepspeed4science/media/Figure2-2.gif new file mode 100644 index 000000000000..0890be7d7e31 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure2-2.gif differ diff --git a/blogs/deepspeed4science/media/Figure3.png b/blogs/deepspeed4science/media/Figure3.png new file mode 100644 index 000000000000..465e80e15a25 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure3.png differ diff --git a/blogs/deepspeed4science/media/Figure4.gif b/blogs/deepspeed4science/media/Figure4.gif new file mode 100644 index 000000000000..b45a5f28fd36 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure4.gif differ diff --git a/blogs/deepspeed4science/media/Figure5.gif b/blogs/deepspeed4science/media/Figure5.gif new file mode 100644 index 000000000000..a26c20103269 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure5.gif differ diff --git a/blogs/deepspeed4science/media/Figure6-1.png b/blogs/deepspeed4science/media/Figure6-1.png new file mode 100644 index 000000000000..65f7f9309f71 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure6-1.png differ diff --git a/blogs/deepspeed4science/media/Figure6-2.gif b/blogs/deepspeed4science/media/Figure6-2.gif new file mode 100644 index 000000000000..b50588c227d7 Binary files /dev/null and b/blogs/deepspeed4science/media/Figure6-2.gif differ diff --git a/blogs/deepspeed4science/media/Figure7.jpg b/blogs/deepspeed4science/media/Figure7.jpg new file mode 100644 index 000000000000..eaa92007268b Binary files /dev/null and b/blogs/deepspeed4science/media/Figure7.jpg differ diff --git a/blogs/deepspeed4science/media/Figure8.gif b/blogs/deepspeed4science/media/Figure8.gif new file mode 100644 index 000000000000..624384910f2a Binary files /dev/null and b/blogs/deepspeed4science/media/Figure8.gif differ diff --git a/blogs/deepspeed4science/media/Figure9.png b/blogs/deepspeed4science/media/Figure9.png new file mode 100644 index 000000000000..f00fd9b6917f Binary files /dev/null and b/blogs/deepspeed4science/media/Figure9.png differ diff --git a/blogs/huggingface-tp/README.md b/blogs/huggingface-tp/README.md new file mode 100644 index 000000000000..bb6b62047202 --- /dev/null +++ b/blogs/huggingface-tp/README.md @@ -0,0 +1,241 @@ +
+ +# Automatic Tensor Parallel (AutoTP) Training of Hugging Face models + +
+ + +# Introduction + +Tensor parallelism (TP) is an important memory optimization for training large-scale deep learning models. Despite the popularity of training Hugging Face (HF) [models](https://huggingface.co/models), the model scaling options for **[HF trainer](https://huggingface.co/docs/transformers/main_classes/trainer)** was previously limited to sharded data parallelism through [ZeRO](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)/[FSDP](https://huggingface.co/docs/accelerate/usage_guides/fsdp). While ZeRO3 offers superior memory efficiency, it incurs significant communication costs. ZeRO (1/2) has lower communication overhead, but in the case of very large models, it cannot be used directly due to memory limitations. Therefore, combining TP with ZeRO (1/2) offers more balanced options for memory and performance. Moreover, through TP, we can alleviate the batch scaling limitations imposed by ZeRO/FSDP. + +We are pleased to announce that DeepSpeed now provides native automatic tensor parallel training for Hugging Face (HF) transformers. This new feature builds on DeepSpeed's [AutoTP](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) mechanism, which was previously restricted to inference. AutoTP training can be combined with ZeRO to unlock unprecented efficiency benefits for HF model post-training, including: + +**1**. Model scaling with lower communication costs than FSDP/ZeRO3 (e.g., use AutoTP + ZeRO1 to achieve ZeRO3 memory savings). + +**2**. Batch size scaling for faster training and increased throughput. + +**3**. Context length scaling to enable new application scenarios. + +We have integrated AutoTP training with ZeRO1 & ZeRO2, with ZeRO3 integration on the way. AutoTP training is available in DeepSpeed versions >= 0.16.4 + +# Batch Scaling with AutoTP Training + ZeRO +The following is a batch scaling experiment of Llama3 8B training conducted on [Gaudi2 Accelerator](https://www.intel.com/content/www/us/en/products/details/processors/ai-accelerators/gaudi.html). + + +
+ + + + +*Figure 1. Batch scaling experiment on Gaudi2, showing throughput performance improvements from 2 to 4 cards by combining AutoTP and ZeRO. The used mbs is the max possible value with the given config. A higher speedup indicates better performance.* + +
+ + + +
+ + + + +*Figure 2. Model training with AutoTP + ZeRO* + +
+ + +Figure 2 illustrates the basic flowchart, The division of TP and ZeRO is implemented through the AutoTP parser and ZeRO Wrapper in [Accelerate](https://github.com/huggingface/accelerate.git). Besides, The TP-based dataloader and save mechanism are both supported in DeepSpeed and Accelerate. + +# Usage + + + +Although we evaluated AutoTP training with Llama2 & Llama3 models in this blog, we expect compatibility with other Hugging Face models, especially [those](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) previously validated with AutoTP inference. + +**Requirements** +- `deepspeed >= 0.16.4` +- `transformers >= 4.50.1` +- `accelerate >= 1.6.0` + + **Enable TP training** + +Similar to ZeRO, AutoTP training is enabled using the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/) by specifying ```[tensor_parallel][autotp_size]```. +``` + "ZeRO_optimization": { + "stage": 1, + "gather_16bit_weights_on_model_save": true, + ... + }, + "tensor_parallel":{ + "autotp_size": 4 + }, +``` + +The parallel configuration follows this logic: + + +``` +tp_size = auto_tp_size +dp_size = num_gpus / tp_size +``` + +Note that the global_batch_size (gbs) changes with different TP settings: +``` +gbs (only dp) = per_device_batch_size * n_gpus * gradient_accumulation_steps + +gbs (dp with tp) = per_device_batch_size * n_gpus / tp_size * gradient_accumulation_steps +``` + + + + + + + + **Save Model** + + + + +Saving checkpoints and model files is fully compatible with HF transformers. The [trainer.save_model()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_model) method saves the original model. Ensure ```gather_16bit_weights_on_model_save``` is set to ```true```in the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/). +```gather_16bit_weights_on_model_save=true in config. + "ZeRO_optimization": { + ... + "gather_16bit_weights_on_model_save": true, + }, +``` + +``` +trainer.save_model(your_saved_path) +``` +Models saved this way can be directly used for HF format inference without intermediate transformations. + + + + **Saving Checkpoints and Resuming** + + + +Saving Checkpoints remains compatible with HF transformers. Use [trainer.save_state()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_state) or set the save interval for automatic saving, which can be used to resume training. +``` +trainer.train(resume_from_checkpoint="your_saved_path/checkpoint-1200") +``` + +# Example +We validated AutoTP training using supervised finetune training (SFT) task: [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca). The original benchmark model used in this project is Llama2-7B. The example code is also available [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/tensor_parallel) + + +**Training Loss curve** + + + +The following loss curves depict SFT training, where gbs is uniformly set to 32, and other configurations match the default experiment settings from ([stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca)). The loss curves are largely consistent across the following setups: + + - ZeRO3 + - TP + disable ZeRO + - ZeRO1 and ZeRO1 + AutoTP + - ZeRO2 and ZeRO2 + AutoTP + + + + + +
+ + + + +*Figure 3. Loss curve of ZeRO3 stage training (gbs=32, dp=8)* + +
+
+ + + +*Figure 4. Loss curve of AutoTP training (gbs=32, tp=8)* +
+ +
+ + + +*Figure 5. Loss curve of AutoTP + ZeRO1 training (gbs=32, dp=2, tp=4)* +
+ + +
+ + + +*Figure 6. Loss curve of AutoTP + ZeRO2 training (gbs=32, dp=2, tp=4)* + + +
+ + + **Resuming Training** + + + We tested recovery training curves from step 1200 in AutoTP + ZeRO1 and AutoTP + ZeRO2, which align with the original training curves. + +
+ + + +*Figure 7. AutoTP + ZeRO1 resuming training* + + + +*Figure 8. AutoTP + ZeRO2 resuming training* + +
+ + + + **Model Evaluation** + + + We conducted inference evaluations for the [MMLU task](https://github.com/EleutherAI/lm-evaluation-harness). + In MMLU, the scores for AutoTP + ZeRO1 and ZeRO1, as well as AutoTP + ZeRO2 and ZeRO2, are consistent, showing a fixed improvement over the pre-training model before SFT. + + +
+ + +| Groups | Version | Filter | n-shot | Metric | Model before SFT | ZeRO1 DP8 training | ZeRO1 TP4 DP2 training | ZeRO2 DP8 training | ZeRO2 TP4DP2 training | +|--------|---------|--------|--------|--------|-----------------------|--------------------|------------------------|--------------------|------------------------| +| mmlu | 2 | none | | acc | 0.4185 ± 0.0041 | 0.4472 ± 0.0041 | 0.4444 ± 0.0041 | 0.4543 ± 0.0041 | 0.4529 ± 0.0041 | +| - humanities | 2 | none | | acc | 0.3979 ± 0.0069 | 0.4185 ± 0.0070 | 0.4145 ± 0.0069 | 0.4274 ± 0.0070 | 0.4272 ± 0.0070 | +| - other | 2 | none | | acc | 0.4712 ± 0.0089 | 0.5249 ± 0.0087 | 0.5182 ± 0.0088 | 0.5282 ± 0.0087 | 0.5269 ± 0.0087 | +| - social sciences | 2 | none | | acc | 0.4742 ± 0.0089 | 0.5070 ± 0.0089 | 0.5083 ± 0.0088 | 0.5151 ± 0.0088 | 0.5115 ± 0.0089 | +| - stem | 2 | none | | acc | 0.3428 ± 0.0084 | 0.3549 ± 0.0084 | 0.3539 ± 0.0084 | 0.3622 ± 0.0084 | 0.3609 ± 0.0084 | + +*Table 1. MMLU score with Llama2-7B inference* + +
+ + + + + +# Miscellaneous + +If users define their own dataloader, please ensure data consistency within ```deepspeed.utils.groups.get_tensor_model_parallel_group()```. DeepSpeed provides basic validation functions to assist with this. + +Furthermore, if users are not using transformers library, you can replace the ```TensorParallel_Layer``` layer and its subclasses as needed. See ```prepare_tp_model``` function in ```unit/model_parallelism/test_autotp_training.py```. Users can also define different shard and gather for subclasses of ```TensorParallel_Layer.``` + + + + + +# Ongoing Work +- **Optimization**: Communication/Activation optimization. +- **Usability**: Support the [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple the AutoTP parser, and expand model testing. + - [UPDATE] We now support [custom partitioning](https://deepspeed.readthedocs.io/en/latest/training.html#custom-layer-specs) in the same spirit as HF's partitioning plan, and will build Transformers TP plan support on top of that ([PR](http://github.com/deepspeedai/DeepSpeed/pull/7806)). + - [UPDATE] DeepSpeed now automatically detects and uses HuggingFace's built-in `base_model_tp_plan` (e.g. Llama, Qwen, Gemma2). When a model provides a `tp_plan`, AutoTP uses it directly without requiring `preset_model` or `partition_config`. Currently `colwise` and `rowwise` partition types are supported. See the [AutoTP training tutorial](https://deepspeed.readthedocs.io/en/latest/training.html#huggingface-tp-plan) for details ([PR](https://github.com/deepspeedai/DeepSpeed/pull/7901)). + +Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending. +Welcome bug reports, enhancement, and additional model training examples. + +# Contributors +This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft. diff --git a/blogs/huggingface-tp/media/batchscale.png b/blogs/huggingface-tp/media/batchscale.png new file mode 100644 index 000000000000..37a6eeeade9e Binary files /dev/null and b/blogs/huggingface-tp/media/batchscale.png differ diff --git a/blogs/huggingface-tp/media/flowchart.png b/blogs/huggingface-tp/media/flowchart.png new file mode 100644 index 000000000000..b7115df8c213 Binary files /dev/null and b/blogs/huggingface-tp/media/flowchart.png differ diff --git a/blogs/huggingface-tp/media/tp8.png b/blogs/huggingface-tp/media/tp8.png new file mode 100644 index 000000000000..0ae6e925eef1 Binary files /dev/null and b/blogs/huggingface-tp/media/tp8.png differ diff --git a/blogs/huggingface-tp/media/tpzero1.png b/blogs/huggingface-tp/media/tpzero1.png new file mode 100644 index 000000000000..b7f21a9e2a5f Binary files /dev/null and b/blogs/huggingface-tp/media/tpzero1.png differ diff --git a/blogs/huggingface-tp/media/tpzero2.png b/blogs/huggingface-tp/media/tpzero2.png new file mode 100644 index 000000000000..7a3333990835 Binary files /dev/null and b/blogs/huggingface-tp/media/tpzero2.png differ diff --git a/blogs/huggingface-tp/media/zero1tpload.png b/blogs/huggingface-tp/media/zero1tpload.png new file mode 100644 index 000000000000..9af5622f908d Binary files /dev/null and b/blogs/huggingface-tp/media/zero1tpload.png differ diff --git a/blogs/huggingface-tp/media/zero2tpload.png b/blogs/huggingface-tp/media/zero2tpload.png new file mode 100644 index 000000000000..69f002abf474 Binary files /dev/null and b/blogs/huggingface-tp/media/zero2tpload.png differ diff --git a/blogs/huggingface-tp/media/zero3.png b/blogs/huggingface-tp/media/zero3.png new file mode 100644 index 000000000000..62e6eb712151 Binary files /dev/null and b/blogs/huggingface-tp/media/zero3.png differ diff --git a/blogs/intel-inference/README.md b/blogs/intel-inference/README.md new file mode 100644 index 000000000000..917117523abc --- /dev/null +++ b/blogs/intel-inference/README.md @@ -0,0 +1,143 @@ + +# Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed + +> This blog is co-published with Intel and can also be viewed on [Intel's website](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html). + +## Introduction + +Transformer models have revolutionized natural language processing with their ability to capture complex semantic and syntactic relationships. However, these models also pose significant challenges for efficient inference, especially for large language models (LLMs) that have billions of parameters. For example, running half-precision inference of Megatron-Turing 530B would require 40 A100-40GB GPUs [1]. To address challenges associated with the inference of large-scale transformer models, the DeepSpeed team at Microsoft* developed DeepSpeed Inference [2]. It provides high-performance multi-GPU inferencing capabilities and introduces several features to efficiently serve transformer-based PyTorch models using GPU. Today, we are very excited to share that DeepSpeed Inference has been implemented for the 4th Gen Intel® Xeon® scalable processor. + +## 4th Gen Intel Xeon Processor + +Intel launched the 4th gen Intel Xeon processor in January 2023. This CPU has built-in accelerators for AI, data analytics, networking, storage and HPC. Tile Matrix Multiplication (TMUL) is the built-in AI accelerator. It executes the Intel® Advanced Matrix Extensions (Intel®AMX). Intel AMX can significantly speed up deep learning (DL) applications, both in inference and training. Other notable new features in 4th gen Intel Xeon processors that can speed up DL applications include PCI Express Gen5 (PCIe 5.0) and DDR5. PCIe 5.0 doubles the I/O bandwidth from PCIe 4.0, increasing the bandwidth between CPU and connected devices. DDR5 offers up to 1.5x bandwidth increase over DDR4 [3]. + +4th gen Intel Xeon with Intel AMX sped up training of BERT-large by 4x compared to 3rd gen Intel Xeon [4]. TMUL executes Intel AMX instructions on data loaded in 2D registers, hence the name tiles. These instructions operate on 8-bit integer (INT8) or 16-bit bfloat (BF16) datatype. 4th gen Intel Xeon with Intel AMX can attain 2048 INT8 operations per cycle compared to 256 INT8 operations per cycle in 3rd gen Intel Xeon with Intel Advanced Vector Extensions 512 Neural Network Instructions (Intel AVX-512 VNNI). Its BF16 performance is 1024 operations per cycle compared to its FP32 performance of 64 operations per cycle. Therefore, Intel AMX can significantly speed up DL applications when INT8 or BF16 datatype is used for matrix multiplication or convolution computations, the common operations in transformer or convolution-based models. + +## DeepSpeed enabled for 4th Gen Intel Xeon + +DeepSpeed is a DL optimization software for scaling and speeding up DL training and inference. DeepSpeed Inference refers to the feature set in DeepSpeed implemented to speed up inference of transformer models [2]. It initially supported only CUDA GPU. We recently added support for CPU, specifically 4th gen Intel Xeon. Features currently implemented for 4th gen Intel Xeon include automatic tensor parallelism (AutoTP), BF16 and INT8 datatype support, and binding cores to rank. + +DeepSpeed builds on top of PyTorch, which has been highly optimized for CPU inference and training. Intel® Extension for PyTorch* adds state-of-the-art optimizations for popular LLMs architectures, including highly efficient matrix multiplication kernels to speed-up linear layers and customized operators to reduce the memory footprint [5]. The runtime software components for DeepSpeed Inference on CPU are shown below in Figure 1. Intel® oneAPI Deep Neural Network Library (oneDNN) uses Intel AVX-512 VNNI and Intel AMX optimizations [6]. Intel® oneAPI Collective Communications Library (oneCCL) is a library that implements the communication patterns in DL [7]. Intel® Neural Compressor (INC) was used to convert the LLMs from FP32 datatype to BF16 or INT8 datatype [8]. + + +
+
+Figure 1. Software components for DeepSpeed Inference on CPU +
+ +## Technologies Introduced + +To accelerate running LLMs with DeepSpeed on 4th-generation Intel Xeon, we introduced technologies into both DeepSpeed and Intel Extension for PyTorch. + +1. Extend DeepSpeed Accelerator Abstraction Interface to provide CPU support [9]. We implemented CPU as a DeepSpeed Accelerator which allows CPU support to be plugged into DeepSpeed in a device-agnostic manner. Device-agnostic DeepSpeed model scripts which use DeepSpeed Accelerator Abstraction Interface can run on CPU devices without modification. +2. Fine-grain core binding. We introduced two new DeepSpeed command line arguments: `--bind_cores_to_rank` and `--bind_core_list` to allow core binding with DeepSpeed AutoTP [10] on a node with multiple sockets or on a single socket with multiple sub-NUMA nodes (SNC). Using `numactl`` for each tensor parallel worker, we can bind workers to cores and NUMA memory. This reduces interference between workers and uses memory bandwidth and core more effectively. +3. Optimized shared memory (SHM) based on AllReduce communication primitives for a single CPU node. We implemented a low latency SHM based AllReduce primitive which utilizes the shared memory of a single-node CPU system. +4. Optimizations in Intel Extension for PyTorch + + a. oneDNN, Tensor Processing Primitives (TPP) and customized linear kernels for weight only quantization. + + b. Indirect Access KV Cache reduces memory reorder overhead when using KV cache. + + c. Subgraph fusion to reduce memory footprint. + + d. Fusion of AllReduce between multi-head attention and multilayer perceptron in transformer layer when there is no dependency between them. + +## How to run DeepSpeed on CPU + +Software required for DeepSpeed Inference on CPU (Specific details can be found in the configuration.) +* PyTorch +* Intel Extension for PyTorch [6] +* oneCCL binding for PyTorch [11] +* oneCCL [7] +* DeepSpeed [12] + +After installing the required software, we can run inference for a model on CPU. Device agnostic interfaces are used to load and run the model. These device agnostic interfaces are accessed through deepspeed.accelerator.get_accelerator() as shown below in Listing 1. Refer to the DeepSpeed tutorial on DeepSpeed accelerator interfaces [13] for further details. + +```python +# Listing 1. An example of using device agnostic interface to get the accelerator device and load and run a model. +import deepspeed +from deepspeed.accelerator import get_accelerator +... +# load model checkpoint into model +model = model.eval().to(get_accelerator().device_name()) + +ds_world_size = int(os.getenv('WORLD_SIZE', '0')) + +engine = deepspeed.init_inference(model=model, mp_size=ds_world_size, \ + dtype=torch.bfloat16, replace_method="auto", \ + replace_with_kernel_inject=False) + +model = engine.module +... +# evaluate model +``` + +Execute the inference code with DeepSpeed using the following command: + +```bash +deepspeed --bind_cores_to_rank +``` + +This command detects the number of sockets on host and launches as many inference workers as the number of sockets. The LLM workload runs in parallel on the inference workers with DeepSpeed AutoTP [10]. AutoTP distributes inference computation among workers and reduces inference latency. For example, if the host has two sockets, this command will launch two inference workers to inference the input sample in parallel. The argument --bind_cores_to_rank instructs DeepSpeed to split the CPU cores and distribute them to each rank evenly. This ensures that each inference worker uses an exclusive set of CPU cores to avoid interfering with one another. If this argument is not specified, it will defer to the operating system to schedule the workers to the CPU cores, which may not be optimal. + +Intel Extension for PyTorch is compatible with DeepSpeed AutoTP and can therefore be used to further optimize AutoTP models generated by DeepSpeed. + +```python +# Use Intel Extension for PyTorch to optimize model +... +model = engine.module +import intel_extension_for_pytorch as ipex +model = ipex.optimize_transformers(model.eval(), dtype=torch.bfloat16, inplace=True) +... +``` +Examples of LLM optimizations for DeepSpeed AutoTP models with Intel Extension for PyTorch are available at [14]. + +## Results + +DeepSpeed enables optimal distribution of LLM inference on two 4th gen Intel Xeon sockets. Intel AMX on 4th gen Intel Xeon can be used to accelerate BF16 matrix multiplication operations. Support for Intel AMX is available through Intel Extension for PyTorch. Performance speedups in GPT-J-6B and LLaMA2-13B from DeepSpeed AutoTP on 2 sockets are shown in Figure 2 below. GPT-J-6B has 6 billion parameters, requiring 12 GB of memory for its weights. Llama-2-13B has 13 billion parameters, requiring 26 GB of memory for the weights. Latency improvement is the metric used. Prompt latency and per token latency improved as shown by the speedups in the plot. + +
+
+Figure 2. Performance speedups from 1-socket to 2-socket 4th gen Intel Xeon with DeepSpeed AutoTP. Higher speedup represents higher performance. Per token latency is per token latency for 2nd and subsequent tokens. in/out refers to the input token size and output token size. Beam search size was 4. See backup for configurations, results may vary. +
+ +## Summary +DeepSpeed Inference has been enabled for 4th gen Intel Xeon with Intel AMX to accelerate matrix multiplications common in DL workloads. DeepSpeed Inference leverages 4th Gen Intel Xeon to speed up the inferences of GPT-J-6B and Llama-2-13B. We will continue to improve it for new devices and new LLMs. Intel Data Center GPU Max is a new GPU designed for AI for which DeepSpeed will also be enabled [15]. + +## Contributors +This work was made possible through deep collaboration between software engineers and researchers at Intel and Microsoft. The contributors of this work include Guokai Ma, Kiefer Kuah, Yejing Lai, Liangang Zhang, Xiaofei Feng, Xu Deng, Mengfei Li, Jianan Gu, Haihao Shen, and Fan Zhao from Intel; Olatunji Ruwase, Martin Cai, and Yuxiong He from Microsoft. + +## Configuration +1-node, 2x Intel® Xeon® Platinum 8480+, 56 cores, HT On, Turbo On, 1024 GB (16x64GB DDR5 4800 MT/s [4800 MT/s]) , BIOS version Intel Corporation SE5C7411.86B.9525.D13.2302071333, 02/07/2023, ucode version 0x2b000190, Red Hat Enterprise Linux 8.6, kernel version 4.18.0-372.9.1.el8.x86_64, gcc 11.2.1, PyTorch 2.1.0.dev20230618+cpu, DeepSpeed 0.9.5+3f5e4931, ipex 2.1.0+git31b7cd6, GPT-J-6B, LLaMA-2-13B. + +## References + +[1] Microsoft, "ZeRO-Inference: Democratizing massive model inference," 9 September 2022. [Online]. Available: https://www.deepspeed.ai/2022/09/09/zero-inference.html. [Accessed 12 April 2023]. + +[2] R. Y. Aminabadi, S. Rajbhandari, M. Zhang, A. A. Awan, C. Li, D. Li, E. Zheng, J. Rasley, S. Smith, O. Ruwase, Y. H. Y. Aminabadi, S. Rajbhandari, M. Zhang, A. A. Awan, C. Li, D. Li and El, "DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale," 30 6 2022. [Online]. Available: https://arxiv.org/abs/2207.00032. + +[3] Intel, "4th Gen Intel(r) Xeon(r) Scalable Processors," [Online]. Available: https://www.intel.com/content/www/us/en/products/docs/processors/xeon-accelerated/4th-gen-xeon-scalable-processors-product-brief.html. [Accessed 12 4 2023]. + +[4] Intel, "Accelerate AI Workloads with Intel® AMX," [Online]. Available: https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/ai-solution-brief.html. [Accessed 12 4 2023]. + +[5] Intel, "Large Language Models (LLM) Optimizations Overview," [Online]. Available: https://intel.github.io/intel-extension-for-pytorch/cpu/2.1.0+cpu/tutorials/llm.html. + +[6] Intel, "Intel® Extension for PyTorch," [Online]. Available: https://github.com/intel/intel-extension-for-pytorch. + +[7] Intel, "oneAPI Collective Communications Library (oneCCL)," [Online]. Available: https://github.com/oneapi-src/oneCCL. + +[8] Intel, "Intel® Neural Compressor," [Online]. Available: https://github.com/intel/neural-compressor. + +[9] Microsoft, "DeepSpeed Accelerator Abstraction Interface," [Online]. Available: https://github.com/deepspeedai/DeepSpeed/blob/master/docs/_tutorials/accelerator-abstraction-interface.md. + +[10] Microsoft, "Automatic Tensor Parallelism for HuggingFace Models," [Online]. Available: https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism. + +[11] Intel, "Intel® oneCCL Bindings for PyTorch," [Online]. Available: https://github.com/intel/torch-ccl. + +[12] Microsoft, "deepspeed," [Online]. Available: https://github.com/deepspeedai/deepspeed. + +[13] Intel, "DeepSpeed Accelerator Abstraction Interface," [Online]. Available: https://github.com/deepspeedai/DeepSpeed/pull/3184. + +[14] Intel, "Intel® Extension for PyTorch large language model example," [Online]. Available: https://github.com/intel/intel-extension-for-pytorch/tree/llm_feature_branch/examples/cpu/inference/python/llm. + +[15] Intel, "Intel® Data Center GPU Max Series," [Online]. Available: https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html. diff --git a/blogs/intel-inference/assets/intel-results.png b/blogs/intel-inference/assets/intel-results.png new file mode 100755 index 000000000000..e65aae3d4a8c Binary files /dev/null and b/blogs/intel-inference/assets/intel-results.png differ diff --git a/blogs/intel-inference/assets/software-arch.png b/blogs/intel-inference/assets/software-arch.png new file mode 100755 index 000000000000..da147f4b9672 Binary files /dev/null and b/blogs/intel-inference/assets/software-arch.png differ diff --git a/blogs/muon-optimizer/README.md b/blogs/muon-optimizer/README.md new file mode 100644 index 000000000000..3fead48ba03e --- /dev/null +++ b/blogs/muon-optimizer/README.md @@ -0,0 +1,83 @@ +# Using Muon Optimizer with DeepSpeed +## TL;DR +Muon optimizer has gained great momentum with significant adoption from frontier AI Labs. For example, Moonshot AI adopted Muon Optimizer to train their Large Foundation Model like Kimi-K2-Thinking. We are thrilled to announce that DeepSpeed now supports Muon optimizer. + +## What is Muon optimizer? +Muon is an optimizer designed for hidden 2D weights of a neural network. It takes gradient of the weight, computes its momentum, and applies Newton-Schulz iterations to orthogonalize the momentum matrix, then uses this orthogonalized matrix to update the weight [[1]](https://kellerjordan.github.io/posts/muon/). Because Muon only maintains one momentum buffer (versus Adam’s two), it uses less memory for optimizer states. + +The orthogonalization step is key to Muon’s convergence advantage in pretraining. In practice, gradient updates for 2D weights in transformers tend to have very high condition numbers — they are nearly low-rank, dominated by a few large singular directions. By orthogonalizing the momentum matrix, Muon equalizes all singular values, effectively amplifying rare but important update directions that would otherwise be overshadowed. This leads to better sample efficiency: in NanoGPT speedrunning benchmarks [[2]](https://github.com/KellerJordan/modded-nanogpt), Muon improved training speed by 35% over AdamW, and at 1.5B parameter scale it reached GPT-2 XL level performance approximately 25% faster than AdamW [[1]](https://kellerjordan.github.io/posts/muon/). + +Unlike Adam optimizer that requires two momentum buffer for each parameter, Muon optimizer only requires one momentum buffer. This means that for parameters using Muon optimizer, we only need to allocate one buffer for momentum, which can save memory compared to Adam. + +Muon is used by Keller Jordan’s mod of NanoGPT [[2]](https://github.com/KellerJordan/modded-nanogpt), Andrej Karpathy’s nanochat [[3]](https://github.com/karpathy/nanochat), and a variant of Muon (MuonClip) is also used by the production-level LLM Kimi-K2 from MoonShot [[4]](https://arxiv.org/pdf/2507.20534). More recently, Zhipu AI’s GLM-5 (744B parameters) confirmed the use of Muon optimizer in both GLM-4.5 and GLM-5 pretraining, along with a “Muon Split” technique that splits MLA up-projection matrices by attention head and orthogonalizes each head independently, addressing a performance gap between MLA and GQA when using Muon [[5]](https://arxiv.org/abs/2602.15763) DeepSeek-V4 (1.6T parameters) also employs the Muon optimizer for faster convergence and greater training stability [[6]](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro). + +## Muon Optimizer support in DeepSpeed +One of the challenges of applying Muon optimizer to DeepSpeed is that previous optimizers (SGD, Adam) look at gradients as flattened buffers. Thus it is hard to swap in Muon optimizer in the same place because the gradient buffers are already flattened. We move the Muon update to `get_flat_partition` function of stage 1 and 2 `DeepSpeedZeroOptimizer` in which per parameter gradients are still in unflattened stages, thus we can easily apply the Muon updates. + +Muon optimizer works on 2D weight matrices (attention and MLP weights). It applies Newton-Schulz orthogonalization to the momentum matrix, which requires the weight to be 2D. Non-2D parameters (embeddings, layer norms, biases, lm_head) fall back to AdamW. We apply a parse in model engine initializer to tag the model parameter with `use_muon`, if and only if the model parameter is 2D and belongs to hidden layers. When Muon optimizer is used, any parameter tagged `use_muon` will use Muon optimizer to update weight. + +Note that Muon is a hybrid optimizer: it uses Muon updates only for 2D hidden weights and falls back to Adam for all other parameters (embeddings, layer norms, biases, lm_head). The DeepSpeed config supports separate learning rates via `muon_lr` (for Muon parameters) and `adam_lr` (for Adam parameters). + +## Running DeepSpeed finetune with Muon optimizer +Deepspeed finetune demo [[7]](https://github.com/delock/deepspeed_finetune_demo) is a demo to use different DeepSpeed training features and compare their performance in a single place. You can use it to test finetune LLM models with Muon optimizer: +``` +git clone https://github.com/delock/deepspeed_finetune_demo +cd deepspeed_finetune_demo +./finetune.sh z2_muon.json +``` + +## Muon Optimizer Convergence Experiment Result + +We tested Muon optimizer by finetuning Moonlight-16B-A3B (a Mixture-of-Experts model with 16B total and 3B active parameters), and evaluated on code generation (MBPP/MBPP+), general knowledge (MMLU), and mathematical reasoning (GSM8K) benchmarks. Each benchmark uses its own domain-specific training set. + +**Training Configuration:** +- Model: Moonlight-16B-A3B (MoE, 16B total / 3B active) +- Training datasets: sahil2801/CodeAlpaca-20k for MBPP/MBPP+, cais/mmlu (auxiliary_train, ~95k examples) for MMLU, meta-math/MetaMathQA (sample_rate=0.1, ~39.5k examples) for GSM8K +- ZeRO Stage 2, bf16, Expert Parallelism (autoep_size=4) +- Batch size: 16, gradient accumulation: 2, 4 GPUs +- 1 epoch, gradient clipping: 1.0 + +### Evaluation Results + +| Optimizer | Learning Rate | adam_lr (for Muon) | MBPP | MBPP+ | MMLU | GSM8K | +|-----------|--------------|-------------------|-------|-------|--------|--------| +| baseline (pre-finetune) | — | — | 0.495 | 0.431 | 0.401 | 0.526 | +| AdamW | 2e-6 | — | 0.661 | 0.534 | 0.660 | 0.805 | +| Muon | 1e-4 | 2e-6 | 0.646 | 0.548 | 0.678 | 0.810 | + +Muon outperforms AdamW on 3 out of 4 metrics: MBPP+ (0.548 vs 0.534, +1.4pp), MMLU (0.678 vs 0.660, +1.8pp), and GSM8K (0.810 vs 0.805, +0.5pp). On MBPP base tests, AdamW edges out Muon (0.661 vs 0.646, -1.5pp), though Muon achieves a higher score on the more rigorous MBPP+ with extra test cases (0.548 vs 0.534), suggesting better generalization. + +## Muon Optimizer Memory Savings +Muon optimizer uses less memory for optimizer states than Adam, because it maintains one momentum buffer per parameter instead of two (first and second moment). + +### Memory Usage Comparison +Note that Muon is a hybrid optimizer: 2D hidden weights use Muon (1 buffer), while remaining parameters (embeddings, layer norms, lm_head) still use Adam (2 buffers). The actual memory savings depend on the fraction of parameters that are 2D hidden weights. For typical transformer models, approximately 90% of parameters are 2D hidden weights, so optimizer state memory is reduced by roughly 45%. However, because total GPU memory also includes model weights, gradients, and activations, the end-to-end memory reduction is smaller (see measured results below). + +| Optimizer | State Buffers per Param | Memory per Parameter | +|-----------|------------------------|---------------------| +| Adam | 2 (m, v) | 8 bytes | +| Muon | 1 (momentum) | 4 bytes | + +### Measured GPU Memory: Qwen2.5-3B Finetuning +We measured peak GPU memory during finetuning Qwen2.5-3B on tatsu-lab/alpaca using the same 8xA100 (40GB) configuration described above (batch size 32, ZeRO Stage 2, bf16). + +| Optimizer | Peak Memory per GPU | Savings vs AdamW | +|-----------|---------------------|------------------| +| AdamW | 34.5 GiB | — | +| Muon | 31.4 GiB | 9% | + +Muon reduces per-GPU memory by approximately 3 GiB (9%) compared to AdamW. The savings come entirely from optimizer states: Muon parameters store one momentum buffer (4 bytes) instead of Adam's two (8 bytes). However, because optimizer states are only one component of total GPU memory (alongside model weights, gradients, and activations), the end-to-end reduction is modest. For larger models or tighter memory budgets, this 9% savings could make the difference between fitting a workload on-device versus requiring CPU offloading. + +## What’s Next +Muon is rapidly gaining traction in the community, and production-level adoption by Kimi-K2 (1T parameters) and GLM-5 (744B parameters) signals that it is a serious contender to replace Adam as the default optimizer for large-scale training. We are actively building out full Muon support in DeepSpeed, with a series of improvements already in flight: + +- [x] **ZeRO Stage 2 support** — merged +- [x] **ZeRO Stage 3 support** — merged +- [x] **Gram-Schmidt based Newton-Schulz iteration** — a faster orthogonalization kernel, in review +- [ ] **CPU Offloading** — in progress +- [ ] **MuonClip** — the variant used by Kimi-K2, planned + +We welcome any thoughts, feedback and contributions related to Muon optimizer support on DeepSpeed, welcome to start an issue for discussion or submit a PR to DeepSpeed. Let’s make Muon rock solid and lightning fast in DeepSpeed! + +## Contributors +This work is contributed from Wang, Zhipeng (@PKUWZP); Peng Du (@pengdurice); Chi McIsaac(@qimcis) and Ma, Guokai (@delock) diff --git a/blogs/ulysses-offload/README.md b/blogs/ulysses-offload/README.md new file mode 100644 index 000000000000..5dfd63ccdc32 --- /dev/null +++ b/blogs/ulysses-offload/README.md @@ -0,0 +1,263 @@ +# Ulysses-Offload: Democratizing Long Context LLM Training + + + +Figure 1: Ulysses-Offload supports 16x longer sequence lengths at 55% +Model FLOPs Utilization (MFU) than NVIDIA Megatron-SP and DeepSpeed Ulysses. + + +To cite and for more technical in depth of this release, please see +our [arxiv report](https://arxiv.org/abs/2408.16978): + +@article{yao2024ulysses, + +title={ Training Ultra Long Context Language Model with Fully Pipelined +Distributed Transformer}, + +author={Jinghan Yao and Sam Ade Jacobs and Masahiro Tanaka and Olatunji +Ruwase and Aamir Shafi and Hari Subramoni and Dhabaleswar K. (DK) Panda +}, + +journal={https://arxiv.org/abs/2408.16978}, + +year={2024} + +} + +## Introduction + +In the rapidly evolving field of generative AI and scientific ML, the +ability to train large (language) models with ultra-long context +capabilities is becoming increasingly important. These models are +essential for a variety of complex tasks, such as understanding +lengthy documents, generating images and videos, and processing extensive +sequences in computational biology. However, training such models +efficiently poses significant challenges due to the enormous GPU +memory required. + +Building DeepSpeed Ulysses, our previous project, which developed +system optimizations for training extremely long sequence transformer +models, we are excited to present Ulysses-Offload, in this release. Ulysses-Offload +is an innovative, resource-efficient technique that offers comparable +benefits to DeepSpeed Ulysses and other previous long-context +optimization methods, but with a lower hardware budget. Ulysses-Offload makes +ultra long-context large language models (LLM) training and finetuning +accessible to everyone, including those with limited GPU resources. Ulysses-Offload enables +training with context lengths of up to 2 million tokens using just 4 +NVIDIA A100-40GB GPUs. Ulysses-Offload supports 16x longer sequence lengths at 55% +Model FLOPs Utilization (MFU) than NVIDIA Megatron-SP and DeepSpeed Ulysses +(see Figure 1). The next section highlights the key innovations of Ulysses-Offload, +and subsequent sections provide additional details on the design and +usability of Ulysses-Offload, followed by experimental results. + +## Key Innovations + +### 1. Fully Pipelined Distributed Transformer (FPDT) + +The core innovation of our work is the Fully Pipelined Distributed +Transformer (FPDT). This approach leverages a pipelined sequence +chunking, which allows for the training of LLMs with sequence lengths up +to 2 million tokens on just 4 A100-40GB GPUs. By breaking down the +sequence into manageable chunks and processing them in a pipelined +manner, Ulysses-Offload significantly reduces the memory footprint while +maintaining high computational efficiency. This method ensures that the +GPUs are utilized effectively, even when dealing with extremely long +sequences. + +### 2. Memory Optimization + +One of the critical aspects of our approach is the comprehensive +analysis and optimization of the memory footprint during LLM training. +We target the reduction of redundant intermediate buffers in both the +forward and backward passes of the training process. By optimizing the +use of GPU and host CPU memory, we can train larger models with longer +sequences without running into GPU memory limitations. This optimization +is crucial for enabling the training of ultra-long context models on a +limited number of GPUs. It is worth noting that Ulysses-Offload memory optimization +is orthogonal and complementary to model- parameter-focused memory +optimization techniques used by DeepSpeed ZeRO and PyTorch FSDP. Ulysses-Offload optimizes memory footprint of activations associated with long sequences while ZeRO and FSDP optimize memory footprint of model parameters. + +### 3. Compatibility and Flexibility + +Ulysses-Offload is designed to be agnostic to existing training techniques and +works efficiently across different LLM models, including popular +architecture like GPT and Llama. This flexibility ensures that our +approach can be easily integrated into various training workflows. +Additionally, Ulysses-Offload is compatible with advanced memory optimization +techniques such as DeepSpeed ZeRO and PyTorch FSDP, further enhancing +its usability and performance. + +## Core Design of Ulysses-Offload + +Figure 2 illustrates the core structure of Ulysses-Offload. Ulysses-Offload leverages multiple +memory hierarchies in modern GPU clusters, thus boosting hardware +efficiency and cost-effectiveness while achieving very high model FLOP +utilization (MFU). The design of Ulysses-Offload centers around pipelining, +scheduling, and memory management. These well-known optimization +techniques are essential for scaling LLM context length to a million +scale with a few GPUs and will be discussed in the subsequent +subsections. + + + +Figure 2: Core design + +### + +### Pipelining and Scheduling + +Ulysses-Offload employs sequence chunking and pipelined computation design to manage the memory +and computational load efficiently. In traditional Transformer model, +input (hidden state) tensor is projected to q, k, v tensors. Each of these tensors can be denoted *\[B, S, H, D\]*, where *B* is batch +size, *S* is sequence length, *H* is number of heads and *D* is hidden +dimension per head. With sequence parallelism such as DeepSpeed Ulysses, +input tensor is partitioned along sequence dimension across sequence +parallel group P, that is *\[B, S/P, H, D\]* prior to alltoall collective +communication. The alltoall collective communication gathers partitioned tensors +along sequence dimension and scatter them along head dimension essentially +transforming tensor from *\[B, S/P, H, D\]* to *\[B, S, H/P, D\]*. Post attention computation, a second alltoall communication transforms *\[B, S, H/P, D\]* back to *\[B, S/P, H, D\]* + +In our Ulysses-Offload design, input sequence are partitioned at a much finer granularity than DeepSpeed Ulysses. In other words, we made changes to sequence partitioning such that we further subdivide per GPU *S/P* sequence into smaller *u* +chunks. Thus, the input tensors are now represented as \[*B, S/uP, H, +D*\]. We denote these chunks as *Ti*, +where$\ i\ \in \ 0,1,\ldots,\ u - 1.$ As shown in Figure 1, +*Ti* is projected to query *qi*, key +*ki*, and value *vi*. Then, similar to DeepSpeed Ulysses, an alltoall collective communication gathers partitioned tensor +along sequence dimension and scatter them along head dimension. In our chunk +design, the sequence length for each chunk is reduced by a factor of *u* +compared to Ulysses. Please note that our Ulysses-Offload chunking procedure is generally applicable to other sequence parallelism techniques. + + + +Figure 3: Core design with offload description + +Figure 3 gives an example of how to perform the computation of chunk +*Tm*. After the alltoall collective communication, +*GPUj* receives +$\widehat{q}m,\ \widehat{k}m,\ and\ \widehat{v}m$*.* We then fetch the +previous sequence chunk by chunk from the host memory to +GPUj, and perform online attention with the current +$\widehat{q}m$ and update the output chunk accordingly. Note that, in a +strict manner, at any given time, only one set of chunks +$\widehat{k}i,\ and\ \widehat{v}i$ is placed on GPU's HBM, reducing the +memory footprint to $\frac{1}{u}$ compared to the non-offloading version +without double buffering. With double buffering, memory footprint is +reduced by *2/u*. + +### Memory Management + +Ulysses-Offload optimizes memory usage by carefully managing the allocation and +deallocation of buffers during training. This involves: + +1. Double Buffering: + + - Two sets of buffers are maintained to overlap computation with + data transfer. + + - While one set of buffers is used for computation, the other set is + preloaded with the next chunk of data. + +2. Hierarchical Memory Utilization: + + - GPU High Bandwidth Memory (HBM) is used for active computation. + + - Host memory is used to store intermediate results that are not + immediately needed, reducing the pressure on GPU memory. + +## Integration with Existing Frameworks + +Ulysses-Offload is designed to integrate seamlessly with popular deep learning +frameworks such as PyTorch. Ulysses-Offload provides user-friendly APIs that +abstract the complexities of pipelined training and memory management. +Users can adopt Ulysses-Offload with minimal changes to existing codebases. + +## Experimental Results + + + +Figure 4: Supported sequence lengths and corresponding Model FLOPs +Utilization (MFU) using Megatron-SP, Ulysses, and our proposed Ulysses-Offload (FPDT). OOM +denotes the point where increasing sequence length will cause memory +issues. We show Ulysses-Offload's performance when the sequence length is larger +than 128K, as shorter sequences can be properly handled by existing +strategies. + +### Extended Sequence Lengths + +In our experimental setup, we compare Ulysses-Offload with two existing methods: +Microsoft DeepSpeed Ulysses and NVIDIA Megatron-SP. Both DeepSpeed +Ulysses and Megatron-SP employ similar approaches to sequence +parallelism but differ in the collective communication used for +gathering sequences before the attention block. The former utilizes +alltoall communication, whereas the latter employs allgather. Ulysses-Offload +builds upon the DeepSpeed Ulysses approach. The primary advantage of +Ulysses-Offload is its capability to support the training of large language models +(LLMs) with ultra-long sequence lengths using fewer GPUs. As shown in +Figure 4, our method enables the training of 8B parameter models with +sequence lengths of 2 million tokens using only 4 GPUs. For even larger +models, such as GPT-30B and Llama-70B parameter models, Ulysses-Offload supports +sequence lengths up to 3 million and 4 million tokens using 16 GPUs and +32 GPUs respectively. This represents a 16x increase in sequence length +compared to current state-of-the-art solutions (see Figure 5), making +Ulysses-Offload a game-changer for tasks that require processing long sequences. + +### High Hardware Efficiency + +As shown in Figure 4 with different model sizes ranging from GPT-2.7B to +Llama-80B parameters, Ulysses-Offload achieves over 55% Model FLOPs Utilization +(MFU), ensuring that the hardware resources are utilized effectively. +This high level of efficiency is maintained even when dealing with +extremely long sequences (up to 4 million context length), making Ulysses-Offload +an ideal solution for training large-scale LLMs. By maximizing the use +of available hardware, Ulysses-Offload reduces the overall cost and complexity of +training long-context models. Our [technical report](https://arxiv.org/abs/2408.16978) offers +further insights into optimizing sequence chunks to balance the +trade-off between memory usage and MFU. + + + +Figure 5: A comprehensive analysis on long-context LLM training with +different training techniques: tensor parallelism (TP), activation +checkpoint (AC), activation checkpoint with CPU offloading (OC), Ulysses +(UL), and our approach Ulysses-Offload (FPDT). + +## Implementation and Usability + +Ulysses-Offload is designed to be easily integrated with popular deep learning +frameworks such as DeepSpeed, Megatron-DeepSpeed and PyTorch. Users can +adopt our approach with minimal changes to their existing training +pipeline, making it accessible to a broad audience. The integration +process involves setting up the sequence chunk pipeline and configuring +the memory optimization techniques, both of which are straightforward +and well-documented (see tutorial). + +Our pipeline design and memory optimization techniques are +straightforward to implement, making Ulysses-Offload accessible to researchers and +practitioners aiming to train long-context LLMs efficiently. We provide +detailed [technical report](https://arxiv.org/abs/2408.16978), +documentation and examples to guide users through the setup process, +ensuring a smooth transition to using Ulysses-Offload. Additionally, Ulysses-Offload, in the +tradition of DeepSpeed provides user-friendly API which abstracts the +complexities of mixed precision training and memory optimization, +allowing users to focus on their research and development tasks. + +## General Availability of DeepSpeed Ulysses-Offload + +We are excited to release Ulysses-Offload. Ulysses-Offload has been +fully integrated with Megatron-DeepSpeed and accessible through both +DeepSpeed and Megatron-DeepSpeed GitHub repos. Click here for detailed +[tutorial](https://www.deepspeed.ai/tutorials/ulysses-offload/) on usage. + +We invite the community to explore our implementation, contribute to +further advancements, and join us in pushing the boundaries of what is +possible in LLM and AI. This release is part of the bigger DeepSpeed +ecosystem of large-scale AI training, finetuning and inference. For more +details on all DeepSpeed technologies and innovations, please visit our +[website]((https://www.deepspeed.ai/)) and follow us +on X, formerly Twitter, ([English](https://twitter.com/DeepSpeedAI), +[Japanese](https://twitter.com/DeepSpeedAI_JP)) and +[Chinese Zhihu](https://www.zhihu.com/people/deepspeed). diff --git a/blogs/ulysses-offload/media/image1.png b/blogs/ulysses-offload/media/image1.png new file mode 100644 index 000000000000..df63ef819763 Binary files /dev/null and b/blogs/ulysses-offload/media/image1.png differ diff --git a/blogs/ulysses-offload/media/image2.png b/blogs/ulysses-offload/media/image2.png new file mode 100644 index 000000000000..fee5ad05d313 Binary files /dev/null and b/blogs/ulysses-offload/media/image2.png differ diff --git a/blogs/ulysses-offload/media/image3.png b/blogs/ulysses-offload/media/image3.png new file mode 100644 index 000000000000..7777b1b245dd Binary files /dev/null and b/blogs/ulysses-offload/media/image3.png differ diff --git a/blogs/ulysses-offload/media/image4.png b/blogs/ulysses-offload/media/image4.png new file mode 100644 index 000000000000..05226922851b Binary files /dev/null and b/blogs/ulysses-offload/media/image4.png differ diff --git a/blogs/ulysses-offload/media/image5.png b/blogs/ulysses-offload/media/image5.png new file mode 100644 index 000000000000..447037975d1a Binary files /dev/null and b/blogs/ulysses-offload/media/image5.png differ diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md new file mode 100644 index 000000000000..5148fa80d122 --- /dev/null +++ b/blogs/windows/08-2024/README.md @@ -0,0 +1,101 @@ +
+ +# DeepSpeed on Windows + +
+ +# Introduction + +DeepSpeed is a popular open-source deep learning optimization library that makes distributed training and inference easy, efficient, and effective. DeepSpeed has been widely used to train a variety of state-of-the-art models, including Phi-3, Megatron-Turing-530B, BLOOM-176B, and Arctic because of its rich suite of sophisticated optimizations (e.g., ZeRO, 3D parallelism, MoE, etc.). However, the lack of native support for Microsoft Windows, the most popular operating system, means that DeepSpeed innovations are inaccessible to many AI developers and users. To address this problem, we started an effort to make DeepSpeed run natively with full features on Windows, while ensuring the same ease-of-use enjoyed on Linux. + +In this blog, we are pleased to announce some early achievements on this journey: DeepSpeed can now be installed in Windows and run natively for single-GPU training, finetuning, and inferencing. Importantly, both the installation and usage experiences are identical to those on Linux. Furthermore, the finetuning and inferencing workloads demonstrate the functioning of three critical DeepSpeed features, HuggingFace Transformers integration, LoRA support, and CPU Offloading. DeepSpeed on Windows is available in DeepSpeed versions 0.14.5 and above. In the rest of this blog, we present examples to demonstrate these achievements. + +# Evaluation Environment +We conducted the experiments on a Surface Laptop Studio 2 running Windows 11 Version 23H2 and Build 22631.3880. The laptop is equipped with a single NVIDIA RTX A2000 GPU with 4GB VRAM. We used Pytorch version 2.3.0 and HuggingFace Transformers version 4.41.2. The example scripts used are from the [DeepSpeedExamples repo](https://github.com/deepspeedai/DeepSpeedExamples), therefore you need to clone the repo before running any of the following examples. + +# Installation +DeepSpeed can be installed on Windows in one of two ways. The easier way is to use the pip package manager, while the other is to build from source. The prerequisites for in both cases are Python 3.x and Pytorch with CUDA support. + +## Installing via pip +To install DeepSpeed, simply run: `pip install deepspeed`. This will install the latest version of DeepSpeed (0.14.5 at this time). Unlike the Linux counterpart, the Windows version comes with all the operators already prebuilt, so there is no need to have a CUDA SDK or C++ compiler installed. + +
+ +
+ +
+ pip installation of DeepSpeed on Windows. +
+ + +## Building from Source +To build DeepSpeed from source, you need to clone the DeepSpeed repository and run the `build_win.bat` compilation script. + + +## Validating Installation +Regardless of the installation choice, you can check that the installation was successful by running ds_report. The output should look like this: + + +
+ +
+ +
+ ds_report output confirming Windows installation of DeepSpeed. +
+ +# Pretraining Examples +We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed. + +## Pretraining CIFAR10 +The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this: +
+ +
+ +
+ Pretraining CIFAR10 model on Windows using DeepSpeed. +
+ +## Pretraining BERT +The scripts and codes for the BERT pretraining example are available in the following path: DeepSpeedExamples\training\HelloDeepSpeed. You can launch the BERT pretraining experiment using the following command: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`. The final output should look like this: + +
+ +
+ +
+ Pretraining BERT model on Windows using DeepSpeed. +
+ +# Fine Tuning Example +We demonstrate fine tuning capability by using the supervised fine tuning (SFT) step of DeepSpeed-Chat application. We conduct SFT of the HuggingFace facebook/opt-125m model while enabling LoRA and CPU offloading memory optimizations. The command line for running this example is as follows:\ +`deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output`.\ +The output should look like this: + +
+ +
+ +
+ Supervised Finetuning of facebook/opt-125m model on Windows using DeepSpeed. +
+ +# Inference Example +We demonstrate inference capability by using ZeRO-Inference for token generation. ZeRO-Inference reduces hardware cost of inferencing by offloading to CPU or NVMe memories. We use the example scripts here to run token generation using Llama-2-7B model from HuggingFace. We offload the model weights to CPU memory since the 4GB VRAM is insufficient to host both the model and the generation working set. We use the following command line to generate 32 tokens from a prompt of 8 tokens:\ +`deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload`.\ +The output will look something like this: + +
+ +
+ +
+ LLAMA2-7B token generation on Windows using ZeRO-Inference. +
+ +# Summary +Enabling DeepSpeed, a popular deep learning framework, to run natively on Windows, the most popular operating system, is a crucial step towards empowering every person and every organization to benefit from the ongoing AI revolution. In this blog, we have shared early results of our work towards this goal. Although Windows support of DeepSpeed is a work-in-progress, we hope that the above updates are encouraging and already useful to users. The next items on our roadmap include running on multiple GPUs, weight quantization, and performance studies. + +# Acknowledgements +This work is a result of significant contributions from current and former DeepSpeed members including Costin Eseanu, Logan Adams, Elton Zheng, Reza Yazdani Aminabadi, Martin Cai, and Olatunji Ruwase. We also acknowledge the valuable contributions of DeepSpeed users who righteously demanded this feature, provided critical workarounds, partial solutions, and constructive feedback, and most importantly, stuck with us. diff --git a/blogs/windows/08-2024/chinese/README.md b/blogs/windows/08-2024/chinese/README.md new file mode 100644 index 000000000000..78b9b6213d89 --- /dev/null +++ b/blogs/windows/08-2024/chinese/README.md @@ -0,0 +1,103 @@ +
+ +# 在Windows系统上使用DeepSpeed + +
+ +# 简介 + +DeepSpeed是一个广受欢迎的开源深度学习优化库,它使得分布式训练和推理变得简单、高效且有效。凭借其众多复杂的优化技术(如ZeRO、3D并行、MoE等),DeepSpeed已被成功应用于包括Phi-3、Megatron-Turing-530B、BLOOM-176B和Arctic在内的多种前沿模型的训练。然而,由于缺乏对主流操作系统微软 Windows的原生支持,许多AI开发者与用户无法充分利用DeepSpeed的创新。为此,我们致力于让DeepSpeed在Windows上实现原生全功能运行,并保持与Linux相同的易用性。 + +在这篇博客中,我们很高兴地宣布我们开发工作中的一些早期成果:DeepSpeed 现在可以在 Windows 上安装并原生支持单 GPU 的训练、微调和推理。重要的是,安装和使用体验与 Linux 上完全相同。此外,微调和推理工作展示了 DeepSpeed 的三个关键特性:HuggingFace Transformers 的集成、LoRA 的支持和 CPU offload。DeepSpeed 在 Windows 上的支持从 DeepSpeed 0.14.5 开始。接下来,我们将通过一些例子展示这些成就。 + +# 测试环境 + +我们在一台运行 Windows 11 23H2 版本号 22631.3880 的 Surface Laptop Studio 2 上进行了测试。该笔记本配备了一块 4GB 显存的 NVIDIA RTX A2000 GPU。我们使用了 Pytorch 2.3.0 和 HuggingFace Transformers 4.41.2。测试所用的脚本来自 [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) 代码仓库,因此在运行以下任何示例前,你需要克隆该仓库。 + +# 安装指南 +DeepSpeed可以通过两种方式在Windows系统上安装。较为简单的方式是使用pip包管理器安装,另一种方法是从源代码安装。两种安装方式的前提条件都是系统已经安装了Python 3.x 和支持CUDA的Pytorch. + +## 通过pip安装 +要安装 DeepSpeed,只需运行:`pip install deepspeed`。它将安装最新版本的 DeepSpeed(目前为 0.14.5)。与 Linux 版本不同的是,Windows 版本已经预先编译了内部的自定义算子,因此不需要安装 CUDA 或 C++ 编译器。 + +
+ +
+ +
+ 通过pip在Windows上安装Deepspeed. +
+ + +## 通过源代码安装 +克隆DeepSpeed代码仓库后,运行build_win.bat脚本进行编译安装。 + + +## 验证安装 +无论选择哪种安装方式,你都可以通过运行 ds_report 来检查安装是否成功。输出应该如下所示: + + +
+ +
+ +
+ ds_report的输出结果,用于验证安装是否成功. +
+ +# 预训练(Pretraining) +我们使用图像分类模型 CIFAR10 和语言模型 BERT 来演示在 Windows 上使用 DeepSpeed 进行预训练。 + +## CIFAR10模型预训练 +用于 CIFAR10 预训练的脚本和代码可以在以下路径找到:`DeepSpeedExamples\training\cifar`。你可以运行以下命令启动 CIFAR10 预训练实验:`deepspeed cifar10_deepspeed.py --deepspeed`。最终输出应类似于: +
+ +
+ +
+ 在 Windows 上使用 Deepspeed 进行 CIFAR10 模型预训练 +
+ +## BERT模型预训练 +用于 BERT 预训练的脚本和代码可以在以下路径找到:`DeepSpeedExamples\training\HelloDeepSpeed`。你可以使用以下命令启动 BERT 预训练实验:`deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`。最终输出应如下所示: + +
+ +
+ +
+ 在 Windows 上使用 Deepspeed 进行 BERT 模型预训练 +
+ +# 微调(Fine Tuning) +我们使用 DeepSpeed-Chat 应用的监督微调(SFT)步骤来演示微调能力。我们对 HuggingFace 的 facebook/opt-125m 模型进行监督微调,同时启用 LoRA 和 CPU offload进行内存优化。运行命令行如下:\ +`deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output`\ +输出应如下所示: + +
+ +
+ +
+ 在 Windows 上使用 DeepSpeed 对 facebook/opt-125m 监督微调 +
+ +# 推理 +我们使用 ZeRO-Inference 的token生成来演示推理能力。ZeRO-Inference 通过转移存储到 CPU 内存或 NVMe 硬盘内存来减少推理的硬件成本。我们使用以下脚本运行 HuggingFace 的 Llama-2-7B 模型来进行 token 生成。由于 4GB 显存无法容纳模型和生成所需的内存,我们将模型权重转移到 CPU 内存。我们使用以下命令行从 8个token的提示词中生成 32 个token:\ +`deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload`\ +输出应类似于: + +
+ +
+ +
+ 在 Windows 上使用 ZeRO-Inference 进行 LLAMA2-7B 模型的token生成 +
+ +# 总结 + +让流行的深度学习框架 DeepSpeed 能够在最流行的操作系统 Windows 上原生运行,是让每个人和每个组织都能从正在进行的人工智能革命中受益的关键一步。在这篇博客中,我们分享了我们为实现这一目标所取得的早期成果。尽管 DeepSpeed 对 Windows 的支持仍在继续开发中,我们希望上述结果已经能够对我们的用户有实用价值,并且鼓舞他们。我们接下来的工作计划涵盖多GPU支持、权重量化以及性能优化。 + +# 致谢 +这给项目的完成得益于现任和前任 DeepSpeed 成员的大力合作,包括 Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai 和 Olatunji Ruwase。我们还要感谢那些及时提出此项需求、提供关键的临时解决方法、部分解决方案和建设性反馈的 DeepSpeed 用户,最重要的是,他们始终与我们同行. diff --git a/blogs/windows/08-2024/japanese/README.md b/blogs/windows/08-2024/japanese/README.md new file mode 100644 index 000000000000..c2f5b9ee2143 --- /dev/null +++ b/blogs/windows/08-2024/japanese/README.md @@ -0,0 +1,123 @@ +
+ +# DeepSpeedのWindowsサポート + +
+ +# はじめに + +DeepSpeedは、分散学習と推論を簡単かつ効率的に行うための人気のあるオープンソースの深層学習最適化ライブラリです。DeepSpeedは、その豊富かつ高度な最適化機能(例:ZeRO、3D parallelism, MoEなど)のおかげで、Phi-3、Megatron-Turing-530B、BLOOM-176B、Arcticなどの最先端モデルの学習に広く利用されています。しかし、最も普及しているオペレーティングシステムであるMicrosoft Windowsをネイティブにサポートしていなかったため、多くのAI開発者やユーザーが、DeepSpeedの革新的な機能を利用できない状態でした。この問題を解決するため、DeepSpeedの完全な機能をWindows上でネイティブに実行し、Linux上と同じ使いやすさを実現するための取り組みを開始しました。 + +このブログでは、この取り組みの最初の成果をお知らせします。現在、DeepSpeedはWindowsにインストールし、単一GPUでの学習、ファインチューニング、および推論をネイティブに実行できるようになりました。ここで重要なこととして、インストールと利用は、Linuxとまったく同じように行えます。ファインチューニングと推論のワークロードを通じて、HuggingFace Transformers との統合、LoRAのサポート、CPUオフロードの3つの重要なDeepSpeedの機能が、正しく動作していることが確認できました。このWindowsサポートは、バージョン0.14.5以降で利用可能です。このブログの残りの部分では、これらの成果を示す例を紹介します。 + +# テスト環境 + +Windows 11 Version 23H2 および Build 22631.3880 を実行している Surface Laptop Studio 2 でテストを行いました。このハードウェアには、4GBのVRAMを搭載した NVIDIA RTX A2000 GPU が1つ搭載されています。また、PyTorchバージョン 2.3.0 および HuggingFace Transformersバージョン 4.41.2 を使用しました。使用したサンプルスクリプトは[DeepSpeedExamplesリポジトリ](https://github.com/deepspeedai/DeepSpeedExamples)から取得できます。以下の例を実行する前にリポジトリをクローンしてください。 + +# インストール + +DeepSpeedは、2つの方法でWindowsにインストールできます。より簡単な方法は、pipパッケージマネージャーを使用することで、もう一方はソースからビルドする方法です。どちらの場合も、Python 3.xとCUDAサポート付きのPyTorchが必要です。 + +## pipを使用したインストール + +DeepSpeedをインストールするには、単に次のコマンドを実行します: `pip install deepspeed`。 +これにより、最新バージョンのDeepSpeed(現時点では0.14.5)がインストールされます。Linux版とは異なり、Windows版ではすべてのオペレーターがすでにビルド済みであるため、CUDA SDKやC++コンパイラをインストールする必要はありません。 + +
+ +
+ +
+ pipによるWindowsへのDeepSpeedのインストール +
+ + +## ソースからのビルド + +ソースからDeepSpeedをビルドするには、DeepSpeedリポジトリをクローンし、コンパイルスクリプトである `build_win.bat` を実行する必要があります。 + +## インストールの検証 + +インストール方法にかかわらず、`ds_report`を実行してインストールが成功したかどうかを確認できます。出力は次のようになります: + +
+ +
+ +
+ DeepSpeedのWindowsインストールを確認するds_reportの出力 +
+ +# 事前学習の例 + +Windows上でDeepSpeedを使用した事前学習の例として、画像分類モデルCIFAR10と言語モデルBERTの実行例を示します。 + +## CIFAR10の事前学習 + +CIFAR10の事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\cifar` + +以下のコマンドを使用してCIFAR10の事前学習を開始できます: `deepspeed cifar10_deepspeed.py –deepspeed` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedによるWindowsでのCIFAR10モデルの事前学習 +
+ +## BERTの事前学習 + +BERTの事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\HelloDeepSpeed` + +以下のコマンドを使用してBERTの事前学習を開始できます: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedによるWindowsでのBERTモデルの事前学習 +
+ +# ファインチューニングの例 + +DeepSpeed-Chatアプリケーションの教師ありファインチューニング(supervised fine tuning; SFT)を使用して、ファインチューニングの機能を示します。LoRAおよびCPUオフロードメモリ最適化を有効にして、 HuggingFace の `facebook/opt-125m` モデルのSFTを実施します。この例を実行するためのコマンドラインは次のとおりです: `deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedを使用したWindowsでの facebook/opt-125m モデルのファインチューニング +
+ +# 推論の例 + +推論の機能を示すために、トークン生成のためのZeRO-Inferenceを使用します。ZeRO-Inferenceは、CPUまたはNVMeメモリにオフロードすることで推論のハードウェアコストを削減します。ここでは、サンプルスクリプトを使用して、HuggingFaceのLlama-2-7Bモデルを使用したトークン生成を実行します。4GBのVRAMではモデルと生成処理の両方を実効するのに十分ではないため、モデルパラメータをCPUメモリにオフロードします。 + +次のコマンドラインを使用して、8トークンのプロンプトから32トークンを生成します: `deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload` + +出力は次のようになります。 + +
+ +
+ +
+ DeepSpeedのZeRO-InferenceによるWindowsでのLLAMA2-7Bのトークン生成 +
+ +# まとめ + +最も広く使われているオペレーティングシステムであるWindowsで、深層学習フレームワークであるDeepSpeedをネイティブに実行できるようにすることは、多くの人と組織が、今まさに進行中のAI革命の恩恵を受けるための重要な一歩です。このブログでは、この目標に向けたプロジェクトの、最初の成果を共有しました。Windowsのサポートは現在進行中のプロジェクトですが、今回の成果が多くのユーザにとって活用され、またさらに発展していけることを願っています。次のロードマップには、複数のGPUでの実行、モデルパラメータの量子化、パフォーマンスの詳細な分析が含まれます。 + +# 謝辞 + +このプロジェクトは、Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai、Olatunji Ruwaseを含むDeepSpeedメンバーによる大きな貢献の結果です。また、この機能を必要とし、様々な問題の解決策や、建設的なフィードバックを提供し、私たちと共に歩んでくれたDeepSpeedユーザーの重要な貢献に感謝します。 diff --git a/blogs/windows/08-2024/media/bert_training.png b/blogs/windows/08-2024/media/bert_training.png new file mode 100644 index 000000000000..c5935e47747e Binary files /dev/null and b/blogs/windows/08-2024/media/bert_training.png differ diff --git a/blogs/windows/08-2024/media/cifar10_training.png b/blogs/windows/08-2024/media/cifar10_training.png new file mode 100644 index 000000000000..99f3fa25bc70 Binary files /dev/null and b/blogs/windows/08-2024/media/cifar10_training.png differ diff --git a/blogs/windows/08-2024/media/ds_report.png b/blogs/windows/08-2024/media/ds_report.png new file mode 100644 index 000000000000..43d82d724ed2 Binary files /dev/null and b/blogs/windows/08-2024/media/ds_report.png differ diff --git a/blogs/windows/08-2024/media/llama2-7b_inference.png b/blogs/windows/08-2024/media/llama2-7b_inference.png new file mode 100644 index 000000000000..f5874468a854 Binary files /dev/null and b/blogs/windows/08-2024/media/llama2-7b_inference.png differ diff --git a/blogs/windows/08-2024/media/opt125m_finetuning.png b/blogs/windows/08-2024/media/opt125m_finetuning.png new file mode 100644 index 000000000000..ed6d1522e3b3 Binary files /dev/null and b/blogs/windows/08-2024/media/opt125m_finetuning.png differ diff --git a/blogs/windows/08-2024/media/win_pip_install_deepspeed.png b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png new file mode 100644 index 000000000000..3b87c95ef144 Binary files /dev/null and b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png differ diff --git a/blogs/zenflow-corebinding/README.md b/blogs/zenflow-corebinding/README.md new file mode 100644 index 000000000000..40a9c9a7d5b3 --- /dev/null +++ b/blogs/zenflow-corebinding/README.md @@ -0,0 +1,199 @@ +# Study of ZenFlow and ZeRO offload performance with DeepSpeed CPU core binding +**TL;DR:** ZenFlow is an improvement to ZeRO Offload contributed to DeepSpeed by Tingfeng Lan et al. After testing this feature, we explored the relationship between ZenFlow performance and DeepSpeed CPU core binding. + +## ZenFlow technology introduction +[ZenFlow](https://arxiv.org/abs/2505.12242) is a recent improvement to ZeRO Offload implemented in DeepSpeed. Its primary goal is to address the GPU stalls caused by ZeRO Offload. These stalls mainly originate from two sources: 1) the data transfer from the GPU to the CPU, which is limited by PCIe bandwidth, and 2) the computational overhead of executing the Adam optimizer on the CPU, which is constrained by CPU performance and memory bandwidth. + +The core idea of ZenFlow is to separate gradients into two groups based on their norm. A very small portion of gradients, which have larger norms, are classified as important gradients and are updated directly on the GPU. The vast majority of gradients, which have smaller norms, are used to update the weights on the CPU at a lower frequency than the important gradients. If the gradients are not scheduled for an update in the current training iteration, they are accumulated into a copy of the gradients. These accumulated gradients are then used for the weight update in a subsequent iteration. + +Furthermore, the weight updates on the CPU are designed to run in parallel with the computations on the GPU, thereby achieving the objective of reducing GPU stall. + +To achieve the goal of parallelizing weight updates on the CPU with GPU computations, ZenFlow creates an additional process for each rank. This dedicated process handles the weight updates, while the original process for each rank can continue executing GPU computation code. This design enables the concurrency between weight updates and GPU computations. In addition to these optimizations, ZenFlow also performs CPU core binding for the weight update processes. It binds the CPU update processes of different ranks to distinct CPU cores to enhance CPU performance. + +## DeepSpeed CPU core binding feature and its improvement to CPU offloading performance +This reminds us that DeepSpeed itself supports CPU core binding through the `--bind_cores_to_rank` flag. This switch was originally designed to improve multi-socket CPU inference performance. By binding cores, different workers can run on distinct CPU cores without interfering with each other, thereby enhancing locality. Additionally, DeepSpeed's core binding feature automatically configures the `OMP_NUM_THREADS` environment variable to ensure the OpenMP thread pool size matches the number of allocated cores. + +This raised a question: Could this switch also benefit ZeRO Offload? We conducted tests to explore this possibility. + +### Improvement to ZeRO Offload performance from DeepSpeed CPU core binding +| | Avg. time of first 51 iterations (1st run) | 2nd run | 3rd run | Average | +|-------------|--------------------------------------------|---------|---------|---------| +| No bind core| 2707.32ms | 3127.24ms | 2826.04ms | 2887ms | +| Bind core | 2649.06ms | 2641.82ms | 2200.76ms | 2497ms | + +**Model:** Qwen2.5-3B + +**Test environment:** 2xDGX-A100-SXM4-40GB, 2xAMD EPYC 7742 64-Core Processor, 1TB memory + +**Test URL:** [DeepSpeedExamples/training/DeepSpeed-ZenFlow/finetuning](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-ZenFlow/finetuning) (All following tests are using the same URL) + +**Test command:** + - **No core binding:** `deepspeed --num_gpus=2 finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zo_config.json --num_train_epochs 1` + - **With core binding:** `deepspeed --num_gpus=2 --bind_cores_to_rank finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zo_config.json --num_train_epochs 1` + +**Config file** (`zo_config.json`): +```json +{ + "train_batch_size": 8, + "bf16": { "enabled": true }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + } + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 2e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "zero_allow_untested_optimizer": true, + "wall_clock_breakdown": true +} +``` + +From this data, DeepSpeed's core binding provides approximately a 15% performance improvement for ZeRO Offload. So, could it also benefit ZenFlow's performance? With this question in mind, we decided to comment out the core binding logic within ZenFlow and instead directly use the `--bind_cores_to_rank` flag to run ZenFlow: + +### Improvement to ZenFlow performance from DeepSpeed CPU core binding +| | Avg. time from iteration 5-51 (1st run) | 2nd run | 3rd run | Average | +|--------------------|-----------------------------------------|---------|---------|---------| +|ZenFlow core binding| 1337.66ms | 1443.87ms | 1475.04ms | 1419ms | +|DeepSpeed core binding| 1233.6ms | 1228.36ms | 1235ms | 1232ms | + +**Model:** Qwen2.5-3B + +**Test environment:** 2xDGX-A100-SXM4-40GB, 2xAMD EPYC 7742 64-Core Processor, 1TB memory + +**DeepSpeed commit:** 1d7b90adc48d57c2283e8825f5c668a3730ff899 + +*ZenFlow use 4 iterations to compute gradient importance, so we start from 5th iteration to measure time* + +**Test command:** + - **No core binding:** `deepspeed --num_gpus=2 finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zf_config.json --num_train_epochs 1` + - **With core binding:** `deepspeed --num_gpus=2 --bind_cores_to_rank finetune_llama.py --model_name Qwen/Qwen2.5-3B --output_dir output --lr 2e-5 --batch_size 8 --deepspeed_config zf_config.json --num_train_epochs 1` + + +**Config file** (`zf_config.json`): +```json +{ + "train_batch_size": 8, + "bf16": { "enabled": true }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "zenflow": { + "topk_ratio": 0.1, + "update_interval": 4, + "full_warm_up_rounds": 0, + "overlap_step": true, + "pt_reserved_cores_perc": 0.5 + } + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 2e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "zero_allow_untested_optimizer": true +} +``` + + +We observed a performance improvement of approximately 15% from DeepSpeed CPU core binding against ZenFlow core binding. Why did this happen? + +## Our improvements to ZenFlow CPU core binding mechanism +After communicating with the authors of ZenFlow, we gained a new understanding of the core binding mechanism required by ZenFlow. + +First, the ZenFlow worker processes need to use a dedicated set of CPU cores, separate from those used by the main process of each rank. Second, the ZenFlow workers and the main processes should be bound to different physical cores, avoiding binding to virtual cores (hyper-threads). Third, the OpenMP thread pool size should be appropriately set to match the number of cores allocated to the ZenFlow workers. + +In the original ZenFlow implementation, all cores (including the virtual cores corresponding to physical cores) were used for core binding, meaning the workers were not properly isolated at the physical core level. In contrast, DeepSpeed's core binding specifically binds processes to physical cores only, which explains the performance improvement we observed. + +Based on this understanding, we collaborated with the ZenFlow authors to update its core binding mechanism. + +First, before each rank launches a ZenFlow worker process, it needs to enumerate the list of available physical cores. If these lists of physical cores differ across ranks, it indicates that DeepSpeed has already performed physical core binding. Otherwise, each rank needs to allocate its own list of available cores from the total pool. + +Finally, each rank allocates a subset of cores from its own list to the ZenFlow worker process and sets the corresponding `OMP_NUM_THREADS` environment variable. This ensures that all processes use distinct CPU cores, preventing interference, and also allows for proper configuration of the OpenMP thread pool size. [code](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py) + +Under this new core binding mechanism, we re-evaluated the performance of ZenFlow: + +### ZenFlow perf. with new core binding mechanism +| | Avg. time from iteration 5-51 (1st run) | 2nd run | 3rd run | Average | Improvement over original binding | +|--------------------|-----------------------------------------|---------|---------|---------|------| +| New ZenFlow worker core binding | 1321.21ms | 1269.83ms | 1384.47ms | 1325ms | 7% | +| DeepSpeed core binding + new ZenFlow worker core binding | 1111.68ms | 1125.38ms | 1111.91ms | 1116ms | 10% | + +**Model:** Qwen2.5-3B + +**Test environment:** 2xDGX-A100-SXM4-40GB, 2xAMD EPYC 7742 64-Core Processor, 1TB memory + +**DeepSpeed commit:** 80033a82938f6cd8ce4988a63c914941e7a8f324 + +The results indicate that ZenFlow's performance was further enhanced under the new core binding mechanism. Compared to the original binding method, performance improved by 7% when not using DeepSpeed's core binding. When DeepSpeed's core binding was enabled, the performance gain reached 10%. + +Why does DeepSpeed binding still provide an additional performance boost on top of the new ZenFlow binding? + +We initially hypothesized that it might be because DeepSpeed uses numactl, which can bind a process to a specific NUMA node, ensuring the process always accesses local memory. However, upon examining the DeepSpeed logs, we found that the -m switch was not enabled during runtime. Furthermore, when we replaced numactl with taskset, we still observed the performance improvement. + +Our current conjecture is that the difference lies in how the binding is applied. numactl (and taskset in this context) operates at the process level, applying the binding to the entire process from the start. In contrast, ZenFlow's binding is applied within the code at the point of use. This distinction in the scope and timing of the binding application could be the source of the performance difference. This point may require more detailed investigation in the future. + +Regardless, the key finding remains: the new ZenFlow binding mechanism improves performance irrespective of whether DeepSpeed binding is used. This conclusively demonstrates the effectiveness of physical core isolation for performance. + +We conducted a comparative analysis of the performance across several configurations: ZeRO Offload without core binding, ZeRO Offload with core binding, and ZenFlow both before and after our improvements. The results are summarized as follows: + +### Perf comparison table +| | Average time | Perf. improv. vs. baseline | +|-------------|--------------|----------------------------| +| ZeRO Offload without binding -- baseline | 2887ms | 1x | +| ZeRO Offload with DeepSpeed core binding | 2497ms | 1.16x | +| ZenFlow original worker core binding | 1419ms | 2.03x | +| DeepSpeed core binding +ZenFlow new worker core binding | 1116ms | 2.59x | + +**Model:** Qwen2.5-3B + +**Test environment:** 2xDGX-A100-SXM4-40GB, 2xAMD EPYC 7742 64-Core Processor, 1TB memory + +The result clearly shows that the improved ZenFlow achieves a 2.59x speedup compared to ZeRO Offload without core binding, and a 2.24x speedup compared to ZeRO Offload with core binding. + +Given that ZenFlow's core innovations involve reducing the frequency of weight updates and parallelizing CPU/GPU execution, the 2.24x improvement over the core-bound ZeRO Offload is particularly significant. This comparison provides a more accurate reflection of ZenFlow's inherent performance advantages. By using the core-bound ZeRO Offload as the baseline, we effectively isolate and quantify the performance gains attributable specifically to ZenFlow's algorithmic optimizations, rather than those coming from general core-binding techniques. This strongly validates the effectiveness of ZenFlow's fundamental design. + +Through our collaboration with the ZenFlow authors, the new core-binding mechanism has been integrated into the main branch of DeepSpeed. As a result, users can now achieve optimal offload performance by simply using ZenFlow in conjunction with the DeepSpeed `--bind_cores_to_rank` flag. This integration provides an out-of-the-box, high-performance experience that leverages the combined strengths of both the algorithmic innovations in ZenFlow and the low-level system optimizations in DeepSpeed's core binding. + +## Practicality metric, a metric to evaluate offloading technology +In addition to comparisons with ZeRO Offload, a performance comparison against scenarios without offloading better demonstrates the practicality of ZenFlow or ZeRO Offload. While it's true that ZeRO Offload or ZenFlow enables model optimization with relatively limited VRAM, achieving a breakthrough from impossibility to possibility, if the performance gap is too significant, the decision to use offloading becomes a dilemma. We consider the performance difference between scenarios with and without offloading as a practicality metric. A value of 1 represents the ideal scenario, indicating that offloading has no impact on performance. The smaller this value, the poorer the practicality, as users would need to wait considerably longer for fine-tuning. + +Since we couldn't run Qwen2.5-3B with ZeRO2 using the same config on two GPUs in our test environment, we conducted the practicality test using Qwen2.5-1.5B instead: + +### Practicality test +| | Average time | Practicality metric | +|-------------|--------------|---------------------| +| ZeRO2 | 240ms | | +| ZeRO Offload with DeepSpeed core binding | 1365ms | 17.6% | +| DeepSpeed core binding + new ZenFlow worker core binding | 569ms | 42.2% | + +**Model: Qwen2.5-1.5B** + +**Test environment:** 2xDGX-A100-SXM4-40GB, 2xAMD EPYC 7742 64-Core Processor, 1TB memory + +Based on the tests conducted on 2xA100 GPUs, the practicality metric for ZeRO Offload was 17.6%, while ZenFlow achieved a practicality metric of 42.2%. This result demonstrates that ZenFlow significantly improves the practicality of offloading. + +## Summary +ZeRO Offload is an effective technique for reducing VRAM pressure, making the fine-tuning of large models possible. We have now seen that ZenFlow, as a new technology, achieves a breakthrough improvement in the practicality of ZeRO Offload, bringing it to a usable level. When combined with DeepSpeed's core binding, ZenFlow is able to deliver its optimal performance. + +## Disclaimer +All performance data presented in this article is measured for the sole purpose of discussing the effects of specific optimization techniques. There is no guarantee that the data was obtained under optimal software or hardware configurations, nor does it represent a performance evaluation of any software or hardware products mentioned. This article discusses only the relative performance changes resulting from specific optimization methods. The performance gain depends on specific software or hardware configuration and may vary in your own run. diff --git a/blogs/zeropp/assets/images/eval1.png b/blogs/zeropp/assets/images/eval1.png new file mode 100644 index 000000000000..8312c1db6de1 Binary files /dev/null and b/blogs/zeropp/assets/images/eval1.png differ diff --git a/blogs/zeropp/assets/images/eval2.png b/blogs/zeropp/assets/images/eval2.png new file mode 100644 index 000000000000..b6fd05f8cd98 Binary files /dev/null and b/blogs/zeropp/assets/images/eval2.png differ diff --git a/blogs/zeropp/assets/images/eval3.png b/blogs/zeropp/assets/images/eval3.png new file mode 100644 index 000000000000..4675e2041d84 Binary files /dev/null and b/blogs/zeropp/assets/images/eval3.png differ diff --git a/blogs/zeropp/assets/images/hpz.png b/blogs/zeropp/assets/images/hpz.png new file mode 100644 index 000000000000..790903cff68b Binary files /dev/null and b/blogs/zeropp/assets/images/hpz.png differ diff --git a/blogs/zeropp/assets/images/overview.png b/blogs/zeropp/assets/images/overview.png new file mode 100644 index 000000000000..8e261b533528 Binary files /dev/null and b/blogs/zeropp/assets/images/overview.png differ diff --git a/blogs/zeropp/assets/images/qgz.gif b/blogs/zeropp/assets/images/qgz.gif new file mode 100644 index 000000000000..90716d325a04 Binary files /dev/null and b/blogs/zeropp/assets/images/qgz.gif differ diff --git a/blogs/zeropp/assets/images/qwz.png b/blogs/zeropp/assets/images/qwz.png new file mode 100644 index 000000000000..ae68c322668f Binary files /dev/null and b/blogs/zeropp/assets/images/qwz.png differ diff --git a/blogs/zeropp/assets/images/rlhf-eval.png b/blogs/zeropp/assets/images/rlhf-eval.png new file mode 100644 index 000000000000..d9b1f3d272c1 Binary files /dev/null and b/blogs/zeropp/assets/images/rlhf-eval.png differ diff --git a/blogs/zeropp/assets/images/zero-overview.gif b/blogs/zeropp/assets/images/zero-overview.gif new file mode 100644 index 000000000000..65051947f79d Binary files /dev/null and b/blogs/zeropp/assets/images/zero-overview.gif differ diff --git a/blogs/zeropp/chinese/README.md b/blogs/zeropp/chinese/README.md new file mode 100644 index 000000000000..7f35fe619140 --- /dev/null +++ b/blogs/zeropp/chinese/README.md @@ -0,0 +1,185 @@ +
+ +# DeepSpeed ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率 + +
+
+ + + +图1: DeepSpeed ZeRO++ 简介 +
+ +大型 AI 模型正在改变数字世界。基于大型语言模型 (LLM)的 Turing-NLG、ChatGPT 和 GPT-4 等生成语言模型用途广泛,能够执行摘要、代码生成和翻译等任务。 同样,DALL·E、Microsoft Designer 和 Bing Image Creator 等大型多模态生成模型可以生成艺术、建筑、视频和其他数字资产,使内容创作者、建筑师和工程师能够探索全新的创意生产力。\ +\ +然而,训练这些大型模型需要在数百甚至数千个 GPU 设备上使用大量内存和计算资源。 例如,训练 [Megatron-Turing NLG 530B](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/)模型需要使用超过 4,000 个 NVidia A100 GPU。 有效地利用这些资源需要一个复杂的优化系统,以将模型合理分配到各个设备的内存中,并有效地并行化这些设备上的计算。 同时,为了使深度学习社区能够轻松进行大型模型训练,这些优化必须易于使用。 + +DeepSpeed 的 ZeRO [优化系列](https://www.deepspeed.ai/tutorials/zero/)为这些挑战提供了强大的解决方案,并已广泛用于大型深度学习模型例如TNLG-17B、Bloom-176B、MPT-7B、Jurrasic-1的训练中 。尽管它具有变革性的能力 ,在一些关键场景中,ZeRO 会在 GPU 之间产生大量数据传输开销,这降低了训练效率。 这种情况特别发生在以下场景中:a) 全局batch size较小,而 GPU数量多,这导致每个 GPU 上batch size较小,需要频繁通信;或者 b) 在低端集群上进行训练,其中跨节点网络带宽有限,导致高通信延迟。在这些情况下,ZeRO 的训练效率会受到限制。 + +为了解决这些限制,我们发布了 [ZeRO++](https://arxiv.org/abs/2306.10209) 。 ZeRO++相比 ZeRO将总通信量减少了 4 倍,而不会影响模型质量。 这有两个关键意义: + +1. *ZeRO++加速大型模型预训练和微调* + 1. 每个GPU上 batch size较小时: 无论是在数千个 GPU 上预训练大型模型,还是在数百个甚至数十个 GPU 上对其进行微调,当每个 GPU 的batch size较小时,ZeRO++ 提供比 ZeRO 高 2.2 倍的吞吐量,直接减少训练时间和成本。 + 2. 低带宽计算集群: ZeRO++ 使低带宽集群能够实现与带宽高 4 倍的高端集群类似的吞吐量。 因此,ZeRO++ 可以跨更广泛的集群进行高效的大型模型训练。 + +2. *ZeRO++加速 ChatGPT 类的 RLHF训练* + + 1. 虽然 ZeRO++ 主要是为训练而设计的,但它的优化也自动适用于 [ZeRO-Inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html#:~:text=ZeRO-Inference%20adapts%20and%20optimizes%20ZeRO-Infinity%20techniques%20for%20model,memory%2C%20thus%20hosting%20no%20%28zero%29%20weights%20in%20GPU.),因为通信开销对于 ZeRO 的训练和推理同样适用。 因此,ZeRO++ 可以提高人类反馈强化学习 (RLHF) 等算法的效率,因为RLHF结合了训练和推理。 + + 2. 通过与 DeepSpeed-Chat 的集成,与原始 ZeRO 相比,ZeRO++ 可以将 RLHF 训练的生成阶段效率提高多达 2 倍,强化学习训练阶段效率提高多达 1.3 倍。 + +接下来,我们将更深入地解释 ZeRO 及其通信开销,并讨论 ZeRO++ 中为解决这些问题而进行的关键优化。 然后我们将展示 ZeRO++ 对不同模型大小、批量大小和带宽限制的训练吞吐量的影响。我们还将讨论 ZeRO++ 如何应用于 DeepSpeed-Chat,以加速使用 RLHF的对话模型的训练。 + +## ZeRO++详解 + +
+ + + +图2: ZeRO optimizer 工作流程图 +
+ +ZeRO 是数据并行(Data Parallelism)的一种内存高效版本,其中模型状态会被分割储存在所有 GPU 上,而不需要在训练期间使用基于gather/broadcas的通信进行复制和重建。这使 ZeRO 能够有效地利用所有设备的聚合 GPU 内存和计算力,同时提供简单易用的数据并行训练。\ +\ +假设模型大小为 M。在前向传播过程中,ZeRO 执行全收集/广播(all-gather/broadcast)操作以在需要之时为每个模型层收集参数(总共大小为 M)。 在向后传递中,ZeRO 对每一层的参数采用类似的通信模式来计算其局部梯度(总大小为 M)。 此外,ZeRO 在对每个局部梯度计算完毕后会立刻使用 reduce 或 reduce-scatter 通信进行平均和分割储存(总大小为 M)。 因此,ZeRO 总共有 3M 的通信量,平均分布在两个全收集/广播(all-gather/broadcast)和一个减少分散/减少(reduce-scatter/reduce)操作中。 + +为了减少这些通信开销,ZeRO++ 进行了三组通信优化,分别针对上述三个通信集合: + +
+ + + +图3:qwZ的分区量化图例 +
+ + +### ZeRO通信过程中的权重量化 (qwZ) + +首先,为了减少 all-gather 期间的参数通信量,我们采用权重量化在通信前将每个模型参数从 FP16(两个字节)动态缩小为 INT8(一个字节)数据类型,并在通信后对权重进行反量化。 然而,简单地对权重进行量化会降低模型训练的准确性。 为了保持良好的模型训练精度,我们采用分区量化,即对模型参数的每个子集进行独立量化。目前尚且没有针对分区量化的高性能现有实现。 因此,我们自行从头开始实现了一套高度优化的量化 CUDA 内核,与基本量化相比,精度提高 3 倍,速度提高 5 倍。 + +
+ + + +图4: 权重的分层分割存储(hpZ) +
+ + +### ZeRO模型权重的分层分割存储 (hpZ) + +其次,为了减少向后传递期间全收集(all-gather)权重的通信开销,我们用 GPU 内存进行通信。 更具体地说,我们不像在 ZeRO 中那样将整个模型权重分布在所有机器上,而是在每台机器中维护一个完整的模型副本。 以更高的内存开销为代价,这允许我们用机器内的模型权重全收集/广播(all-gather/broadcast)代替昂贵的跨机器全收集/广播(all-gather/broadcast),由于机器内通信带宽更高,这使得通信速度大幅提升。 + +
+ + + +图5: qgZ 端到端的工作流程 + +
+ +### ZeRO通信过程中梯度量化 (qgZ) + +第三,要降低梯度的reduce-scatter通信成本更具挑战性。 因为直接应用量化来减少通信量是不可行的。 即使我们使用分区量化来降低量化误差,梯度reduce也会累积并放大量化误差。 为了解决这个问题,我们只在通信之前量化梯度,但在任何reduce操作之前将它们反量化到原有精度。 为了有效地做到这一点,我们发明了一种名为 qgZ 的基于 all-to-all 的新型量化梯度通信范式,它在功能上等同于压缩的归约-分散(reduce-scatter)操作。 + +qgZ 旨在解决两个挑战:i) 如果我们简单地在 INT4/INT8 中实施 reduce-scatter 会导致显著精度损失,以及 ii) 在传统tree或ring-based reduce-scatter中使用量化需要一长串量化和反量化步骤,这直接导致误差积累和显著的延迟,即使我们在全精度上进行reduce。为了解决这两个挑战,qgZ 不使用tree或ring-based reduce-scatter算法,而是基于一种新颖的分层 all-to-all 方法。 + +qgZ 中有三个主要步骤:i)梯度切片重新排序,ii)节点内通信和reduce,以及 iii)节点间通信和reduce。 首先,在任何通信发生之前,我们对梯度进行切片并对张量切片重新排序,以保证通信结束时每个 GPU 上的最终梯度位置(即图 5 中的绿色块)是正确的。 其次,我们量化重新排序的梯度切片,在每个节点内进行 all-to-all 通信,从 all-to-all 中对接收到的梯度切片进行反量化,并进行局部reduce。 第三,我们再次量化局部reduce后的梯度,进行节点间的all-to-all通信,再次对接收到的梯度进行反量化,并计算最终的高精度梯度reduce,得到图5中绿色块的结果。\ +\ +这种分层方法的原因是为了减少跨节点通信量。 更准确地说,给定每个节点 N 个 GPU、M 的模型大小和 Z 的量化比率,单跳 all-to-all 将生成 M\*N/Z 跨节点流量。 相比之下,通过这种分层方法,我们将每个 GPU 的跨节点流量从 M/Z 减少到 M/(Z\*N)。 因此,总通信量从 M\*N/Z 减少到 M\*N/(Z\*N) = M/Z。 我们通过重叠节点内和节点间通信以及融合 CUDA 内核来进一步优化 qgZ 的端到端延迟(张量切片重新排序 (Tensor Slice Reordering)+ 节点内量化(Intra-node quantization))和(节点内反量化 (Intra-node Dequantization) + 节点内梯度整合 (Intra-node Reduction) + 节点间量化(inter-node quantization))。 + +
+ +| Communication Volume | Forward all-gather on weights | Backward all-gather on weights | Backward reduce-scatter on gradients | Total | +|:---------------------------:|:------------------------------------:|:-------------------------------------:|:-------------------------------------------:|:------------:| +| ZeRO | M | M | M | 3M | +| ZeRO++ | 0.5M | 0 | 0.25M | 0.75M | + +
+ +### **通信总量优化** + +通过结合以上所有三个组件,我们将跨节点通信量从 3M 减少到 0.75M。 更具体地说,我们使用 qwZ 将模型权重的前向全收集/广播从 M 减少到 0.5M。 我们使用 hpZ 消除了反向传播期间的跨节点 all-gather,将通信从 M 减少到 0。最后,我们使用 qgZ 将反向传播期间的跨节点 reduce-scatter 通信从 M 减少到 0.25M。 + +## **ZeRO++ 加速大型语言模型训练** + +在这里,我们展示了 ZeRO++ 在 384 个 Nvidia V100 GPU 上的真实 LLM 训练场景的测试结果。 + +
+ + + +图6: 在 384 个 V100 GPU 上的各种模型大小下 ZeRO++ 与 ZeRO 的吞吐量,节点间使用 4 个 Infiniband (IB) 进行互连,每个以 100 Gbps 运行。 + +
+ +### **在GPU小batch size情况下ZeRO++实现更高的训练效率** + +**高带宽集群:** 如图 6 所示,我们首先展示了 ZeRO++ 相对于 ZeRO 的吞吐量改进,针对不同的模型大小和微批量(micro-batch size)大小,测试使用 4x Infiniband (IB) 以实现 400Gbps 跨节点互连带宽,每个以 100Gbps 运行。 在 micro-batch size为每 GPU 1k tokens时,ZeRO++ 比 ZeRO-3 的吞吐量提高了 28% 到 36%。 对于 2k tokens micro-batch size大小,ZeRO++ 比 ZeRO-3 实现了 24% 到 29% 的吞吐量增益。 + +
+ + + + +图7: 在 384 个 V00 GPU 上 100Gbps 跨节点带宽时各种 LLM 的吞吐量 + +
+ +**低带宽集群:** 在 100Gbps等低带宽网络环境中,ZeRO++ 的性能明显优于 ZeRO-3。 如图 7 所示,与 ZeRO-3 相比,ZeRO++ 在端到端吞吐量方面实现了高达 2.2 倍的加速。 平均而言,ZeRO++ 比 ZeRO-3 基线实现了大约 2 倍的加速。 + +
+ + + + +图8: ZeRO++ 以显着降低的带宽实现高带宽集群性能 + +
+ +### **实现高带宽ZeRO和低带宽ZeRO++集群之间的模型训练效率等效** + +此外,与 ZeRO 在高得多的带宽环境下相比,ZeRO ++ 可以在低带宽集群中实现相当的系统吞吐量。 如图 8 所示,对于 18B 和 138B 模型大小,具有 200Gbps 跨节点带宽的 ZeRO++ 可以达到与 800Gbps 跨节点带宽的 ZeRO-3 相似的 TFLOP。 + +鉴于 ZeRO++ 出色的可扩展性,我们将 ZeRO++ 视为用于训练大型 AI 模型的下一代 ZeRO。 + +## **DeepSpeed-Chat 与ZeRO++结合用于 RLHF 训练** + +### **RLHF训练简介** + +ChatGPT 类模型由 LLM 提供支持,并[使用 RLHF 进行微调](https://openai.com/blog/chatgpt)。 RLHF 由生成(推理)阶段和训练阶段组成。 在生成阶段,演员(actor)模型将部分对话作为输入,并使用一系列前向传递生成响应。 然后在训练阶段,评论(critic)模型根据质量对生成的响应进行排名,为演员模型提供强化信号。 使用这些排名对参与者模型进行微调,使其能够在后续迭代中生成更准确和适当的响应。 + +RLHF 训练带来了巨大的内存压力,因为它使用了四种模型(演员、参考、评论、奖励)。 常见的解决方案是采用低秩自适应训练 (LoRA) 来解决 RLHF 的内存压力。 LoRA 冻结了预训练模型的权重,并将可训练的秩分解矩阵注入到 Transformer 架构的每一层中,显着减少了可训练参数的数量。 LoRA 通过减少内存使用来加速 RLHF,允许更大的批处理(batch)大小,从而大大提高吞吐量。 + +### **DeepSpeed-Chat with ZeRO++ 用于 RLHF 训练** + +
+ + + + +图9: ZeRO++ 加速了 RLHF 训练的生成和训练阶段 + +
+ +ZeRO++在RLHF + LoRA的场景下有着独特的应用,因为大多数模型权重都被冻结了。 这意味着 ZeRO++ 可以将这些冻结的权重量化保存到INT4/8 中,而不是将它们存储在 fp16 中并在每次通信操作之前对其进行量化。 通信后的反量化仍然是为了让权重为计算做好准备,但反量化后的权重在计算后被简单地丢弃。 + +以这种方式使用 ZeRO++ 进行 RLHF 训练可以减少内存使用和通信量。 这意味着通过减少通信以及由于减少内存使用而启用更大的批处理大小来提高训练吞吐量。 在生成阶段,ZeRO++ 使用 hpZ 将所有权重通信保持在每个节点内,以利用更高的节点内通信带宽,减少通信量,进一步提高生成吞吐量。\ +\ +ZeRO++ 已集成到 DeepSpeed-Chat 中,以支持 ChatGPT 类模型的 RLHF 训练。 在图 9 中,我们比较了不同大小的 actor 模型的 RLHF 生成吞吐量。测试配置为 32个V100 GPU ,actor 模型大小为30B 和 66B以测试 ZeRO 和 ZeRO++性能。 结果表明,ZeRO++ 的 RLHF 生成吞吐量比 ZeRO 高出 2.25 倍。 我们还展示了在 16 个 V100 GPU 上训练阶段的加速,其中 ZeRO++ 实现了比 ZeRO 高 1.26 倍的吞吐量,这是由于 ZeRO++ 支持的更低通信量和更大批量大小。 + +## **DeepSpeed ZeRO++现已发布!** + +我们非常高兴能够发布 DeepSpeed ZeRO++ 并让 AI 社区中的每个人都可以使用它。请访问我们的 GitHub 页面以获取 [LLM训练教程](https://www.deepspeed.ai/tutorials/zeropp/)。 用于 DeepSpeed-Chat 的 ZeRO++ 将在未来几周内发布。\ +有关 ZeRO++ 的更多技术细节,请查看我们的[arxiv论文](https://arxiv.org/pdf/2306.10209.pdf)。 + +DeepSpeed-ZeRO++ 是 DeepSpeed 生态系统的一部分。 要了解更多信息,请访问我们的网站,在那里您可以找到详细的博客文章、教程和有用的文档。 + +您还可以在我们的[英文 Twitter](https://twitter.com/DeepSpeedAI)、[日文 Twitter](https://twitter.com/DeepSpeedAI_JP) 和[中文知乎](https://www.zhihu.com/people/deepspeed) 上获取最新的 DeepSpeed 新闻。 + +DeepSpeed 欢迎您的贡献! 我们鼓励您在 DeepSpeed GitHub 页面上报告问题、贡献 PR 并加入讨论。 有关更多详细信息,请参阅我们的贡献指南。 我们对与大学、研究实验室和公司的合作持开放态度。 对于此类请求(以及其他不适合 GitHub 的请求),请直接发送电子邮件至 。 + +**贡献者:** + +DeepSpeed 团队的以下人员的贡献使该项目成为可能: + +[Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), Heyang Qin, Sam Ade Jacobs, Connor Holmes, [Samyam Rajbhandari](https://www.microsoft.com/en-us/research/people/samyamr/), [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/), Ammar Ahmad Awan, Jeff Rasley, Michael Wyatt, [Yuxiong He](https://www.microsoft.com/en-us/research/people/yuxhe/) (team lead) diff --git a/blogs/zeropp/japanese/README.md b/blogs/zeropp/japanese/README.md new file mode 100644 index 000000000000..6f7949719981 --- /dev/null +++ b/blogs/zeropp/japanese/README.md @@ -0,0 +1,186 @@ +
+ +# DeepSpeed ZeRO++: LLMやチャットモデルの訓練を劇的に高速化 – 通信オーバヘッドを1/4に大幅削減 - + +
+
+ + + +図1: DeepSpeed ZeRO++ の概要 +
+ +大規模AIモデルは、まさに今デジタルの世界を変革しつつあります。大規模言語モデル(Large Language Model, LLM)を搭載したTuring-NLG、ChatGPT、GPT-4のような生成言語モデルは、驚くほど汎用性が高く、要約、コーディング、翻訳のようなタスクを実行できます。同様に、DALL·E、Microsoft Designer、Bing Image Creatorのような大規模なマルチモーダル生成モデルは、アート、建築、ビデオ、その他のデジタルアセットを生成することができ、コンテンツクリエイター、建築家、エンジニアがクリエイティブな生産性を発揮し、新たなフロンティアを開拓する力をもたらしています。 + +しかし、これらの大規模なモデルを訓練するには、何百、何千ものGPUデバイスを使用した膨大なメモリとコンピューティングリソースが必要です。例えば、[Megatron-Turing NLG 530Bモデル](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/)の訓練には、4,000以上のNVidia A100 GPUが使用されました。これらのリソースを効率的に活用するには、モデルを個々のGPUデバイスのメモリに収まるように分割し、これらのデバイス間で効率的に並列計算を行うための、複雑な最適化システムが必要になります。同時に、大規模なモデル学習をユーザーが容易に利用できるようにするには、そうした最適化が簡単に適用できる必要があります。 + +DeepSpeedが提供する[ZeRO](https://www.deepspeed.ai/tutorials/zero/)と呼ばれる一連の最適化技術は、これらの課題に対する強力なソリューションを提供し、大規模で強力な深層学習モデルであるTNLG-17B、Bloom-176B、MPT-7B、Jurrasic-1などの訓練に広く使用されています。ZeROはそうした強力な機能を持つ一方で、いくつかの利用シナリオでは、GPU間のデータ転送のオーバーヘッドが大きくなり、高い学習効率を達成することが難しいことがあります。これは特に、a) (グローバル)バッチサイズに対して多数のGPUで訓練するため、GPUごとのバッチサイズが小さくなり、頻繁な通信が必要になる場合 b) ローエンドの計算クラスタで訓練する際、ノード間のネットワーク帯域幅が十分ではなく、通信待ち時間が長くなる場合 に発生します。これらのシナリオでは、ZeROの使いやすさと計算効率という利点が十分に発揮できません。 + +今回リリースする[ZeRO++](https://arxiv.org/abs/2306.10209)は、ZeROの通信を最適化することで、こうした問題を解決するシステムです。バッチサイズの制限やデバイス間の帯域幅の制約に関係なく、大規模モデルの訓練で極めて高い効率を実現します。ZeRO++は、量子化および通信とデータの再マッピングを組み合わせることで、モデルの品質に影響を与えることなく、ZeROと比較して総通信量を4分の1に削減します。これにより、以下に示す2つの重要な効果が得られます。 + + +1. *大規模モデルの事前学習・ファインチューニングの高速化* + 1. GPUあたりのバッチサイズが小さい: 数千のGPUで大規模モデルを事前学習する場合でも、数百または数十のGPUでモデルをファインチューニングする場合でも、GPUあたりのバッチサイズが小さい場合、ZeRO++はZeROに比べて最大2.2倍のスループットを提供し、訓練時間とコストを削減します。 + + 2. 低帯域幅クラスタ: ZeRO++では、帯域幅の小さいクラスタでも、4倍の帯域幅を持つクラスタと同等のスループットを達成できます。そのため、ZeRO++を使用すれば、さまざまなクラスタで効率的な大規模モデルの訓練が可能になります。 + +2. *RLHFによるChatGPTライクなモデルの訓練の高速化* + + 1. ZeRO++は主に訓練の高速化を目的に設計されていますが、通信オーバーヘッドは、ZeROを用いた訓練と推論に共通の課題であるため、ZeRO++の最適化は、推論のための機構である[ZeRO-Inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html#:~:text=ZeRO-Inference%20adapts%20and%20optimizes%20ZeRO-Infinity%20techniques%20for%20model,memory,%20thus%20hosting%20no%20(zero)%20weights%20in%20GPU.)でも有効です。その結果、ZeRO++は、対話モデルの推論に使用される、人間のフィードバックからの強化学習(RLHF)のような、訓練と推論の両方を組み合わせたワークロードの効率を向上させます。 + + 2. DeepSpeed-Chatとの統合により、ZeRO++はオリジナルのZeROと比較して、RLHF訓練の生成フェーズを最大2倍、訓練フェーズを最大1.3倍高速化することができます。 + +次に、ZeROとその通信オーバーヘッドについて詳しく掘り下げた上で、ZeRO++における主要な最適化について説明します。また、モデルサイズ、バッチサイズ、帯域幅の制約を変えて、ZeRO++が訓練の実行速度に与える影響も実証します。また、ZeRO++をDeepSpeed-Chatに適用して、RLHFを使用した対話モデルの学習を高速化する方法についても説明します。 + +## ZeRO++の詳細 + +
+ + + +図2: ZeROによる最適化 +
+ +ZeROは、データ並列のメモリ効率を向上させた技術であり、モデルの状態を全てのGPUに複製する代わりに、GPUごとに分割し、訓練中にgather/broadcastといった集合通信を必要になる都度実行して、分割されたモデル状態を再構築します。これにより、ZeROは、データ並列のシンプルさ・使いやすさを保ちつつ、すべてのGPUデバイスのメモリと計算を集約して、効果的に活用することができます。 + +順伝播(forward)の計算では、ZeROはallgather/broadcast通信によって、モデルの各レイヤーのパラメータを、使用する直前に収集します(パラメータの合計のサイズをMとします)。逆伝播(backward)では、ZeRO は各レイヤーのパラメータについて同様の通信パターンによって、各GPU上でローカルに勾配を計算します(勾配の合計サイズは同じく Mになります)。さらに、ZeROは、ローカルに計算された勾配を、reduceまたはreduce-scatter通信(合計サイズM)を使用して平均化し、分割します。2回のallgather/broadcast、及び1回のreduceまたはreduce-scatter通信で、合計の通信データサイズは3Mになります。 + +これらの通信オーバーヘッドを削減するために、ZeRO++では、上記の3回の通信を対象とした一連の最適化技術を実現しました: + +
+ + + +図3: qwZにおけるブロックベース量子化 +
+ + +### パラメータの量子化と通信 (qwZ) + +まず、allgather時のパラメータの通信量を削減するために、パラメータの量子化を使用します。通信の直前に各モデルパラメータをFP16(2バイト)からINT8(1バイト)データ型に変換し、通信後に元に戻します。しかし、単純にパラメータの量子化を行うと、モデルの学習精度が低下する可能性があります。そこで、モデルの学習精度を維持するために、モデルパラメータの各サブセットに対して、独立した量子化を行うブロックベースの量子化を採用しています。これまでに、高性能なブロックベース量子化の実装は存在しなかったため、ZeRO++のために、高度に最適化された量子化CUDAカーネルをゼロから実装し、基本的な量子化と比較して、3倍の高精度と、5倍の高速化を実現しました。 + +
+ + + +図4: hpZにおける階層的なパラメータの分割 +
+ + +### ZeROのための階層的なパラメータの分割 (hpZ) + +次に、逆伝播において、GPUメモリの必要サイズの増加と引き換えに、パラメータのallgatherの通信オーバヘッドを削減します。具体的には、ZeROのようにモデル全体のパラメータを全てのサーバのGPUデバイスに分散させるのではなく、各サーバごとに完全なモデルのコピーを保持します。これにより、必要メモリサイズは増加しますが、一般に通信帯域幅が限られるサーバ間でのallgather/broadcastではなく、通信帯域幅の大きいサーバ内通信によるallgather/broadcastのみを使用することになり、大幅に高速化できます。 + +
+ + + +図5: qgZの処理の流れ + +
+ +### 勾配の量子化と通信 (qgZ) + +次に取り上げる、reduce-scatterを使った勾配の通信コストの削減は、上述の他の課題よりさらに困難です。通信量を減らすために単純に量子化を適用すると、ブロックベースの量子化を使用したとしても、reduceでの加算の過程で誤差が累積されてしまいます。そこで我々は、勾配を送信前に量子化し、受信後、reduceでの加算の前に量子化を解除します。これを効率的に行うために、我々はqgZと呼ばれるall-to-allベースの新しい量子化勾配通信パラダイムを考案しました。 + +qgZは、次の2つの課題を解決するために設計されています。i) 単純にINT4/INT8でreduce-scatterを実装した場合、reduceを低精度で計算することによって生じる大幅な精度低下を克服すること、及び ii) (元の精度でreduce-scatterを行う場合でも)リングベースまたはツリーベースの従来のreduce-scatterにおいて、量子化と復元の一連の処理から生じる精度低下と大幅なレイテンシオーバーヘッドを回避すること です。qgZは、リングベースまたはツリーベースの散布度削減アルゴリズムの代わりに、新しい階層的なall-to-all通信によるアプローチを用います。 + +qgZには3つの主要なステップがあります:i) 勾配スライスの並べ替え、ii) ノード内通信と加算、iii) ノード間通信と加算。まず、通信が行われる前に、勾配テンソルのスライスと、スライスの並べ替えを行い、通信終了時に各GPU上で正しい勾配の配置(図5の緑色の勾配のスライス)が得られるようにします。第2に、並べ替えられた勾配スライスを量子化し、各ノード内でall-to-all通信を行います。all-to-allから受信した勾配スライスは、量子化から復元され、ローカルでreduction(加算)の計算を行います。第3に、ローカルでreductionされた勾配を再び量子化し、ノード間で全ノード間通信を行います。受信した勾配を再び量子化から復元し、元の精度でreductionの計算を行い、図5の緑の勾配のスライスを得ます。 + +このような階層的なアプローチをとる理由は、ノード間の通信量を削減するためです。より正確には、ノードあたりN個のGPU、モデルサイズM、および量子化の比率Zが与えられた場合、シングルホップのall-to-all通信では、M*N/Z個のノード間通信が発生します。これに対し、この階層的アプローチでは、各GPUのノード間通信をM/ZからM/(Z*N)に減らすことができます。したがって、総通信量はM*N/ZからM*N/(Z*N)=M/Zに減少します。さらに、ノード内通信とノード間通信をオーバーラップさせ、(テンソルスライス並べ替え+ノード内量子化)と(ノード内非量子化+ノード内加算+ノード間量子化)のCUDAカーネルを融合させることで、qgZのend-to-endのレイテンシを最適化します。 + +
+ +| Communication Volume | Forward all-gather on weights | Backward all-gather on weights | Backward reduce-scatter on gradients | Total | +|:---------------------------:|:------------------------------------:|:-------------------------------------:|:-------------------------------------------:|:------------:| +| ZeRO | M | M | M | 3M | +| ZeRO++ | 0.5M | 0 | 0.25M | 0.75M | + +
+ +### **通信量の削減** + +上述の3つの最適化技術をすべて組み込むことで、ノード間の通信量を3Mから0.75Mに減らすことができます。具体的には、qwZを用いて、モデルパラメータに関する順伝播のallgather/broadcast通信をMから0.5Mに削減します。また、qgZを使用して、逆伝播のノード間のreduce-scatter通信をMから0.25Mに削減します。 + +## **ZeRO++によるLLM訓練の高速化** + +ここでは、384台のNVIDIA V100 GPUを使用した、実際のLLM訓練シナリオでのZeRO++の評価結果を示します。 + +
+ + + +図6: 様々なモデルサイズでのZeRO++とZeROの速度の比較(384台のV100 GPU、400Gbps (100Gbps×4) のノード間接続) + +
+ +### **GPUあたりのバッチサイズが小さい場合でも高い効率を実現** + +**高帯域幅クラスタ:** 図6は、それぞれ100Gbpsで動作する4つのインフィニバンド(IB)接続を使用した400Gbpsノード間接続で、異なるモデルサイズとマイクロバッチサイズについて、ZeRO++のスループットがZeROを上回ったことを示しています。GPUあたり1kトークンを使用した場合、ZeRO++はZeRO-3に対して28%から36%のスループット向上を達成しました。マイクロバッチサイズが2kの場合では、ZeRO++はZeRO-3に対して24%から29%のスループット向上を達成しています。 + +
+ + + + +図7: 異なるサイズのLLMのスループット比較(384台のGPU・100Gbpsのノード間接続) +
+ +**低帯域幅クラスタ:** 100Gbpsネットワークのような低速なネットワーク環境では、ZeRO++は大幅に優れた性能を発揮します。図 7 に示すように、ZeRO++ は ZeRO-3 と比較して、end-to-endのスループットで最大 2.2 倍の高速化を達成しています。平均して、ZeRO++はZeRO-3をベースラインとして、約2倍の高速化を達成しています。 + +
+ + + + +図8: ZeRO++により、低い帯域幅のクラスタでも、ZeROを高い帯域幅のクラスタで使用した場合と同等の性能を実現 + +
+ +### **低帯域幅クラスタでも高帯域幅クラスタで従来技術と用いたのと同様の効率を実現** + +さらに、ZeRO ++は、低帯域幅クラスタで、はるかに高い帯域幅クラスタでのZeROを使用した場合と比較して、同等のシステムスループットを達成できます。図8に示すように、18Bと138Bの両モデルで、200Gbpsノード間通信が可能な環境でのZeRO++は、800Gbpsノード間通信が可能な環境のZeRO-3と同等のTFLOPを達成できます。その優れたスケーラビリティから、ZeRO++は大規模AIモデルを訓練するための次世代のZeROと位置付けられます。 + +## **DeepSpeed-Chatを用いたRLHF訓練におけるZeRO++の適用** + +### **RLHF訓練の背景** + +ChatGPTのようなモデルは、LLMの学習と、[RLHFによるファインチューニング](https://openai.com/blog/chatgpt)によって構築されます。RLHFは生成(推論)フェーズと学習フェーズから構成されます。生成フェーズでは、アクターモデルが部分的な会話を入力とし、一連の順伝播の計算を用いて応答を生成します。そして訓練フェーズでは、クリティックモデルが生成された応答を品質によってランク付けし、アクターモデルに強化信号を与えます。アクターモデルはこれらのランク付けを用いてファインチューニングされ、その後の反復においてより正確で適切な応答を生成できるようになります。 + +RLHFトレーニングは4つのモデル(アクター、リファレンス、クリティック、リウォード)を利用するため、きわめて大きなメモリが必要となります。この問題に対処するため、低ランク適応(LoRA)を採用しています。LoRAは事前学習されたモデルのパラメータを固定し、学習可能なランク分解行列をTransformerアーキテクチャの各層に追加することで、学習可能なパラメータ数を大幅に削減することができます。LoRAを用いてメモリ使用量を削減することでRLHFを高速化し、より大きなバッチサイズでの計算が可能になり、スループットを大幅に向上できます。 + +### **RLHF訓練のためのDeepSpeed-ChatへのZeRO++の適用** + +
+ + + + +図9: ZeRO++によりRLHF訓練の生成フェーズと訓練フェーズの両方を高速化 + +
+ +LoRAを使用する場合、RLHFでは、ほとんどのモデルパラメータが固定されています。ZeRO++は、この特徴を利用した特別な機能を提供しています。ZeRO++は通常、固定されたパラメータをFP16で保持し、各通信操作の前に量子化します。RLHFではその代わりに、前もってINT4/8に量子化しておくことができます。通信後の量子化からの復元は必要ですが、復元されたパラメータは、それを使用する計算が終わった後に破棄されます。 + +このようにZeRO++をRLHF訓練に使用することで、メモリ使用量と通信量の両方を削減できます。通信量だけでなく、メモリ使用量が削減されるため、バッチサイズが大きくすることができ、訓練のスループットが向上します。生成フェーズでは、ZeRO++はhpZを使用してすべてのパラメータの通信を各ノード内で行うようにし、通信量を削減しながらノード内の高い通信帯域幅を利用することで、生成スループットをさらに向上させます。 + +ZeRO++はDeepSpeed-Chatに統合され、ChatGPTライクなモデルのRLHF訓練を強力にサポートします。図 9 では、32 個の V100 GPU 上で、30B および 66B のアクターモデルについて、ZeRO と ZeRO++ を比較し、アクターモデルのサイズが異なる場合の RLHF 生成のスループットを比較しています。その結果、ZeRO++はZeROよりもRLHF生成スループットが最大2.25倍向上することが確認されました。また、16個のV100 GPU上での訓練フェーズでは、ZeRO++によって可能になった通信量の低減とバッチサイズの拡大により、ZeRO++はZeROよりも1.26倍優れたスループットを達成しています。 + +## **早速試してみましょう!** + +DeepSpeed ZeRO++をリリースし、AIコミュニティの誰もが利用できるようになることを大変嬉しく思っています。まずは、LLM訓練の[チュートリアル](https://www.deepspeed.ai/tutorials/zeropp/)をご覧ください。ZeRO++ for DeepSpeed-Chatは数週間以内にリリースされる予定です。 + +ZeRO++の技術的な詳細については、arXivにアップロードされた[論文](https://arxiv.org/pdf/2306.10209.pdf)をご覧ください。 + +DeepSpeed-ZeRO++は、DeepSpeedエコシステムの一部です。詳細については、我々の[Webサイト](https://www.deepspeed.ai/)をご覧ください。詳細なブログ記事、チュートリアル、ドキュメントが掲載されています。 + +また、[英語版Twitter](https://twitter.com/DeepSpeedAI)、[日本語版Twitter](https://twitter.com/DeepSpeedAI_JP)、[中国語版Zhihuアカウント](https://www.zhihu.com/people/deepspeed)でも最新のDeepSpeedニュースを発信しています。 + +DeepSpeedは、皆様の開発への参加を歓迎しています。DeepSpeedのGitHubページで、バグ報告、Pull Request、ディスカッションへの参加が可能です。詳細は[ガイドライン](https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md)をご覧ください。また、大学、研究所、企業とのコラボレーションも行っています。こうしたコラボレーションについてのご要望(およびGitHubには適さないその他の話題)については まで直接メールをお送りください。 + + +**Contributors:** + +このプロジェクトは、DeepSpeedチームの以下のメンバーによって実施されました。 + +[Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), Heyang Qin, Sam Ade Jacobs, Connor Holmes, [Samyam Rajbhandari](https://www.microsoft.com/en-us/research/people/samyamr/), [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/), Ammar Ahmad Awan, Jeff Rasley, Michael Wyatt, [Yuxiong He](https://www.microsoft.com/en-us/research/people/yuxhe/) (team lead) diff --git a/build_win.bat b/build_win.bat index ec8c8a362a78..64ba99633d50 100644 --- a/build_win.bat +++ b/build_win.bat @@ -1,19 +1,17 @@ @echo off +set CUDA_HOME=%CUDA_PATH% +set DISTUTILS_USE_SDK=1 + set DS_BUILD_AIO=0 +set DS_BUILD_CUTLASS_OPS=0 +set DS_BUILD_EVOFORMER_ATTN=0 +set DS_BUILD_FP_QUANTIZER=0 +set DS_BUILD_GDS=0 +set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 +set DS_BUILD_DEEP_COMPILE=0 -echo Administrative permissions required. Detecting permissions... - -net session >nul 2>&1 -if %errorLevel% == 0 ( - echo Success: Administrative permissions confirmed. -) else ( - echo Failure: Current permissions inadequate. - goto end -) - - -python setup.py bdist_wheel +python -m build --wheel --no-isolation :end diff --git a/ci/__init__.py b/ci/__init__.py new file mode 100644 index 000000000000..6f5f5619004b --- /dev/null +++ b/ci/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/ci/accelerate.py b/ci/accelerate.py new file mode 100644 index 000000000000..21570a85f40b --- /dev/null +++ b/ci/accelerate.py @@ -0,0 +1,58 @@ +# Copyright (c) Snowflake. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parents[1] + +# yapf: disable +image = (modal.Image + .from_registry("pytorch/pytorch:2.9.1-cuda12.8-cudnn9-devel", add_python="3.10") + .apt_install("git") + .pip_install("uv") + # uv_pip_install already includes --compile-bytecode + .uv_pip_install("datasets==3.6.0", extra_options="--system") + .pip_install_from_requirements(ROOT_PATH / "requirements/requirements.txt", gpu="any") + .pip_install_from_requirements(ROOT_PATH / "requirements/requirements-dev.txt", gpu="any") + .add_local_dir(ROOT_PATH , remote_path="/root/", copy=True) + .run_commands("pip install /root") + .add_local_dir(ROOT_PATH / "accelerator", remote_path="/root/deepspeed/accelerator") + .add_local_dir(ROOT_PATH / "csrc", remote_path="/root/deepspeed/ops/csrc") + .add_local_dir(ROOT_PATH / "op_builder", remote_path="/root/deepspeed/ops/op_builder") + ) + +app = modal.App("deepspeedai-accelerate-ci", image=image) + +@app.function( + gpu="l40s:1", + timeout=1800, +) +def pytest(): + import subprocess + + cmd = "git clone https://github.com/huggingface/accelerate" + print(f"running: {cmd}") + subprocess.run( + cmd.split(), + check=True, + cwd=ROOT_PATH / ".", + ) + cmd = "uv pip install --system --compile-bytecode ./accelerate[testing]" + print(f"running: {cmd}") + subprocess.run( + cmd.split(), + check=True, + cwd=ROOT_PATH / ".", + ) + + cmd = "pytest ./accelerate/tests/deepspeed" + print(f"running: {cmd}") + subprocess.run( + cmd.split(), + check=True, + cwd=ROOT_PATH / ".", + ) diff --git a/ci/torch_latest.py b/ci/torch_latest.py new file mode 100644 index 000000000000..1d2d354c11f0 --- /dev/null +++ b/ci/torch_latest.py @@ -0,0 +1,181 @@ +# Copyright (c) Snowflake. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import shlex +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parents[1] +DEFAULT_MODAL_TORCH_PRESET = "2.10.0-cuda12.8" +DEFAULT_MODAL_TRANSFORMERS_SOURCE = "git" +MODAL_TORCH_PRESETS = { + "2.7.1-cuda12.8": { + "image": "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel", + "torch_package": "torch==2.7.1", + "torchvision_package": "torchvision==0.22.1", + "torch_test_version": "2.7", + "cuda_test_version": "12.8", + }, + "2.8.0-cuda12.8": { + "image": "pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel", + "torch_package": "torch==2.8.0", + "torchvision_package": "torchvision==0.23.0", + "torch_test_version": "2.8", + "cuda_test_version": "12.8", + }, + "2.9.1-cuda12.8": { + "image": "pytorch/pytorch:2.9.1-cuda12.8-cudnn9-devel", + "torch_package": "torch==2.9.1", + "torchvision_package": "torchvision==0.24.1", + "torch_test_version": "2.9", + "cuda_test_version": "12.8", + }, + "2.10.0-cuda12.8": { + "image": "pytorch/pytorch:2.10.0-cuda12.8-cudnn9-devel", + "torch_package": "torch==2.10.0", + "torchvision_package": "torchvision==0.25.0", + "torch_test_version": "2.10", + "cuda_test_version": "12.8", + }, + "2.11.0-cuda12.8": { + "image": "pytorch/pytorch:2.11.0-cuda12.8-cudnn9-devel", + "torch_package": "torch==2.11.0", + "torchvision_package": "torchvision==0.26.0", + "torch_test_version": "2.11", + "cuda_test_version": "12.8", + }, +} +PYTORCH_CUDA_128_INDEX_URL = "https://download.pytorch.org/whl/cu128" + + +def resolve_modal_torch_config(): + selected_preset = os.environ.get("MODAL_TORCH_PRESET") or DEFAULT_MODAL_TORCH_PRESET + try: + preset_config = MODAL_TORCH_PRESETS[selected_preset] + except KeyError as exc: + supported = ", ".join(sorted(MODAL_TORCH_PRESETS)) + raise ValueError(f"Unsupported MODAL_TORCH_PRESET={selected_preset!r}; supported values: {supported}") from exc + + return { + "preset": selected_preset, + **preset_config, + } + + +def resolve_modal_transformers_config(): + transformers_source = os.environ.get("MODAL_TRANSFORMERS_SOURCE") or DEFAULT_MODAL_TRANSFORMERS_SOURCE + supported_sources = {"requirements", "git"} + if transformers_source not in supported_sources: + supported = ", ".join(sorted(supported_sources)) + raise ValueError( + f"Unsupported MODAL_TRANSFORMERS_SOURCE={transformers_source!r}; supported values: {supported}") + + transformers_ref = os.environ.get("MODAL_TRANSFORMERS_REF") or "" + if transformers_source == "git" and not transformers_ref: + transformers_ref = "main" + + return { + "source": transformers_source, + "ref": transformers_ref, + } + + +def transformers_override_commands(): + if MODAL_TRANSFORMERS_CONFIG["source"] == "requirements": + return () + + transformers_ref = shlex.quote(MODAL_TRANSFORMERS_CONFIG["ref"]) + return ( + "rm -rf /tmp/transformers", + "git clone --filter=blob:none https://github.com/huggingface/transformers /tmp/transformers", + "cd /tmp/transformers && " + f"git checkout {transformers_ref} && " + "resolved_ref=$(git rev-parse HEAD) && " + "echo \"Resolved Transformers git ref: ${resolved_ref}\" && " + "pip install .", + ) + + +def torch_package_reinstall_command(): + command = [ + "pip", + "install", + "--force-reinstall", + "--no-cache-dir", + "--index-url", + PYTORCH_CUDA_128_INDEX_URL, + MODAL_TORCH_CONFIG["torch_package"], + MODAL_TORCH_CONFIG["torchvision_package"], + ] + return " ".join(shlex.quote(part) for part in command) + + +MODAL_TORCH_CONFIG = resolve_modal_torch_config() +MODAL_TRANSFORMERS_CONFIG = resolve_modal_transformers_config() +MODAL_TORCH_IMAGE = MODAL_TORCH_CONFIG["image"] +MODAL_TORCH_TEST_VERSION = MODAL_TORCH_CONFIG["torch_test_version"] +MODAL_CUDA_TEST_VERSION = MODAL_TORCH_CONFIG["cuda_test_version"] + +# yapf: disable +image = (modal.Image + .from_registry(MODAL_TORCH_IMAGE, add_python="3.10") + .env({ + "MODAL_TORCH_PRESET": MODAL_TORCH_CONFIG["preset"], + "MODAL_TRANSFORMERS_SOURCE": MODAL_TRANSFORMERS_CONFIG["source"], + "MODAL_TRANSFORMERS_REF": MODAL_TRANSFORMERS_CONFIG["ref"], + }) + .run_commands("apt update && apt install -y git libaio-dev") + .pip_install_from_requirements(ROOT_PATH / "requirements/requirements.txt", gpu="any") + .pip_install_from_requirements(ROOT_PATH / "requirements/requirements-dev.txt", gpu="any") + .pip_install_from_requirements(ROOT_PATH / "requirements/requirements-deepcompile.txt", gpu="any") + .run_commands(torch_package_reinstall_command()) + ) + +transformers_commands = transformers_override_commands() +if transformers_commands: + image = image.run_commands(*transformers_commands) + +image = (image + .add_local_dir(ROOT_PATH , remote_path="/root/", copy=True) + .run_commands("pip install /root") + .add_local_dir(ROOT_PATH / "accelerator", remote_path="/root/deepspeed/accelerator") + .add_local_dir(ROOT_PATH / "csrc", remote_path="/root/deepspeed/ops/csrc") + .add_local_dir(ROOT_PATH / "op_builder", remote_path="/root/deepspeed/ops/op_builder") + ) + + +app = modal.App("deepspeedai-torch-latest-ci", image=image) + + +@app.function( + gpu="l40s:2", + timeout=3600, +) +def pytest(): + import subprocess + + subprocess.run( + [ + "python", + "-c", + "import json, torch, torchvision, transformers; " + "print('Modal Python package versions: ' + json.dumps({" + "'torch': torch.__version__, " + "'torch_cuda': torch.version.cuda, " + "'torchvision': torchvision.__version__, " + "'transformers': transformers.__version__" + "}, sort_keys=True))", + ], + check=True, + cwd=ROOT_PATH / ".", + ) + subprocess.run( + f"pytest -n 4 --verbose tests/unit/v1/ --torch_ver={MODAL_TORCH_TEST_VERSION} " + f"--cuda_ver={MODAL_CUDA_TEST_VERSION}".split(), + check=True, + cwd=ROOT_PATH / ".", + ) diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 8eebe00349be..e276ad0856dd 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -5,53 +5,38 @@ #include "cpu_adagrad.h" #include +#include #include +#include #include #include #include -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif +using namespace std::string_literals; static std::unordered_map> s_optimizers; // C++ interface -void Adagrad_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_1(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) { float step_size = -1 * _alpha; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float grad = (float)grads[k]; + float param = (float)_params[k]; float momentum = grads[k]; float variance = _exp_avg_sq[k]; if (_weight_decay > 0) { grad = param * _weight_decay + grad; } @@ -62,47 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params, grad += _eps; grad = momentum / grad; param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; + _params[k] = param; // STORE UPDATE TERM TO GRAD'S MEMORY grads[k] = grad * step_size; _exp_avg_sq[k] = variance; } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - _buf_index = !_buf_index; - } -#endif } } } -void Adagrad_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_4(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); } int create_adagrad_optimizer(int optimizer_id, @@ -136,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id, return 0; } -void Adagrad_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) +template +void Adagrad_Optimizer::Step_8(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), (grads + rounded_size), (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg_sq, + size_t _param_size) +{ + opt->Step_8((ds_params_precision_t*)(_params), + (ds_params_precision_t*)(grads), + (ds_state_precision_t*)(_exp_avg_sq), + _param_size); +} + +std::map, + std::function, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Adagrad optimizer with param type "s + + c10::toString(params_type) + " and state type "s + + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size); } int ds_adagrad_step(int optimizer_id, @@ -170,58 +190,13 @@ int ds_adagrad_step(int optimizer_id, auto grads_c = grads.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous(); - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); opt->IncrementStep(step); opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel()); -#if defined(__ENABLE_CUDA__) - opt->SynchronizeStreams(); -#endif - return 0; -} + invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel()); -int ds_adagrad_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ -#if defined(__ENABLE_CUDA__) - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_sq_ptr, - params_c.numel(), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif return 0; } @@ -235,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); - m.def("adagrad_update_copy", - &ds_adagrad_step_plus_copy, - "DeepSpeed CPU Adagrad update and param copy (C++)"); m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); } diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 4d3d5a45e628..f4c242ff9229 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -4,306 +4,11 @@ // DeepSpeed Team #include "cpu_adam.h" -#include -#include -#include -#include -#include -#include - -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif - -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adam_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) { - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - - float step_size = -1 * _alpha / _bias_correction1; - float w_decay = -1 * _alpha * _weight_decay; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } - - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#endif -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * _betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; - grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#endif - } - } -} - -void Adam_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adam_optimizer(int optimizer_id, - float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true, - bool should_log = false) -{ - auto opt = - std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", - alpha, - betta1, - betta2, - weight_decay, - (int)adamw_mode); - } - - return 0; -} - -void Adam_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - nullptr, - (params.options().dtype() == at::kHalf)); - -#if defined(__ENABLE_CUDA__) - opt->SynchronizeStreams(); -#endif - return 0; -} - -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ -#if defined(__ENABLE_CUDA__) - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif - return 0; -} - -int destroy_adam_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); - m.def("adam_update_copy", - &ds_adam_step_plus_copy, - "DeepSpeed CPU Adam update and param copy (C++)"); + m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp new file mode 100644 index 000000000000..1f2b8cf0df47 --- /dev/null +++ b/csrc/adam/cpu_adam_impl.cpp @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include +#include +#include +#include "cpu_adam.h" + +using namespace std::string_literals; +static std::unordered_map> s_optimizers; + +// C++ interface + +template +void Adam_Optimizer::Step_1(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = (float)grads[k]; + float param = (float)_params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + } + } +} + +template +void Adam_Optimizer::Step_4(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size)); +} + +int create_adam_optimizer(int optimizer_id, + float alpha, + float betta1, + float betta2, + float eps, + float weight_decay, + bool adamw_mode, + bool should_log) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +template +void Adam_Optimizer::Step_8(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg, + void* _exp_avg_sq, + size_t _param_size) +{ + opt->Step_8((ds_params_precision_t*)(_params), + (ds_params_precision_t*)(grads), + (ds_state_precision_t*)(_exp_avg), + (ds_state_precision_t*)(_exp_avg_sq), + _param_size); +} + +std::map, + std::function, void*, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Adam optimizer with param type "s + c10::toString(params_type) + + " and state type "s + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, + params.data_ptr(), + grads.data_ptr(), + exp_avg.data_ptr(), + exp_avg_sq.data_ptr(), + param_size); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel()); + + return 0; +} + +void adamw_rollback_inplace(float* params, + const float* grads, + float* momentum, + float* variance, + size_t param_size, + float learning_rate, + float beta1, + float beta2, + float eps, + float weight_decay, + int& step_count) +{ + const float lr = learning_rate; + const float lambda = weight_decay; + const float beta1_pow = std::pow(beta1, step_count); + const float beta2_pow = std::pow(beta2, step_count); + const float one_minus_beta1 = 1.0f - beta1; + const float one_minus_beta2 = 1.0f - beta2; + const float lr_lambda = lr * lambda; + const float one_minus_lr_lambda = 1.0f - lr_lambda; + +#pragma omp parallel for + for (size_t i = 0; i < param_size; ++i) { + const float bias_correction1 = 1.0f - beta1_pow; + const float bias_correction2 = 1.0f - beta2_pow; + + const float m_hat = momentum[i] / bias_correction1; + const float v_hat = variance[i] / bias_correction2; + + const float denominator = std::sqrt(v_hat) + eps; + + // Rollback parameter update + const float update = lr * m_hat / denominator; + float new_param = (params[i] + update) / one_minus_lr_lambda; + + // Handle numerical instability + if (!std::isfinite(new_param)) { new_param = 0.0f; } + + params[i] = new_param; + + const float grad = grads[i]; + momentum[i] = (momentum[i] - one_minus_beta1 * grad) / beta1; + variance[i] = (variance[i] - one_minus_beta2 * grad * grad) / beta2; + } + + --step_count; +} + +int ds_adam_rollback(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + try { + // Validate tensor types - rollback currently only supports float32 + if (params.scalar_type() != torch::kFloat32 || grads.scalar_type() != torch::kFloat32 || + exp_avg.scalar_type() != torch::kFloat32 || + exp_avg_sq.scalar_type() != torch::kFloat32) { + printf("Error: Adam rollback currently only supports float32 tensors\n"); + return -1; + } + + float* params_ptr = params.data_ptr(); + const float* grads_ptr = grads.data_ptr(); + float* momentum_ptr = exp_avg.data_ptr(); + float* variance_ptr = exp_avg_sq.data_ptr(); + const size_t param_size = params.numel(); + int step_count = static_cast(step); + + adamw_rollback_inplace(params_ptr, + grads_ptr, + momentum_ptr, + variance_ptr, + param_size, + lr, + beta1, + beta2, + epsilon, + weight_decay, + step_count); + + return 0; + } catch (const std::exception& e) { + printf("Error in Adam rollback for optimizer #%d: %s\n", optimizer_id, e.what()); + return -1; + } +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu index 1b697d989b1a..a1fc7d15aec9 100644 --- a/csrc/adam/multi_tensor_adam.cu +++ b/csrc/adam/multi_tensor_adam.cu @@ -23,14 +23,14 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 #define BLOCK_SIZE 512 #define ILP 4 -typedef enum { +typedef enum : int { ADAM_MODE_0 = 0, // L2 regularization mode ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) } adamMode_t; using MATH_T = float; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, @@ -48,13 +48,13 @@ struct AdamFunctor { // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; // potentially use to pass in list of scalar // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; @@ -71,7 +71,8 @@ struct AdamFunctor { n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size, bias_correction2 = 1 - std::pow(beta2, step); } + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { break; } + } + // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), - 0, - "adam", - multi_tensor_apply<4>(BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - bias_correction1, - bias_correction2, - epsilon, - lr, - (adamMode_t)mode, - weight_decay);) + if (requires_64bit_indexing) { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, + (int64_t)chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } else { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh index 12f41cb49c6b..ea028e91946b 100644 --- a/csrc/adam/multi_tensor_apply.cuh +++ b/csrc/adam/multi_tensor_apply.cuh @@ -28,14 +28,14 @@ constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; template struct TensorListMetadata { void* addresses[n][depth_to_max_tensors[n - 1]]; - int sizes[depth_to_max_tensors[n - 1]]; + int64_t sizes[depth_to_max_tensors[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. int start_tensor_this_launch; }; template -__global__ void multi_tensor_apply_kernel(int chunk_size, +__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable, @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size, } template -void multi_tensor_apply(int block_size, - int chunk_size, +void multi_tensor_apply(int64_t block_size, + int64_t chunk_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size, tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++; - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk; diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index f35760a99a5c..9d7ff5093017 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -68,8 +68,8 @@ static void _get_aio_latencies(std::vector>& raw_l std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); } -static void _do_io_submit_singles(const long long int n_iocbs, - const long long int iocb_index, +static void _do_io_submit_singles(const int64_t n_iocbs, + const int64_t iocb_index, std::unique_ptr& aio_ctxt, std::vector>& submit_times) { @@ -89,8 +89,8 @@ static void _do_io_submit_singles(const long long int n_iocbs, } } -static void _do_io_submit_block(const long long int n_iocbs, - const long long int iocb_index, +static void _do_io_submit_block(const int64_t n_iocbs, + const int64_t iocb_index, std::unique_ptr& aio_ctxt, std::vector>& submit_times) { @@ -109,16 +109,19 @@ static void _do_io_submit_block(const long long int n_iocbs, assert(submit_ret > 0); } -static int _do_io_complete(const long long int min_completes, - const long long int max_completes, +static int _do_io_complete(const int64_t min_completes, + const int64_t max_completes, std::unique_ptr& aio_ctxt, std::vector>& reap_times) { const auto start_time = std::chrono::high_resolution_clock::now(); - const auto n_completes = io_getevents( - aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); + int64_t n_completes = io_pgetevents(aio_ctxt->_io_ctxt, + min_completes, + max_completes, + aio_ctxt->_io_events.data(), + nullptr, + nullptr); reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); - assert(n_completes >= min_completes); return n_completes; } @@ -131,7 +134,7 @@ void do_aio_operation_sequential(const bool read_op, { struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); - const auto num_io_blocks = static_cast( + const auto num_io_blocks = static_cast( ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); #if DEBUG_DS_AIO_PERF const auto io_op_name = std::string(read_op ? "read" : "write"); @@ -142,15 +145,14 @@ void do_aio_operation_sequential(const bool read_op, std::vector> submit_times; std::vector> reap_times; const auto max_queue_bytes = - static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); + static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); auto start = std::chrono::high_resolution_clock::now(); - for (long long iocb_index = 0; iocb_index < num_io_blocks; - iocb_index += aio_ctxt->_queue_depth) { + for (int64_t iocb_index = 0; iocb_index < num_io_blocks; iocb_index += aio_ctxt->_queue_depth) { const auto start_offset = iocb_index * aio_ctxt->_block_size; const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; const auto n_iocbs = - min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); + min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); @@ -265,6 +267,10 @@ void report_file_error(const char* filename, const std::string file_op, const in int open_file(const char* filename, const bool read_op) { const int flags = read_op ? (O_RDONLY | O_DIRECT) : (O_WRONLY | O_CREAT | O_DIRECT); +#if defined(__ENABLE_CANN__) + int* flags_ptr = (int*)&flags; + *flags_ptr = read_op ? (O_RDONLY) : (O_WRONLY | O_CREAT); +#endif const int mode = 0600; const auto fd = open(filename, flags, mode); if (fd == -1) { @@ -278,13 +284,14 @@ int open_file(const char* filename, const bool read_op) int regular_read(const char* filename, std::vector& buffer) { - long long int num_bytes; - const auto f_size = get_file_size(filename, num_bytes); - assert(f_size != -1); - buffer.resize(num_bytes); const auto fd = open(filename, O_RDONLY, 0600); assert(fd != -1); - long long int read_bytes = 0; + struct stat fs; + const auto result = fstat(fd, &fs); + assert(result != -1); + int64_t num_bytes = fs.st_size; + buffer.resize(num_bytes); + int64_t read_bytes = 0; auto r = 0; do { const auto buffer_ptr = buffer.data() + read_bytes; @@ -294,16 +301,15 @@ int regular_read(const char* filename, std::vector& buffer) } while (r > 0); if (read_bytes != num_bytes) { - std::cerr << "read error " - << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes - << std::endl; + std::cerr << "read error " << " read_bytes (read) = " << read_bytes + << " num_bytes (fstat) = " << num_bytes << std::endl; } assert(read_bytes == num_bytes); close(fd); return 0; } -static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) +static bool _validate_buffer(const char* filename, void* aio_buffer, const int64_t num_bytes) { std::vector regular_buffer; const auto reg_ret = regular_read(filename, regular_buffer); @@ -311,7 +317,7 @@ static bool _validate_buffer(const char* filename, void* aio_buffer, const long std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" << std::endl; - if (static_cast(regular_buffer.size()) != num_bytes) { return false; } + if (static_cast(regular_buffer.size()) != num_bytes) { return false; } return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); } @@ -319,7 +325,7 @@ static bool _validate_buffer(const char* filename, void* aio_buffer, const long bool validate_aio_operation(const bool read_op, const char* filename, void* aio_buffer, - const long long int num_bytes) + const int64_t num_bytes) { const auto msg_suffix = std::string("deepspeed_aio_") + std::string(read_op ? "read()" : "write()") + diff --git a/csrc/aio/common/deepspeed_aio_common.h b/csrc/aio/common/deepspeed_aio_common.h index 2940de945ee8..aa4e49f4f4ed 100644 --- a/csrc/aio/common/deepspeed_aio_common.h +++ b/csrc/aio/common/deepspeed_aio_common.h @@ -35,4 +35,4 @@ int regular_read(const char* filename, std::vector& buffer); bool validate_aio_operation(const bool read_op, const char* filename, void* aio_buffer, - const long long int num_bytes); + const int64_t num_bytes); diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index 763b2c253a34..c8e577f299ae 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -18,10 +18,15 @@ const int c_block_size = 128 * 1024; const int c_io_queue_depth = 8; io_xfer_ctxt::io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, + const int64_t file_offset, + const int64_t buffer_offset, + const int64_t num_bytes, const void* buffer) - : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) + : _fd(fd), + _file_base_offset(file_offset), + _buffer_base_offset(buffer_offset), + _mem_buffer(buffer), + _num_bytes(num_bytes) { } @@ -36,14 +41,15 @@ io_prep_context::io_prep_context(const bool read_op, void io_prep_context::prep_iocbs(const int n_iocbs, const size_t num_bytes, const void* start_buffer, - const long long int start_offset) + const int64_t start_offset) { assert(static_cast(n_iocbs) <= _iocbs->size()); for (auto i = 0; i < n_iocbs; ++i) { const auto shift = i * _block_size; - const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; - const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; + const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_buffer_base_offset + shift; + const auto xfer_offset = _xfer_ctxt->_file_base_offset + start_offset + shift; auto byte_count = _block_size; + if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } if (_read_op) { @@ -64,25 +70,25 @@ io_prep_generator::io_prep_generator(const bool read_op, _next_iocb_index(0) { _num_io_blocks = - static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); + static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); _remaining_io_blocks = _num_io_blocks; } int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) { if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { - assert(static_cast(_remaining_bytes) == _remaining_io_blocks); + assert(static_cast(_remaining_bytes) == _remaining_io_blocks); return 0; } assert(static_cast(n_iocbs) <= iocbs->size()); - auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); + auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { - const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); - const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; - const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); - + const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + _xfer_ctxt->_buffer_base_offset + + (_next_iocb_index * _block_size); + const auto xfer_offset = _xfer_ctxt->_file_base_offset + (_next_iocb_index * _block_size); + const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); if (_read_op) { io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); } else { @@ -95,7 +101,7 @@ int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* return actual_n_iocbs; } -int get_file_size(const char* filename, long long int& size) +int64_t get_file_size(const char* filename, int64_t& size) { struct stat st; if (stat(filename, &st) == -1) { return -1; } @@ -103,7 +109,15 @@ int get_file_size(const char* filename, long long int& size) return 0; } -void* ds_page_aligned_alloc(const size_t size, const bool lock) +int64_t get_fd_file_size(const int fd, int64_t& size) +{ + struct stat st; + if (fstat(fd, &st) == -1) { return -1; } + size = st.st_size; + return 0; +} + +void* ds_page_aligned_alloc(const int64_t size, const bool lock) { void* ptr; int retval; diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index 9c58c2286610..8742bf5bff54 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -30,13 +30,15 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. struct io_xfer_ctxt { const int _fd; - const long long int _base_offset; + const int64_t _file_base_offset; + const int64_t _buffer_base_offset; const void* _mem_buffer; - const long long int _num_bytes; + const int64_t _num_bytes; io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, + const int64_t file_offset, + const int64_t buffer_offset, + const int64_t num_bytes, const void* buffer); }; @@ -54,7 +56,7 @@ struct io_prep_context { void prep_iocbs(const int n_iocbs, const size_t num_bytes, const void* start_buffer, - const long long int start_offset); + const int64_t start_offset); }; struct io_prep_generator { @@ -62,10 +64,10 @@ struct io_prep_generator { const std::unique_ptr& _xfer_ctxt; const size_t _block_size; - long long int _remaining_bytes; - long long int _num_io_blocks; - long long int _remaining_io_blocks; - long long int _next_iocb_index; + int64_t _remaining_bytes; + int64_t _num_io_blocks; + int64_t _remaining_io_blocks; + int64_t _next_iocb_index; io_prep_generator(const bool read_op, const std::unique_ptr& xfer_ctxt, @@ -74,6 +76,7 @@ struct io_prep_generator { int prep_iocbs(const int n_iocbs, std::vector* iocbs); }; -void* ds_page_aligned_alloc(const size_t size, const bool lock = false); +void* ds_page_aligned_alloc(const int64_t size, const bool lock = false); -int get_file_size(const char* filename, long long int& size); +int64_t get_file_size(const char* filename, int64_t& size); +int64_t get_fd_file_size(const int fd, int64_t& size); diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp new file mode 100644 index 000000000000..8387e667b332 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepspeed_aio_op_desc.h" + +using namespace std; + +io_op_desc_t::io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset) + : _read_op(read_op), + _buffer(buffer), + _fd(fd), + _filename((filename == nullptr) ? std::string() : filename), + _file_offset(file_offset), + _intra_op_parallelism(intra_op_parallelism), + _num_bytes_per_thread(static_cast(buffer.nbytes()) / intra_op_parallelism), + _validate(validate) +{ + if (validate) { assert(nullptr != filename); } +} + +char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void io_op_desc_t::finish() {} + +void io_op_desc_t::validate() {} + +void io_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ +} diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h new file mode 100644 index 000000000000..cc7f15d74658 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef _IO_OP_DESC_T_ +#define _IO_OP_DESC_T_ +#include +#include +#include "deepspeed_py_aio.h" + +struct io_op_desc_t { + const bool _read_op; + torch::Tensor _buffer; + int _fd; + std::string _filename; + const int _intra_op_parallelism; + const int64_t _num_bytes_per_thread; + torch::Tensor _contiguous_buffer; + const bool _validate; + const int64_t _file_offset; + + io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset); + + virtual void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + virtual char* data_ptr() const; + + virtual void validate(); + + virtual void finish(); +}; +#endif // _IO_OP_DESC_T_ diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 055db8798a6b..30c3b4914397 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -11,30 +11,6 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. using namespace std; -io_op_desc_t::io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate) - : _read_op(read_op), - _buffer(buffer), - _fd(fd), - _filename(filename), - _num_bytes(num_bytes), - _validate(validate) -{ - _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer; - _contiguous_buffer = _cpu_buffer.contiguous(); -} - -char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } - -void io_op_desc_t::fini() -{ - if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } -} - deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) : _tid(tid), _aio_config(aio_config), @@ -61,18 +37,7 @@ void deepspeed_aio_thread_t::run() } if (next_io_op) { - const auto base_offset = next_io_op->_num_bytes * _tid; - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt( - next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } + next_io_op->run(_tid, _aio_ctxt, &_aio_config); { std::lock_guard lock(_complete_sync._mutex); diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index 20799ecbb018..a192804db13d 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -10,28 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include #include #include -#include "deepspeed_py_aio.h" - -struct io_op_desc_t { - const bool _read_op; - torch::Tensor _buffer; - int _fd; - const std::string _filename; - const long long int _num_bytes; - torch::Tensor _cpu_buffer; - torch::Tensor _contiguous_buffer; - const bool _validate; - - io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate); - - char* data_ptr() const; - void fini(); -}; +#include "deepspeed_cpu_op.h" struct thread_sync_t { std::mutex _mutex; diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp new file mode 100644 index 000000000000..4e2e13f5fd98 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepspeed_cpu_op.h" +#include "deepspeed_pin_tensor.h" + +using namespace std; + +cpu_op_desc_t::cpu_op_desc_t( + const std::unique_ptr& pinned_tensor_mgr, + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset) + : io_op_desc_t(read_op, buffer, fd, filename, intra_op_parallelism, validate, file_offset), + _cpu_buffer(buffer), + _pinned_tensor_mgr(pinned_tensor_mgr), + _is_managed_bounce_buffer(false) +{ + // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory. + _use_bounce_buffer = + !(_buffer.is_cpu() && (_buffer.is_pinned() || _pinned_tensor_mgr->is_managed(_buffer))); + if (_use_bounce_buffer) { + _alloc_bounce_buffer(); + if (!_read_op) { _cpu_buffer.copy_(_buffer); } + } + _contiguous_buffer = _cpu_buffer.contiguous(); +} + +char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void cpu_op_desc_t::finish() +{ + if (_use_bounce_buffer) { + if (_read_op) { + if (_buffer.is_cuda()) { + _buffer.copy_(_cpu_buffer.to(torch::Device(torch::kCUDA, _buffer.get_device()), + /*non_blocking=*/true)); + } + if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); } + if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); } +#if defined(__ENABLE_CANN__) + // `DS_BUILD_OPS=1 install.sh` complains that ‘torch_npu’ has not + // been declared, so inline `torch_npu::utils::is_npu`. + if (_buffer.is_privateuseone()) { + auto device = at::Device("npu:0"); + _buffer.copy_(_cpu_buffer.to(device)); + } +#endif + } + + _free_bounce_buffer(); + } +} + +void cpu_op_desc_t::validate() +{ + const auto num_io_bytes = static_cast(_contiguous_buffer.nbytes()); + validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), num_io_bytes); +} + +void cpu_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ + assert(tid < _intra_op_parallelism); + const auto buffer_base_offset = _num_bytes_per_thread * tid; + const auto file_base_offset = _file_offset + (_num_bytes_per_thread * tid); + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt( + _fd, file_base_offset, buffer_base_offset, _num_bytes_per_thread, data_ptr())); + + if (aio_config->_overlap_events) { + do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); + } else { + do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr); + } +} + +void cpu_op_desc_t::_alloc_bounce_buffer() +{ + auto options = torch::TensorOptions() + .dtype(_buffer.dtype()) + .layout(_buffer.layout()) + .device(torch::kCPU) + .requires_grad(false); + +#if defined(__CUDA_ARCH__) + _cpu_buffer = torch::empty(_buffer.numel(), options).pin_memory(); +#else + _is_managed_bounce_buffer = true; + _cpu_buffer = _pinned_tensor_mgr->alloc(_buffer.numel(), options); +#endif +} + +void cpu_op_desc_t::_free_bounce_buffer() +{ + if (_is_managed_bounce_buffer) { _pinned_tensor_mgr->free(_cpu_buffer); } +} diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h new file mode 100644 index 000000000000..7cc648bace8e --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_cpu_op.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include "deepspeed_aio_op_desc.h" + +struct cpu_op_desc_t : io_op_desc_t { + torch::Tensor _cpu_buffer; + bool _use_bounce_buffer; + bool _is_managed_bounce_buffer; + const std::unique_ptr& _pinned_tensor_mgr; + + cpu_op_desc_t(const std::unique_ptr& pinned_tensor_mgr, + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset); + + void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + char* data_ptr() const; + + void validate(); + + void finish(); + + void _alloc_bounce_buffer(); + void _free_bounce_buffer(); +}; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp index 752823dc7dd2..a97a4ac18ba8 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.cpp +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.cpp @@ -15,21 +15,28 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t() { for (auto iter = _locked_tensors.begin(); iter != _locked_tensors.end(); ++iter) { munlock(iter->first, iter->second); + std::free((void*)iter->first); } _locked_tensors.clear(); } -torch::Tensor deepspeed_pin_tensor_t::alloc(const size_t num_elem, const at::ScalarType& elem_type) +torch::Tensor deepspeed_pin_tensor_t::alloc(const int64_t num_elem, + const torch::TensorOptions& options) { - const auto num_bytes = num_elem * elementSize(elem_type); + const auto scalar_dtype = torch::typeMetaToScalarType(options.dtype()); + const auto num_bytes = num_elem * torch::elementSize(scalar_dtype); auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true); assert(nullptr != pinned_buffer); _locked_tensors[pinned_buffer] = num_bytes; - auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU); + return at::from_blob(pinned_buffer, static_cast(num_elem), options); +} - return at::from_blob(pinned_buffer, static_cast(num_bytes), options); +torch::Tensor deepspeed_pin_tensor_t::alloc(const int64_t num_elem, const at::ScalarType& elem_type) +{ + auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false); + return alloc(num_elem, options); } bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) @@ -37,9 +44,18 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor) auto addr = locked_tensor.data_ptr(); if (_locked_tensors.find(addr) != _locked_tensors.end()) { munlock(addr, _locked_tensors[addr]); + std::free(addr); _locked_tensors.erase(addr); return true; } return false; } + +bool deepspeed_pin_tensor_t::is_managed(const torch::Tensor& buffer) +{ + if (!buffer.is_cpu()) { return false; } + auto addr = buffer.data_ptr(); + if (_locked_tensors.find(addr) != _locked_tensors.end()) { return true; } + return false; +}; diff --git a/csrc/aio/py_lib/deepspeed_pin_tensor.h b/csrc/aio/py_lib/deepspeed_pin_tensor.h index 4350a4ac7df6..4b8ad7e76085 100644 --- a/csrc/aio/py_lib/deepspeed_pin_tensor.h +++ b/csrc/aio/py_lib/deepspeed_pin_tensor.h @@ -15,13 +15,16 @@ Functionality for managing CPU tensors occupying page-locked memory. #include "deepspeed_py_aio.h" struct deepspeed_pin_tensor_t { - std::map _locked_tensors; + std::map _locked_tensors; deepspeed_pin_tensor_t() = default; ~deepspeed_pin_tensor_t(); - torch::Tensor alloc(const size_t num_elem, const at::ScalarType& elem_type); + torch::Tensor alloc(const int64_t num_elem, const at::ScalarType& elem_type); + torch::Tensor alloc(const int64_t num_elem, const torch::TensorOptions& options); bool free(torch::Tensor& locked_tensor); + + bool is_managed(const torch::Tensor& buffer); }; diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index 387b713f2bfc..1ff0397043fa 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -4,9 +4,6 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ @@ -54,8 +51,10 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, if (fd == -1) { return -1; } auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + const auto num_write_bytes = static_cast(buffer.nbytes()); + + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, 0, 0, num_write_bytes, write_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); if (config._overlap_events) { @@ -72,9 +71,8 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } @@ -87,7 +85,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, const bool validate) { const auto start_time = std::chrono::high_resolution_clock::now(); - long long num_file_bytes; + int64_t num_file_bytes; if (-1 == get_file_size(filename, num_file_bytes)) { const auto error_code = errno; report_file_error(filename, " fstat for read", error_code); @@ -99,9 +97,10 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, if (fd == -1) { return -1; } auto read_buffer = (char*)buffer.data_ptr(); - assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert(static_cast(buffer.nbytes()) == num_file_bytes); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, 0, 0, num_file_bytes, read_buffer)); std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); if (config._overlap_events) { @@ -118,8 +117,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h index 11d5225de9f1..ba794db5440d 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.h +++ b/csrc/aio/py_lib/deepspeed_py_aio.h @@ -4,10 +4,7 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +Functionality for swapping tensors to/from (NVMe) storage devices. */ #include diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index c21e92de9449..2b1093e99286 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -4,295 +4,25 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ #include "deepspeed_py_aio_handle.h" +#include using namespace std; -static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } - deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads) - : _aio_ctxt(new aio_context(block_size, queue_depth)), - _single_submit(single_submit), - _overlap_events(overlap_events), - _num_threads(num_threads), - _aio_config(block_size, queue_depth, single_submit, overlap_events, false), - _num_pending_ops(0), - _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) -{ - for (auto i = 0; i < num_threads; ++i) { - _thread_contexts.push_back(std::make_shared(i, _aio_config)); - } - - for (auto& ctxt : _thread_contexts) { - _threads.push_back(std::thread(_start_aio_thread, ctxt)); - } -} - -deepspeed_aio_handle_t::~deepspeed_aio_handle_t() -{ - _stop_threads(); - for (auto& thr : _threads) { thr.join(); } -} - -const int deepspeed_aio_handle_t::get_block_size() const -{ - return _aio_ctxt ? _aio_ctxt->_block_size : -1; -} - -const int deepspeed_aio_handle_t::get_queue_depth() const -{ - return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; -} - -const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } - -const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } - -const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } - -int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - - assert(_aio_ctxt); - - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - close(fd); - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, - const char* filename, - const bool validate) + const int intra_op_parallelism) + : deepspeed_io_handle_t(block_size, + queue_depth, + single_submit, + overlap_events, + intra_op_parallelism) { - assert(_aio_ctxt); - - const auto start_time = std::chrono::high_resolution_clock::now(); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; } -void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) -{ - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_work_queue.push(scheduled_op); - } - ctxt->_work_sync._cond_var.notify_one(); - } - _num_pending_ops++; -} - -std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() -{ - std::shared_ptr completed_op = nullptr; - for (auto& ctxt : _thread_contexts) { - std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); - completed_op = ctxt->_complete_queue.front(); - ctxt->_complete_queue.pop(); - } - return completed_op; -} - -void deepspeed_aio_handle_t::_stop_threads() -{ - assert(0 == _num_pending_ops); - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_time_to_exit = true; - } - ctxt->_work_sync._cond_var.notify_one(); - } -} - -int deepspeed_aio_handle_t::wait() -{ - assert(_num_pending_ops > 0); - auto num_completed_ops = 0; - - while (_num_pending_ops > 0) { - auto completed_op = _wait_for_aio_work(); - - completed_op->fini(); - - close(completed_op->_fd); - - if (completed_op->_validate) { - validate_aio_operation(completed_op->_read_op, - completed_op->_filename.c_str(), - completed_op->data_ptr(), - _num_threads * completed_op->_num_bytes); - } - --_num_pending_ops; - ++num_completed_ops; - } - - return num_completed_ops; -} - -bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, - const long long int num_bytes) -{ - const auto op_string = read_op ? "Read" : "Write"; - if (num_bytes % get_thread_count()) { - std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes - << " not divisible by thread count = " << get_thread_count() << std::endl; - return false; - } - - return true; -} - -int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - const auto buffer_bytes = static_cast(buffer.nbytes()); - if (buffer_bytes != num_file_bytes) { - std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes - << " != " << num_file_bytes << std::endl; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - assert((num_file_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - const auto num_write_bytes = static_cast(buffer.nbytes()); - assert((num_write_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, true); -} - -int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, true); -} - -at::Tensor deepspeed_aio_handle_t::new_cpu_locked_tensor(const size_t num_elem, - const torch::Tensor& example_tensor) -{ - return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); -} - -bool deepspeed_aio_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor) -{ - return _pinned_tensor_mgr->free(locked_tensor); -} +deepspeed_aio_handle_t::~deepspeed_aio_handle_t() {} diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 3a254c3814a2..c9fcb6d2b462 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -6,72 +6,16 @@ /* Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ - #include #include -#include "deepspeed_aio_thread.h" -#include "deepspeed_pin_tensor.h" - -struct deepspeed_aio_handle_t { - std::unique_ptr _aio_ctxt; - const bool _single_submit; - const bool _overlap_events; - const int _num_threads; - deepspeed_aio_config_t _aio_config; - - std::vector> _thread_contexts; - std::vector _threads; - int _num_pending_ops; - std::unique_ptr _pinned_tensor_mgr; +#include "deepspeed_py_io_handle.h" +struct deepspeed_aio_handle_t : deepspeed_io_handle_t { deepspeed_aio_handle_t(const int block_size, const int queue_depth, const bool single_submit, const bool overlap_events, - const int num_threads); + const int intra_op_parallelism); ~deepspeed_aio_handle_t(); - - const int get_block_size() const; - const int get_queue_depth() const; - const bool get_single_submit() const; - const bool get_overlap_events() const; - const int get_thread_count() const; - - int read(torch::Tensor& buffer, const char* filename, const bool validate); - - int write(const torch::Tensor& buffer, const char* filename, const bool validate); - - int pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int sync_pread(torch::Tensor& buffer, const char* filename); - - int sync_pwrite(const torch::Tensor& buffer, const char* filename); - - int async_pread(torch::Tensor& buffer, const char* filename); - - int async_pwrite(const torch::Tensor& buffer, const char* filename); - - // TODO: Make API's args to be shape and dtype. - torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor); - - bool free_cpu_locked_tensor(torch::Tensor&); - - int wait(); - - void _stop_threads(); - - void _schedule_aio_work(std::shared_ptr scheduled_op); - - std::shared_ptr _wait_for_aio_work(); - - bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); }; diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index 8a59107dd347..f5480e9d9d83 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -4,13 +4,13 @@ // DeepSpeed Team /* -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +Functionality for swapping tensors to/from (NVMe) storage devices. */ #include "deepspeed_py_copy.h" #include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) or defined(__AVX256__) union AVX_Data { diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h index 19ba28317d00..f443571a3e7b 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.h +++ b/csrc/aio/py_lib/deepspeed_py_copy.h @@ -4,9 +4,6 @@ // DeepSpeed Team /* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp new file mode 100644 index 000000000000..48668a842949 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_io_handle.h" +#include + +#define O_DIRECT_ALIGNMENT 512 + +using namespace std; + +static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } + +static bool is_valid_bytes_to_read(const char* filename, + const int64_t file_offset, + const int64_t num_bytes_to_read) +{ + int64_t num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return false; + } + if ((file_offset + num_bytes_to_read) > num_file_bytes) { + std::cout << filename << ": file_offset + buffer nbytes > file bytes " + << (file_offset + num_bytes_to_read) << " > " << num_file_bytes << std::endl; + } + assert((file_offset + num_bytes_to_read) <= num_file_bytes); + return true; +} + +deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int intra_op_parallelism) + : _aio_ctxt(new aio_context(block_size, queue_depth)), + _single_submit(single_submit), + _overlap_events(overlap_events), + _intra_op_parallelism(intra_op_parallelism), + _aio_config(block_size, queue_depth, single_submit, overlap_events, false), + _num_pending_ops(0), + _pinned_tensor_mgr(new deepspeed_pin_tensor_t()) +{ + for (auto i = 0; i < intra_op_parallelism; ++i) { + _thread_contexts.push_back(std::make_shared(i, _aio_config)); + } + + for (auto& ctxt : _thread_contexts) { + _threads.push_back(std::thread(_start_aio_thread, ctxt)); + } +} + +deepspeed_io_handle_t::~deepspeed_io_handle_t() +{ + _stop_threads(); + for (auto& thr : _threads) { thr.join(); } +} + +const int deepspeed_io_handle_t::get_block_size() const +{ + return _aio_ctxt ? _aio_ctxt->_block_size : -1; +} + +const int deepspeed_io_handle_t::get_queue_depth() const +{ + return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; +} + +const bool deepspeed_io_handle_t::get_single_submit() const { return _single_submit; } + +const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; } + +const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; } + +const int deepspeed_io_handle_t::get_alignment() const +{ + return _intra_op_parallelism * O_DIRECT_ALIGNMENT; +} + +int deepspeed_io_handle_t::read(torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + + assert(_aio_ctxt); + + int64_t num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, file_offset, 0, num_file_bytes, read_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + close(fd); + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +int deepspeed_io_handle_t::write(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset) +{ + assert(_aio_ctxt); + + const auto start_time = std::chrono::high_resolution_clock::now(); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt( + new io_xfer_ctxt(fd, file_offset, 0, num_write_bytes, write_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; + return 0; +} + +void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) +{ + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_work_queue.push(scheduled_op); + } + ctxt->_work_sync._cond_var.notify_one(); + } + _num_pending_ops++; +} + +std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work() +{ + std::shared_ptr completed_op = nullptr; + for (auto& ctxt : _thread_contexts) { + std::unique_lock lock(ctxt->_complete_sync._mutex); + ctxt->_complete_sync._cond_var.wait(lock, + [ctxt] { return !ctxt->_complete_queue.empty(); }); + completed_op = ctxt->_complete_queue.front(); + ctxt->_complete_queue.pop(); + } + return completed_op; +} + +void deepspeed_io_handle_t::_stop_threads() +{ + assert(0 == _num_pending_ops); + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_time_to_exit = true; + } + ctxt->_work_sync._cond_var.notify_one(); + } +} + +int deepspeed_io_handle_t::wait() +{ + assert(_num_pending_ops > 0); + auto num_completed_ops = 0; + + while (_num_pending_ops > 0) { + auto completed_op = _wait_for_aio_work(); + + if (completed_op->_validate) { completed_op->validate(); } + + completed_op->finish(); + + if (!completed_op->_filename.empty()) { (completed_op->_fd); } + + --_num_pending_ops; + ++num_completed_ops; + } + + return num_completed_ops; +} + +bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes) +{ + const auto op_string = read_op ? "Read" : "Write"; + if (num_bytes % get_intra_op_parallelism()) { + std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes + << " not divisible by intra op parallelism = " << get_intra_op_parallelism() + << std::endl; + return false; + } + + return true; +} + +std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const int64_t file_offset) +{ + return std::make_shared(_pinned_tensor_mgr, + read_op, + buffer, + fd, + filename, + _intra_op_parallelism, + validate, + file_offset); +} + +int deepspeed_io_handle_t::_pread(const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset) +{ + auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, validate, file_offset); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset) +{ + const auto buffer_bytes = static_cast(buffer.nbytes()); + + if (!is_valid_bytes_to_read(filename, file_offset, buffer_bytes)) { return -1; } + + if (!_is_valid_parallel_aio_op(true, buffer_bytes)) { return -1; } + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + return _pread(buffer, fd, filename, validate, async, file_offset); +} + +int deepspeed_io_handle_t::_pwrite(const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset) +{ + auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, validate, file_offset); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + return _pwrite(buffer, fd, filename, validate, async, file_offset); +} + +int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) +{ + return pread(buffer, filename, false, false, file_offset); +} + +int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) +{ + return pwrite(buffer, filename, false, false, file_offset); +} + +int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) +{ + return pread(buffer, filename, false, true, file_offset); +} + +int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, + const char* filename, + const int64_t file_offset) +{ + return pwrite(buffer, filename, false, true, file_offset); +} + +int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, + const int fd, + const int64_t file_offset = 0) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + return _pwrite(buffer, fd, nullptr, false, true, file_offset); +} + +at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const int64_t num_elem, + const torch::Tensor& example_tensor) +{ + return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type()); +} + +bool deepspeed_io_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor) +{ + return _pinned_tensor_mgr->free(locked_tensor); +} diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h new file mode 100644 index 000000000000..8cb43c5b38e5 --- /dev/null +++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_aio_thread.h" +#include "deepspeed_pin_tensor.h" + +struct deepspeed_io_handle_t { + std::unique_ptr _aio_ctxt; + const bool _single_submit; + const bool _overlap_events; + const int _intra_op_parallelism; + deepspeed_aio_config_t _aio_config; + + std::vector> _thread_contexts; + std::vector _threads; + int _num_pending_ops; + std::unique_ptr _pinned_tensor_mgr; + + deepspeed_io_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int intra_op_parallelism); + + virtual ~deepspeed_io_handle_t() = 0; + + const int get_block_size() const; + const int get_queue_depth() const; + const bool get_single_submit() const; + const bool get_overlap_events() const; + const int get_intra_op_parallelism() const; + const int get_alignment() const; + + int read(torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset); + + int write(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const int64_t file_offset); + + int pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset); + + int pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset); + + int sync_pread(torch::Tensor& buffer, const char* filename, const int64_t file_offset); + + int sync_pwrite(const torch::Tensor& buffer, const char* filename, const int64_t file_offset); + + int async_pread(torch::Tensor& buffer, const char* filename, const int64_t file_offset); + + int async_pwrite(const torch::Tensor& buffer, const char* filename, const int64_t file_offset); + int async_pwrite(const torch::Tensor& buffer, const int fd, const int64_t file_offset); + + // TODO: Make API's args to be shape and dtype. + torch::Tensor new_cpu_locked_tensor(const int64_t num_elem, + const torch::Tensor& example_tensor); + + bool free_cpu_locked_tensor(torch::Tensor&); + + int wait(); + + void _stop_threads(); + + void _schedule_aio_work(std::shared_ptr scheduled_op); + + std::shared_ptr _wait_for_aio_work(); + + bool _is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes); + + int _pread(const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset); + + int _pwrite(const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const bool async, + const int64_t file_offset); + + virtual std::shared_ptr _create_io_op_desc(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const int64_t file_offset); +}; diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp old mode 100755 new mode 100644 index 9033549bc0d2..cf9838cf8191 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -6,10 +6,10 @@ /* Functionality for swapping optimizer tensors to/from (NVMe) storage devices. */ - #include #include "deepspeed_py_aio_handle.h" #include "deepspeed_py_copy.h" +using namespace pybind11::literals; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -20,27 +20,110 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); py::class_(m, "aio_handle") - .def(py::init()) + .def(py::init(), + "AIO handle constructor", + "block_size"_a = 1024 * 1024, + "queue_depth"_a = 128, + "single_submit"_a = false, + "overlap_events"_a = false, + "intra_op_parallelism"_a = 1) .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) - .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) + .def("get_intra_op_parallelism", &deepspeed_aio_handle_t::get_intra_op_parallelism) + .def("get_alignment", &deepspeed_aio_handle_t::get_alignment) + + .def("read", + &deepspeed_aio_handle_t::read, + "Synchronous and non-parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "file_offset"_a = 0) + + .def("write", + &deepspeed_aio_handle_t::write, + "Synchronous and non-parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "file_offset"_a = 0) + + .def("pread", + &deepspeed_aio_handle_t::pread, + "Parallel file read with option of asynchronous completion. If synchronous, returns " + "count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a, + "file_offset"_a = 0) + + .def("pwrite", + &deepspeed_aio_handle_t::pwrite, + "Parallel file write with option of asynchronous completion. If synchronous, returns " + "count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a, + "file_offset"_a = 0) + + .def("sync_pread", + &deepspeed_aio_handle_t::sync_pread, + "Synchronous parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def("sync_pwrite", + &deepspeed_aio_handle_t::sync_pwrite, + "Synchronous parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def("async_pread", + &deepspeed_aio_handle_t::async_pread, + "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " + "subsequent wait() returns count of completed ops.", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) - .def("read", &deepspeed_aio_handle_t::read) - .def("write", &deepspeed_aio_handle_t::write) + .def( + "async_pwrite", + py::overload_cast( + &deepspeed_aio_handle_t::async_pwrite), + "Asynchronous parallel file write. Returns 0 on success, and subsequent wait() returns " + "count of completed ops.", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) - .def("pread", &deepspeed_aio_handle_t::pread) - .def("pwrite", &deepspeed_aio_handle_t::pwrite) + .def("async_pwrite", + py::overload_cast( + &deepspeed_aio_handle_t::async_pwrite), + "Asynchronous parallel file write using opened python file object.", + "buffer"_a, + "fd"_a, + "file_offset"_a = 0) - .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) - .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) - .def("async_pread", &deepspeed_aio_handle_t::async_pread) - .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) + .def("new_cpu_locked_tensor", + &deepspeed_aio_handle_t::new_cpu_locked_tensor, + "Allocate pinned CPU tensor.", + "num_elem"_a, + "example_tenosr"_a) - .def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor) - .def("free_cpu_locked_tensor", &deepspeed_aio_handle_t::free_cpu_locked_tensor) + .def("free_cpu_locked_tensor", + &deepspeed_aio_handle_t::free_cpu_locked_tensor, + "Free pinned CPU tensor.", + "tensor"_a) - .def("wait", &deepspeed_aio_handle_t::wait); + .def("wait", + &deepspeed_aio_handle_t::wait, + "Wait for (ongoing) asynchronous operations to complete", + py::call_guard()); } diff --git a/csrc/aio/py_test/aio_bench_generate_param.py b/csrc/aio/py_test/aio_bench_generate_param.py index 09d0e03c7ef6..7a0ab59ed73d 100644 --- a/csrc/aio/py_test/aio_bench_generate_param.py +++ b/csrc/aio/py_test/aio_bench_generate_param.py @@ -41,9 +41,9 @@ def convert_to_param(key): return { "single_submit": "true" if key[0] == "single" else "false", "overlap_events": "true" if key[1] == "overlap" else "false", - "thread_count": int(key[3]), - "queue_depth": int(key[4]), - "block_size": int(key[5]) + "thread_count": int(key[5]), + "queue_depth": int(key[3]), + "block_size": int(key[4]) } diff --git a/csrc/aio/py_test/aio_bench_perf_sweep.py b/csrc/aio/py_test/aio_bench_perf_sweep.py index 7d55f7ded65c..b63fb8dd1d21 100644 --- a/csrc/aio/py_test/aio_bench_perf_sweep.py +++ b/csrc/aio/py_test/aio_bench_perf_sweep.py @@ -10,75 +10,47 @@ import argparse import json import itertools -import subprocess import shutil -from test_ds_aio_utils import refine_integer_value +from ds_aio_job import Job, run_job from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \ - READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR + READ_LOG_DIR, WRITE_LOG_DIR from deepspeed.ops.op_builder import AsyncIOBuilder -OTHER_OPTIONS = '--handle' +OTHER_OPTIONS = '--engine aio_handle' PERF_SCRIPT = 'test_ds_aio.py' DEFAULT_SWEEP_CONFIG = { - "block_size": ["128K", "256K"], - "queue_depth": [4, 16, 32], - "overlap_events": [True, False], - "io_parallel": [2, 8], - "single_submit": [False] + "block_size": ["128K", "1M"], + "queue_depth": [32, 64, 128], + "sequential_requests": [True, False], + "single_submit": [False], + "io_parallel": [1, 2, 8], } -class Job(object): - - def __init__(self, cmd_line, output_file=None, work_dir=None): - self.cmd_line = cmd_line - self.output_file = output_file - self.work_dir = work_dir - self.output_fd = None - - def cmd(self): - return self.cmd_line - - def get_stdout(self): - return self.output_fd - - def get_stderr(self): - return self.output_fd - - def get_cwd(self): - return self.work_dir - - def open_output_file(self): - if self.output_file is not None: - self.output_fd = open(self.output_file, 'w') - - def close_output_file(self): - if self.output_fd is not None: - self.output_fd.close() - self.output_fd = None - - class SweepConfig(object): def __init__(self, args): - self.nvme_dir = args.nvme_dir - self.io_size = args.io_size + self.folder_to_device_mapping = get_ftd_map(args.nvme_dir) self.search_space = get_sweep_config_dict(args.sweep_config) + self.search_space.update(self.folder_to_device_mapping) self.read = not args.no_read self.write = not args.no_write self.flush_cache = not args.no_sudo self.log_dir = args.log_dir - self.loops = args.loops - self.other_options = f'{OTHER_OPTIONS} --loops {args.loops}' + self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}' + if args.gpu: + self.other_options += ' --gpu' + if args.gds: + self.other_options += ' --use_gds' def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--nvme_dir', + nargs='+', required=True, - type=str, help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.') parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.') @@ -92,6 +64,10 @@ def parse_arguments(): default="400M", help='Number of I/O bytes to read/write for performance measurements.') + parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.') + + parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator') + parser.add_argument( '--no_sudo', action='store_true', @@ -118,6 +94,12 @@ def dump_cmd_lines(cmd_lines): print(f'{i}: {cmd}') +def get_ftd_map(nvme_dir_list): + ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)] + ftd_arg = [' '.join(ftd for ftd in ftd_list)] + return {'folder_to_device_mapping': ftd_arg} + + def get_sweep_config_dict(sweep_config_json): if sweep_config_json is None: return DEFAULT_SWEEP_CONFIG @@ -127,6 +109,20 @@ def get_sweep_config_dict(sweep_config_json): return sweep_config +QUEUE_DEPTH = "--queue_depth" +BLOCK_SIZE = "--block_size" +SINGLE_SUBMIT = "--single_submit" +SEQUENTIAL_REQUESTS = "--sequential_requests" +THREAD_COUNT = "--threads" +IO_PARALLEL = "--io_parallel" + +DEPRECATED_KEYS = {THREAD_COUNT: "multi_process"} + + +def _handle_key_deprecation(key): + return DEPRECATED_KEYS.get(f'--{key}', key) + + def get_sweep_cmd_lines(sweep_config_dict): def flatten_options(key, value_list): @@ -141,23 +137,13 @@ def flatten_options(key, value_list): return flat_list - flat_list = [flatten_options(key, value) for key, value in sweep_config_dict.items()] + flat_list = [flatten_options(_handle_key_deprecation(key), value) for key, value in sweep_config_dict.items()] cmd_list = list(itertools.product(*flat_list)) cmd_list = [list(cmd) for cmd in cmd_list] #dump_cmd_lines(cmd_list) return cmd_list -def run_job(job): - args = ' '.join(job.cmd()) - print(f'args = {args}') - job.open_output_file() - proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) - job.close_output_file() - assert proc.returncode == 0, \ - f"This command failed: {job.cmd()}" - - def launch_sweep(sweep_jobs, sync_job, flush_cache_job): for perf_job in sweep_jobs: if flush_cache_job is not None: @@ -176,7 +162,12 @@ def create_cmd_tags(cmd_line): if len(fields) == 1: tags[fields[0]] = None elif len(fields) == 2: - tags[fields[0]] = fields[1] + if fields[0] == '--folder_to_device_mapping': + tags[fields[0]] = len(fields[1:]) + else: + tags[fields[0]] = fields[1] + elif len(fields) > 2: + tags[fields[0]] = len(fields[1:]) return tags @@ -184,16 +175,16 @@ def get_log_file(io_op_desc, cmd_line): QUEUE_DEPTH = "--queue_depth" BLOCK_SIZE = "--block_size" SINGLE_SUBMIT = "--single_submit" - OVERLAP_EVENTS = "--overlap_events" - THREAD_COUNT = "--threads" + SEQUENTIAL_REQUESTS = "--sequential_requests" + FTD_MAP = "--folder_to_device_mapping" IO_PARALLEL = "--io_parallel" tag_map = { QUEUE_DEPTH: "d", BLOCK_SIZE: "bs", SINGLE_SUBMIT: "single", - OVERLAP_EVENTS: "overlap", - THREAD_COUNT: "t", + SEQUENTIAL_REQUESTS: "sequential", + FTD_MAP: "ftd", IO_PARALLEL: "p" } @@ -201,14 +192,14 @@ def get_log_file(io_op_desc, cmd_line): QUEUE_DEPTH: 1, BLOCK_SIZE: "1M", SINGLE_SUBMIT: "block", - OVERLAP_EVENTS: "sequential", - THREAD_COUNT: 1, + SEQUENTIAL_REQUESTS: "overlap", + FTD_MAP: 1, IO_PARALLEL: 1 } def get_default_value(tag): value = tag_default[tag] - if tag in [SINGLE_SUBMIT, OVERLAP_EVENTS]: + if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]: return value return f'{tag_map[tag]}{value}' @@ -218,7 +209,7 @@ def get_config_value(tag, value): return tag_key return f'{tag_key}{value}' - tag_list = [SINGLE_SUBMIT, OVERLAP_EVENTS, THREAD_COUNT, IO_PARALLEL, QUEUE_DEPTH, BLOCK_SIZE] + tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL] log_tags = [io_op_desc] cmd_tags = create_cmd_tags(cmd_line) for tag in tag_list: @@ -252,40 +243,14 @@ def async_io_setup(): return AsyncIOBuilder().is_compatible() -def get_block_size_and_count(io_bytes): - block_size = 1 - block_count = io_bytes - bytes_in_KB = 1024 - - while block_count % bytes_in_KB == 0: - block_size *= bytes_in_KB - block_count /= bytes_in_KB - - return int(block_size), int(block_count) - - -def create_read_file(sweep_config): - read_folder = os.path.join(sweep_config.nvme_dir, f'{READ_IO_DIR}') - os.makedirs(read_folder, exist_ok=True) - read_file_name = os.path.join(read_folder, f'random_{sweep_config.io_size}B.pt') - block_size, block_count = get_block_size_and_count(refine_integer_value(sweep_config.io_size)) - dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}']) - print(f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....') - run_job(dd_job) - print(f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....') - return read_folder, read_file_name - - def remove_folder(folder): assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found" shutil.rmtree(folder) def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): - read_folder, read_file_name = create_read_file(sweep_config) - read_option = f'--read_file {read_file_name}' - read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines] - #dump_cmd_lines(read_cmd_lines) + read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines] + #dump_cmd_lines(cmd_lines) log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}') os.makedirs(log_folder, exist_ok=True) @@ -294,15 +259,9 @@ def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job) - remove_folder(read_folder) - def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): - write_folder = os.path.join(sweep_config.nvme_dir, f'{WRITE_IO_DIR}') - os.makedirs(write_folder, exist_ok=True) - write_file_name = os.path.join(write_folder, f'random_{sweep_config.io_size}B.pt') - write_option = f'--write_size {sweep_config.io_size} --write_file {write_file_name}' - write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines] + write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines] #dump_cmd_lines(write_cmd_lines) log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}') @@ -312,8 +271,6 @@ def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job) - remove_folder(write_folder) - def main(): print("Running performance sweep of deepspeed nvme library") diff --git a/csrc/aio/py_test/dgx2_v100_optimal_read.sh b/csrc/aio/py_test/dgx2_v100_optimal_read.sh new file mode 100755 index 000000000000..fdee14bd6eb9 --- /dev/null +++ b/csrc/aio/py_test/dgx2_v100_optimal_read.sh @@ -0,0 +1,21 @@ +python test_ds_aio.py \ + --read \ + --handle --io_size 400M \ + --loops 3 \ + --folder_to_device_mapping \ + /mnt/nvme23/aio:0 \ + /mnt/nvme23/aio:1 \ + /mnt/nvme23/aio:2 \ + /mnt/nvme23/aio:3 \ + /mnt/nvme45/aio:4 \ + /mnt/nvme45/aio:5 \ + /mnt/nvme45/aio:6 \ + /mnt/nvme45/aio:7 \ + /mnt/nvme67/aio:8 \ + /mnt/nvme67/aio:9 \ + /mnt/nvme67/aio:10 \ + /mnt/nvme67/aio:11 \ + /mnt/nvme89/aio:12 \ + /mnt/nvme89/aio:13 \ + /mnt/nvme89/aio:14 \ + /mnt/nvme89/aio:15 \ diff --git a/csrc/aio/py_test/dgx2_v100_optimal_write.sh b/csrc/aio/py_test/dgx2_v100_optimal_write.sh new file mode 100755 index 000000000000..fdd9c63e9387 --- /dev/null +++ b/csrc/aio/py_test/dgx2_v100_optimal_write.sh @@ -0,0 +1,20 @@ +python test_ds_aio.py \ + --handle --io_size 400M \ + --loops 3 \ + --folder_to_device_mapping \ + /mnt/nvme23/aio:0 \ + /mnt/nvme23/aio:1 \ + /mnt/nvme23/aio:2 \ + /mnt/nvme23/aio:3 \ + /mnt/nvme45/aio:4 \ + /mnt/nvme45/aio:5 \ + /mnt/nvme45/aio:6 \ + /mnt/nvme45/aio:7 \ + /mnt/nvme67/aio:8 \ + /mnt/nvme67/aio:9 \ + /mnt/nvme67/aio:10 \ + /mnt/nvme67/aio:11 \ + /mnt/nvme89/aio:12 \ + /mnt/nvme89/aio:13 \ + /mnt/nvme89/aio:14 \ + /mnt/nvme89/aio:15 \ diff --git a/csrc/aio/py_test/dgx2_v100_suboptimal_read.sh b/csrc/aio/py_test/dgx2_v100_suboptimal_read.sh new file mode 100755 index 000000000000..31b815b82331 --- /dev/null +++ b/csrc/aio/py_test/dgx2_v100_suboptimal_read.sh @@ -0,0 +1,6 @@ +python test_ds_aio.py \ + --read \ + --handle --io_size 400M \ + --loops 3 \ + --folder /mnt/nvme23/aio \ + --multi_process 16 diff --git a/csrc/aio/py_test/dgx2_v100_suboptimal_write.sh b/csrc/aio/py_test/dgx2_v100_suboptimal_write.sh new file mode 100755 index 000000000000..f083cd36dfba --- /dev/null +++ b/csrc/aio/py_test/dgx2_v100_suboptimal_write.sh @@ -0,0 +1,5 @@ +python test_ds_aio.py \ + --handle --io_size 400M \ + --loops 3 \ + --folder /mnt/nvme23/aio \ + --multi_process 16 diff --git a/csrc/aio/py_test/ds_aio_args.py b/csrc/aio/py_test/ds_aio_args.py new file mode 100644 index 000000000000..840bac8d519b --- /dev/null +++ b/csrc/aio/py_test/ds_aio_args.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import argparse +import os +from test_ds_aio_utils import refine_integer_value +from ds_aio_constants import AIO_HANDLE, AIO_BASIC, TORCH_FAST_IO, TORCH_IO, VALID_ENGINES +from deepspeed.accelerator import get_accelerator + +MAPPING_DELIMITER = ':' + + +def refine_args(args): + if args.io_size and type(args.io_size) == str: + args.io_size = refine_integer_value(args.io_size) + + if args.block_size and type(args.block_size) == str: + args.block_size = refine_integer_value(args.block_size) + + if args.fast_io_size and type(args.fast_io_size) == str: + args.fast_io_size = refine_integer_value(args.fast_io_size) + + return args + + +def _get_mapping_dict(args): + if args.folder is not None: + d = {i: args.folder for i in range(args.multi_process)} + else: + d = {} + for m in args.folder_to_device_mapping: + fields = m.split(MAPPING_DELIMITER) + d[fields[1]] = fields[0] + + return d + + +def _validate_folder_mapping(args): + no_error = True + error_messages = [] + invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m] + if len(invalid_mappings) > 0: + error_messages.append( + f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}') + no_error = False + + folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping] + invalid_folders = [d for d in folder_list if not os.path.exists(d)] + if len(invalid_folders) > 0: + error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}') + no_error = False + + if args.gpu: + device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping] + invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()] + if len(invalid_device_list) > 0: + error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}') + no_error = False + + return no_error, error_messages + + +def validate_args(args): + no_error = True + error_messages = [] + + if args.folder is not None and len(args.folder_to_device_mapping) > 0: + error_messages.append('--folder and --folder_to_device_mapping cannot be specified together.') + no_error = False + elif args.folder is None and len(args.folder_to_device_mapping) == 0: + error_messages.append('At least one of --folder or --folder_to_device_mapping must be specified.') + no_error = False + + # Validate --folder + if args.folder is not None and not os.path.exists(args.folder): + no_error = False + error_messages.append(f'Invalid folder in --folder: {args.folder} ') + + # Validate --folder_mapping_to_device + if len(args.folder_to_device_mapping) > 0: + no_mapping_error, mapping_error_messages = _validate_folder_mapping(args) + no_error = no_error and no_mapping_error + error_messages += mapping_error_messages + + # Validate --engine + if args.engine not in VALID_ENGINES: + no_error = False + error_messages.append(f'Invalid engine {args.engine}. Valid options = {VALID_ENGINES}') + + # Validate --engine=torch_io + if args.engine == TORCH_IO: + if args.read: + no_error = False + error_messages.append(f'Read not currently supported for --engine={TORCH_IO}') + + if not no_error: + print(f'Found {len(error_messages)} validation error(s)') + # Validate --gpu, --use_gds + if args.use_gds and not args.gpu: + error_messages.append('--gpu must be set to transfer with --use_gds') + no_error = False + + if not no_error: + print(f'Found {len(error_messages)} validation errors') + for i, msg in enumerate(error_messages): + print(f'{i+1}: {msg}') + + return no_error + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.') + + parser.add_argument('--folder_to_device_mapping', + default=[], + nargs='+', + help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).' + 'Can be specified multiple times for multi-process runs,' + 'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu' + 'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15') + + parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.') + + parser.add_argument('--fast_io_size', type=str, default='64M', help='Size of fast_io pinned buffer (bytes).') + + parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)') + + parser.add_argument('--multi_process', + type=int, + default=1, + help='Number of parallel processes doing I/O (default 1).') + + parser.add_argument('--block_size', + type=str, + default='1M', + help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).') + + parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).') + + parser.add_argument('--single_submit', + action='store_true', + help='Submit I/O requests in singles (default is submit queue_depth amount at once.).') + + parser.add_argument( + '--sequential_requests', + action='store_true', + help= + 'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).' + ) + + parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.') + + parser.add_argument( + '--engine', + type=str, + default=AIO_HANDLE, + help= + f'Engine to perform I/O. Options are [{AIO_HANDLE}, {AIO_BASIC}, {TORCH_IO}, {TORCH_FAST_IO}]. Default is aio_handle' + ) + + parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions') + + parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism') + + parser.add_argument('--gpu', action='store_true', help='Use GPU memory') + + parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO') + + parser.add_argument('--slow_bounce_buffer', + action='store_true', + help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.') + + parser.add_argument('--torch_legacy_save', action='store_true', help='Use torch legacy save approach') + + parser.add_argument('--use_accelerator_pin_memory', + action='store_true', + help='Obtain pinned (CPU page-locked) tensors from accelerator') + + parser.add_argument('--warmup_loops', type=int, default=1, help='Count of operation warmup repetitions') + + parser.add_argument('--include_warmup_time', action='store_true', help='Include warmup latency in results') + + parser.add_argument('--different_file_each_iteration', + action='store_true', + help='Read/write a different file on each iteration.') + + args = parser.parse_args() + print(f'args = {args}') + return args + + +def get_validated_args(): + args = parse_arguments() + args = refine_args(args) + if not validate_args(args): + quit() + print('Successful validation of command line arguments') + args.total_loops = args.warmup_loops + args.loops + peer_tag = 'gpu' if args.gpu else 'process' + args.mapping_dict = _get_mapping_dict(args) + args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()] + assert len(args.mapping_dict) == len(args.mapping_list) + print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping') + for i, (device_id, folder) in enumerate(args.mapping_list): + print(f'[{i}]: {peer_tag} {device_id} <----> {folder}') + + return args diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py index ad2a4349cd0c..6003bcbf2ea1 100755 --- a/csrc/aio/py_test/ds_aio_basic.py +++ b/csrc/aio/py_test/ds_aio_basic.py @@ -6,128 +6,59 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -import torch import os import time -from multiprocessing import Pool, Barrier -from test_ds_aio_utils import report_results, task_log, task_barrier -from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import AsyncIOBuilder +from deepspeed.ops.aio import AsyncIOBuilder +from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from ds_aio_constants import * -def pre_basic(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' +class AIOBasic_Engine(object): - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu')) - task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}') + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 + def fini(self): + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None - return ctxt - - -def pre_basic_read(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, True) - return ctxt - - -def pre_basic_write(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, False) - return ctxt - - -def post_basic(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_basic_read(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth, - args.single_submit, args.overlap_events, args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_basic_write(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth, - args.single_submit, args.overlap_events, args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_basic_read - schedule['post'] = post_basic - schedule['main'] = main_basic_read - else: - schedule['pre'] = pre_basic_write - schedule['post'] = post_basic - schedule['main'] = main_basic_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) + def read(self, args, tid, loop_id): + start_time = time.time() + AsyncIOBuilder().load().aio_read(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth, + args.single_submit, not args.sequential_requests, args.validate) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) + def write(self, args, tid, loop_id): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) + AsyncIOBuilder().load().aio_write(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth, + args.single_submit, not args.sequential_requests, args.validate) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') -def _init_tasklet(b): - global aio_barrier - aio_barrier = b + buffer = create_page_locked_tensor(args.io_size, True) + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}') -def aio_basic_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) + task_log(tid, 'created deepspeed aio basic engine') - report_results(args, read_op, pool_results) + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[ELAPSED_SEC] = 0 + return ctxt diff --git a/csrc/aio/py_test/ds_aio_constants.py b/csrc/aio/py_test/ds_aio_constants.py new file mode 100644 index 000000000000..e2ebdbe9f01f --- /dev/null +++ b/csrc/aio/py_test/ds_aio_constants.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +AIO_HANDLE = 'aio_handle' +AIO_BASIC = 'aio_basic' +TORCH_IO = 'torch_io' +TORCH_FAST_IO = 'torch_fastio' +VALID_ENGINES = [AIO_HANDLE, AIO_BASIC, TORCH_IO, TORCH_FAST_IO] + +BUFFER = 'buffer' +BOUNCE_BUFFER = 'bounce_buffer' +NUM_BYTES = 'num_bytes' +FILE = 'file' +HANDLE = 'handle' +ELAPSED_SEC = 'elapsed_sec' +FAST_IO_BUFFER = 'fast_io_buffer' +USE_CPU_LOCKED_TENSOR = 'cpu_locked_tensor' diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index d35b2713edae..aeb8f0862f23 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -2,176 +2,105 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -""" -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" import torch import os import time -from multiprocessing import Pool, Barrier -from test_ds_aio_utils import report_results, task_log, task_barrier -from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import AsyncIOBuilder - - -def pre_handle(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - io_parallel = args.io_parallel if args.io_parallel else 1 - handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, - args.overlap_events, io_parallel) - task_log(tid, f'Created deepspeed aio handle') - - if args.gpu: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name()) - else: - if args.use_accelerator_pin_memory: - buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu')) - else: - buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8)) - - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['handle'] = handle - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 - - task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}') - - return ctxt - - -def pre_handle_read(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, True) - return ctxt - - -def pre_handle_write(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, False) - return ctxt - - -def post_handle(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_parallel_read(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - - start_time = time.time() - ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_parallel_write(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - start_time = time.time() - ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - +from deepspeed.ops.aio import AsyncIOBuilder +from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from ds_aio_constants import * -def main_handle_read(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - start_time = time.time() - ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time +class AIOHandle_Engine(object): - return ctxt + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) + def fini(self): + for buf in [BUFFER, BOUNCE_BUFFER]: + if self.ctxt[buf] is not None: + if self.ctxt[USE_CPU_LOCKED_TENSOR]: + self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf]) -def main_handle_write(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - start_time = time.time() - ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time + self.ctxt[buf].detach() + self.ctxt[buf] = None - return ctxt + def read(self, args, tid, loop_id): + handle = self.ctxt[HANDLE] - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_handle_read - schedule['post'] = post_handle - schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read - else: - schedule['pre'] = pre_handle_write - schedule['post'] = post_handle - schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) - - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) - - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops - - -def _init_tasklet(b): - global aio_barrier - aio_barrier = b - - -def aio_handle_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) - - report_results(args, read_op, pool_results) + dest_buffer = BOUNCE_BUFFER if self.ctxt[BOUNCE_BUFFER] is not None else BUFFER + ret = handle.pread(self.ctxt[dest_buffer], self.ctxt[FILE][loop_id], args.validate, True) + assert ret != -1 + handle.wait() + if dest_buffer == BOUNCE_BUFFER: + self.ctxt[BUFFER].data.copy_(self.ctxt[BOUNCE_BUFFER].data) + end_time = time.time() + self.ctxt[ELAPSED_SEC].append(end_time - start_time) + + def write(self, args, tid, loop_id): + handle = self.ctxt[HANDLE] + start_time = time.time() + if self.ctxt[BOUNCE_BUFFER] is not None: + source_buffer = BOUNCE_BUFFER + self.ctxt[BOUNCE_BUFFER].data.copy_(self.ctxt[BUFFER].data) + else: + source_buffer = BUFFER + ret = handle.pwrite(self.ctxt[source_buffer], self.ctxt[FILE][loop_id], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + self.ctxt[ELAPSED_SEC].append(end_time - start_time) + + def _create_files(self, args, folder, tid): + if args.different_file_each_iteration: + filenames = [ + create_filename(folder, args.read, args.io_size, f'{tid}_{l}') for l in range(args.total_loops) + ] + else: + filenames = [ + create_filename(folder, args.read, args.io_size, f'{tid}_{0}') for _ in range(args.total_loops) + ] + + if args.read: + for f in filenames: + if not (os.path.isfile(f) and os.path.getsize(f) == args.io_size): + create_file(f, args.io_size) + else: + for f in filenames: + if os.path.isfile(f): + os.remove(f) + + return filenames + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filenames = self._create_files(args, folder, tid) + io_parallel = args.io_parallel if args.io_parallel else 1 + handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + task_log(tid, 'created deepspeed aio handle engine') + + bounce_buffer = None + if args.gpu: + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}') + bounce_buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle) + else: + buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle) + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + ctxt = {} + ctxt[FILE] = filenames + ctxt[NUM_BYTES] = args.io_size + ctxt[HANDLE] = handle + ctxt[BUFFER] = buffer + ctxt[BOUNCE_BUFFER] = bounce_buffer + ctxt[ELAPSED_SEC] = [] + ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory + + task_log(tid, + f'{io_string} file {filenames} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + return ctxt diff --git a/csrc/aio/py_test/ds_aio_job.py b/csrc/aio/py_test/ds_aio_job.py new file mode 100644 index 000000000000..e9579a48fe4d --- /dev/null +++ b/csrc/aio/py_test/ds_aio_job.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping tensors to/from (NVMe) storage devices. +""" +import subprocess +import shlex + + +class Job(object): + + def __init__(self, cmd_line, output_file=None, work_dir=None): + self.cmd_line = cmd_line + self.output_file = output_file + self.work_dir = work_dir + self.output_fd = None + + def cmd(self): + return self.cmd_line + + def get_stdout(self): + return self.output_fd + + def get_stderr(self): + return self.output_fd + + def get_cwd(self): + return self.work_dir + + def open_output_file(self): + if self.output_file is not None: + self.output_fd = open(self.output_file, 'w') + + def close_output_file(self): + if self.output_fd is not None: + self.output_fd.close() + self.output_fd = None + + +def run_job(job): + args = shlex.split(' '.join(job.cmd())) + print(f'args = {args}') + job.open_output_file() + proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) + job.close_output_file() + assert proc.returncode == 0, \ + f"This command failed: {job.cmd()}" diff --git a/csrc/aio/py_test/io_engine.py b/csrc/aio/py_test/io_engine.py new file mode 100644 index 000000000000..b62628fe517e --- /dev/null +++ b/csrc/aio/py_test/io_engine.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +from multiprocessing import Pool, Barrier + +from ds_aio_constants import AIO_BASIC, TORCH_FAST_IO, TORCH_IO +from test_ds_aio_utils import report_results, task_log, task_barrier +from ds_aio_handle import AIOHandle_Engine +from ds_aio_basic import AIOBasic_Engine +from torch_io import TorchIO_Engine +from torch_fastio_engine import Torch_FastIO_Engine + + +def prepare_operation(args, tid, read_op): + if args.engine == TORCH_IO: + io_engine = TorchIO_Engine(args, tid, read_op) + elif args.engine == AIO_BASIC: + io_engine = AIOBasic_Engine(args, tid, read_op) + elif args.engine == TORCH_FAST_IO: + io_engine = Torch_FastIO_Engine(args, tid, read_op) + else: + io_engine = AIOHandle_Engine(args, tid, read_op) + + return io_engine + + +def prepare_read(pool_params): + args, tid = pool_params + return prepare_operation(args, tid, True) + + +def prepare_write(pool_params): + args, tid = pool_params + return prepare_operation(args, tid, False) + + +def post_operation(pool_params): + _, _, io_engine = pool_params + io_engine.fini() + + +def read_operation(pool_params): + args, tid, loop_id, io_engine = pool_params + return io_engine.read(args, tid, loop_id) + + +def write_operation(pool_params): + args, tid, loop_id, io_engine = pool_params + return io_engine.write(args, tid, loop_id) + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = prepare_read + schedule['post'] = post_operation + schedule['main'] = read_operation + else: + schedule['pre'] = prepare_write + schedule['post'] = post_operation + schedule['main'] = write_operation + + return schedule + + +def io_engine_tasklet(pool_params): + args, tid, read_op = pool_params + num_processes = len(args.mapping_dict) + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, num_processes) + + # Run pre task + task_log(tid, 'running pre-task') + io_engine = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, num_processes) + + # Run main tasks in a loop + io_engine.ctxt["main_task_sec"] = [] + for i in range(args.total_loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + schedule["main"]((args, tid, i, io_engine)) + task_barrier(aio_barrier, num_processes) + stop_time = time.time() + io_engine.ctxt["main_task_sec"].append(stop_time - start_time) + + # Run post task + task_log(tid, 'running post-task') + schedule["post"]((args, tid, io_engine)) + task_barrier(aio_barrier, num_processes) + + ctxt = io_engine.ctxt + # return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + if args.include_warmup_time: + e2e_latency_sec = sum(ctxt["main_task_sec"]) + task_latency_sec = sum(ctxt["elapsed_sec"]) + actual_loops = args.total_loops + else: + e2e_latency_sec = sum(ctxt["main_task_sec"][args.warmup_loops:]) + task_latency_sec = sum(ctxt["elapsed_sec"][args.warmup_loops:]) + actual_loops = args.loops + + l = ctxt["elapsed_sec"] + task_log(tid, f'task_latency_sec = {l}') + return e2e_latency_sec, task_latency_sec, ctxt["num_bytes"] * actual_loops + + +def _init_takslet(b): + global aio_barrier + aio_barrier = b + + +def io_engine_multiprocessing(args, read_op): + num_processes = len(args.mapping_dict) + b = Barrier(num_processes) + pool_params = [(args, p, read_op) for p in range(num_processes)] + with Pool(processes=num_processes, initializer=_init_takslet, initargs=(b, )) as p: + pool_results = p.map(io_engine_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/parse_aio_stats.py b/csrc/aio/py_test/parse_aio_stats.py index 09c79ada5b36..54c77608cf0d 100755 --- a/csrc/aio/py_test/parse_aio_stats.py +++ b/csrc/aio/py_test/parse_aio_stats.py @@ -50,7 +50,7 @@ def extract_value(key, file): return int(v[0]) * 1024 * 1024 else: return int(key[2:]) - except: + except Exception: print(f"{file}: extract_value fails on {key}") return None diff --git a/csrc/aio/py_test/run_read_sweep.sh b/csrc/aio/py_test/run_read_sweep.sh index b9d7e050454a..9ce6af2291a3 100755 --- a/csrc/aio/py_test/run_read_sweep.sh +++ b/csrc/aio/py_test/run_read_sweep.sh @@ -1,13 +1,22 @@ #!/bin/bash -if [[ $# -ne 2 ]]; then - echo "Usage: $0 " +if [[ $# -lt 2 ]]; then + echo "Usage: $0 " exit 1 fi +function prep_folder() +{ + folder=$1 + if [[ -d ${folder} ]]; then + rm -f ${folder}/* + else + mkdir -p ${folder} + fi +} function validate_environment() { - validate_cmd="python ./validate_async_io.py" + validate_cmd="TORCH_EXTENSIONS_DIR=./torch_extentions python3 ./validate_async_io.py" eval ${validate_cmd} res=$? if [[ $res != 0 ]]; then @@ -17,18 +26,27 @@ function validate_environment() fi } +function fileExists() { + local file="$1" + if [[ -f "$file" ]]; then + return 0 + else + return 1 + fi +} validate_environment -INPUT_FILE=$1 -if [[ ! -f ${INPUT_FILE} ]]; then - echo "Input file not found: ${INPUT_FILE}" - exit 1 -fi - -LOG_DIR=$2/aio_perf_sweep +IO_SIZE=$1 +LOG_DIR=./aio_perf_sweep +MAP_DIR=$2/aio +GPU_MEM=$3 +USE_GDS=$4 RUN_SCRIPT=./test_ds_aio.py -READ_OPT="--read_file ${INPUT_FILE}" +READ_OPT="--read" + +prep_folder ${MAP_DIR} +prep_folder ${LOG_DIR} if [[ -d ${LOG_DIR} ]]; then rm -f ${LOG_DIR}/* @@ -36,37 +54,60 @@ else mkdir -p ${LOG_DIR} fi -DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " -SYNC="sync" +if [[ ${GPU_MEM} == "gpu" ]]; then + gpu_opt="--gpu" +else + gpu_opt="" +fi +if [[ ${USE_GDS} == "gds" ]]; then + gds_opt="--use_gds" +else + gds_opt="" +fi + +DISABLE_CACHE="sudo sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " +SYNC="sudo sync" -for sub in single block; do - if [[ $sub == "single" ]]; then - sub_opt="--single_submit" +for xtype in cpu gpu gds; do + if [[ $xtype == "cpu" ]]; then + gpu_opt="" + gds_opt="" + elif [[ $xtype == "gpu" ]]; then + gpu_opt="--gpu" + gds_opt="" else - sub_opt="" + gpu_opt="--gpu" + gds_opt="--use_gds" fi for ov in overlap sequential; do - if [[ $ov == "overlap" ]]; then - ov_opt="--overlap_events" + if [[ $ov == "sequential" ]]; then + ov_opt="--sequential_requests" else - ov_opt="" + sub_opt="" fi - for t in 1 2 4 8; do - for p in 1 ; do - for d in 1 2 4 8 16 32; do - for bs in 128K 256K 512K 1M; do - SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}" - OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}" - LOG="${LOG_DIR}/read_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" - cmd="python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" - echo ${DISABLE_CACHE} - echo ${cmd} - echo ${SYNC} + for ov in overlap sequential; do + if [[ $ov == "sequential" ]]; then + ov_opt="--sequential_requests" + else + ov_opt="" + fi + for p in 1 2 4 8; do + for t in 1 2 4 8; do + for d in 8 16 32 64 128; do + for bs in 128K 256K 512K 1M 2M 4M 8M 16M; do + SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder_to_device_mapping /mnt/nvme01:0" + OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --io_parallel ${t}" + LOG="${LOG_DIR}/read_${xtype}_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" + cmd="/usr/bin/time python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" - eval ${DISABLE_CACHE} - eval ${cmd} - eval ${SYNC} - sleep 2 + echo ${DISABLE_CACHE} + echo ${cmd} + echo ${SYNC} + eval ${DISABLE_CACHE} + eval ${cmd} + eval ${SYNC} + sleep 2 + done done done done diff --git a/csrc/aio/py_test/run_write_sweep.sh b/csrc/aio/py_test/run_write_sweep.sh index 99f2113dda6f..7954e6ada8de 100755 --- a/csrc/aio/py_test/run_write_sweep.sh +++ b/csrc/aio/py_test/run_write_sweep.sh @@ -26,46 +26,85 @@ function validate_environment() validate_environment if [[ $# -ne 3 ]]; then - echo "Usage: $0 " + echo "Usage: $0 " exit 1 fi -SIZE="$1M" +SIZE=$1 WRITE_DIR=$2 LOG_DIR=$3/aio_perf_sweep -OUTPUT_FILE=${WRITE_DIR}/ds_aio_write_${SIZE}B.pt -WRITE_OPT="--write_file ${OUTPUT_FILE} --write_size ${SIZE}" - +WRITE_OPT="--folder ${WRITE_DIR} --io_size ${SIZE} --loops 3" +IO_ENGINE="torch_fastio" +ENGINE_OPTS="" +if [[ $IO_ENGINE == "aio_handle" ]]; then + IO_PARALLEL="1" # "1 2 4 8" + QUEUE_DEPTH="8 16 32 64 128" + BLOCK_SIZE="128K 256K 512K 1M 2M 4M 8M 16M" + SUBMIT="block" + OVERLAP="overlap" +elif [[ $IO_ENGINE == "torch_fastio" ]]; then + IO_PARALLEL="1" # "1 2 4 8" + QUEUE_DEPTH="8 16 32 64 128" + BLOCK_SIZE="128K 256K 512K 1M 2M 4M 8M 16M" + SUBMIT="block" + OVERLAP="overlap" + ENGINE_OPTS="--torch_legacy --fast_io_size ${SIZE}" +else + IO_PARALLEL="1" + QUEUE_DEPTH="8" + BLOCK_SIZE="128K" + SUBMIT="single" + OVERLAP="sequential" +fi prep_folder ${WRITE_DIR} prep_folder ${LOG_DIR} RUN_SCRIPT=./test_ds_aio.py -DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' " +OUTPUT_FILE=${MAP_DIR}/ds_aio_write_${SIZE}B.pt +WRITE_OPT="" + + +prep_folder ${MAP_DIR} +prep_folder ${LOG_DIR} + + +if [[ ${GPU_MEM} == "gpu" ]]; then + gpu_opt="--gpu" +else + gpu_opt="" +fi +if [[ ${USE_GDS} == "gds" ]]; then + gds_opt="--use_gds" +else + gds_opt="" +fi + +DISABLE_CACHE="sync; bash -c 'echo 1 > /proc/sys/vm/drop_caches' " SYNC="sync" -for sub in single block; do +for sub in ${SUBMIT}; do if [[ $sub == "single" ]]; then sub_opt="--single_submit" else sub_opt="" fi - for ov in overlap sequential; do - if [[ $ov == "overlap" ]]; then - ov_opt="--overlap_events" + for ov in ${OVERLAP}; do + if [[ $ov == "sequential" ]]; then + ov_opt="--sequential_requests" else ov_opt="" fi - for t in 1 2 4 8; do - for p in 1; do - for d in 1 2 4 8 16 32; do - for bs in 128K 256K 512K 1M; do - SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}" - OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}" + for p in 1; do + for t in ${IO_PARALLEL}; do + for d in ${QUEUE_DEPTH}; do + for bs in ${BLOCK_SIZE}; do + SCHED_OPTS="${sub_opt} ${ov_opt} --engine ${IO_ENGINE} --io_parallel ${t} ${ENGINE_OPTS}" + OPTS="--multi_process ${p} --queue_depth ${d} --block_size ${bs}" LOG="${LOG_DIR}/write_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt" - cmd="python ${RUN_SCRIPT} ${WRITE_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" + cmd="python ${RUN_SCRIPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}" echo ${DISABLE_CACHE} echo ${cmd} echo ${SYNC} diff --git a/csrc/aio/py_test/single_process_config.json b/csrc/aio/py_test/single_process_config.json index 275c54135cd8..4a224711b5ff 100644 --- a/csrc/aio/py_test/single_process_config.json +++ b/csrc/aio/py_test/single_process_config.json @@ -2,12 +2,17 @@ "block_size": [ "128K", "256K", - "1M" + "1M", + "2M", + "4M", + "8M", + "16M" ], "queue_depth": [ - 4, + 8, 16, - 32 + 32, + 64 ], "io_parallel": [ 1, @@ -19,7 +24,7 @@ true, false ], - "overlap_events": [ + "sequential_requests": [ true, false ], diff --git a/csrc/aio/py_test/test_ds_aio.py b/csrc/aio/py_test/test_ds_aio.py index e6242cb35789..32aa74611b9a 100755 --- a/csrc/aio/py_test/test_ds_aio.py +++ b/csrc/aio/py_test/test_ds_aio.py @@ -6,79 +6,18 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -import os -import argparse import multiprocessing as mp -from ds_aio_basic import aio_basic_multiprocessing -from ds_aio_handle import aio_handle_multiprocessing -from test_ds_aio_utils import refine_args - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--read_file', type=str, default=None, help='Read file.') - - parser.add_argument('--write_file', type=str, default=None, help='Write file.') - - parser.add_argument('--write_size', type=str, default=None, help='Number of bytes to write.') - - parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') - - parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') - - parser.add_argument('--threads', type=int, default=1, help='Thread parallelism count.') - - parser.add_argument('--single_submit', - action='store_true', - help='Submit I/O requests in singles (default is submit queue_depth amount at once.).') - - parser.add_argument('--overlap_events', - action='store_true', - help='Overlap I/O submission and completion requests.') - - parser.add_argument('--validate', action='store_true', help='Perform validation in library.') - - parser.add_argument('--handle', action='store_true', help='Use AIO handle.') - - parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions') - - parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism') - - parser.add_argument('--gpu', action='store_true', help='Use GPU memory') - - parser.add_argument('--use_accelerator_pin_memory', - action='store_true', - help='Obtain pinned (CPU page-locked) tensors from accelerator') - - args = parser.parse_args() - print(f'args = {args}') - return args - - -def validate_args(args): - if args.read_file and not os.path.isfile(args.read_file): - print(f'args validation error: {args.read_file} not found') - return False - - return True +from ds_aio_args import get_validated_args +from io_engine import io_engine_multiprocessing def main(): - print(f'Testing deepspeed_aio python frontend') - - args = parse_arguments() - refine_args(args) - if not validate_args(args): - quit() - - mp.set_start_method('spawn') - multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing - if args.read_file: - multiprocess_function(args, True) + print('Testing deepspeed_aio python frontend') - if args.write_file: - multiprocess_function(args, False) + args = get_validated_args() + mp.set_start_method('spawn', force=True) + multiprocess_function = io_engine_multiprocessing + multiprocess_function(args, args.read) if __name__ == "__main__": diff --git a/csrc/aio/py_test/test_ds_aio_utils.py b/csrc/aio/py_test/test_ds_aio_utils.py index 6aad114c0bdc..aa67084c706d 100755 --- a/csrc/aio/py_test/test_ds_aio_utils.py +++ b/csrc/aio/py_test/test_ds_aio_utils.py @@ -6,12 +6,19 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ +import os +from ds_aio_job import Job, run_job +import torch +from deepspeed.accelerator import get_accelerator + BYTES_PER_GB = 1024**3 +BYTES_PER_MB = 1024**2 +BYTES_PER_KB = 1024 LOG_TIDS = [0] -def task_log(tid, msg): - if tid in LOG_TIDS: +def task_log(tid, msg, force=False): + if force or tid in LOG_TIDS: print(f'tid {tid}: {msg}') @@ -31,16 +38,29 @@ def report_results(args, read_op, pool_results): total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) task_latency_sec = max([sec for _, sec, _ in pool_results]) - task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB + task_speed_GB = 0 if task_latency_sec == 0 else total_bytes / task_latency_sec / BYTES_PER_GB print(f'Task {io_string} Latency = {task_latency_sec} sec') print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') e2e_latency_sec = max([sec for sec, _, _ in pool_results]) - e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB + e2e_speed_GB = 0 if e2e_latency_sec == 0 else total_bytes / e2e_latency_sec / BYTES_PER_GB print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') +def get_block_size_and_count(io_bytes): + if io_bytes > BYTES_PER_MB and io_bytes % BYTES_PER_MB == 0: + block_size = BYTES_PER_MB + block_size_string = '1M' + else: + assert io_bytes % BYTES_PER_KB == 0 + block_size = BYTES_PER_KB + block_size_string = '1K' + block_count = io_bytes / block_size + + return block_size_string, int(block_count) + + def refine_integer_value(value): unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} @@ -50,9 +70,22 @@ def refine_integer_value(value): return int(value) -def refine_args(args): - if args.write_size and type(args.write_size) == str: - args.write_size = refine_integer_value(args.write_size) +def create_filename(folder, read_op, size, tid): + io_string = "read" if read_op else "write" + return os.path.join(folder, f'_aio_{io_string}_{size}.pt.{tid}') + + +def create_file(filename, num_bytes): + block_size, block_count = get_block_size_and_count(num_bytes) + dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={filename} bs={block_size} count={block_count}']) + print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....') + run_job(dd_job) + print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....') + - if args.block_size and type(args.block_size) == str: - args.block_size = refine_integer_value(args.block_size) +def create_page_locked_tensor(num_elem, use_accelerator, aio_handle=None): + if use_accelerator: + return get_accelerator().pin_memory(torch.randint(high=128, size=(num_elem, ), dtype=torch.uint8, + device='cpu')) + else: + return aio_handle.new_cpu_locked_tensor(num_elem, torch.empty(0, dtype=torch.uint8)) diff --git a/csrc/aio/py_test/torch_fastio_engine.py b/csrc/aio/py_test/torch_fastio_engine.py new file mode 100644 index 000000000000..e16ac4c0417d --- /dev/null +++ b/csrc/aio/py_test/torch_fastio_engine.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from ds_aio_constants import * +from deepspeed.io import FastFileWriter + + +class Torch_FastIO_Engine(object): + + def __init__(self, args, tid, read_op): + assert read_op is False, 'Read operation is not currently supported' + self.ctxt = self._create_context(args, tid, read_op) + self.zipfile_serialization = not args.torch_legacy_save + + def fini(self): + if self.ctxt[USE_CPU_LOCKED_TENSOR]: + for buf in [BUFFER, FAST_IO_BUFFER]: + self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf]) + + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None + + def read(self, args, tid): + start_time = time.time() + torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def write(self, args, tid): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) + + ds_file_writer = FastFileWriter(file_path=self.ctxt[FILE], + aio_handle=self.ctxt[HANDLE], + pinned_tensor=self.ctxt[FAST_IO_BUFFER]) + + start_time = time.time() + torch.save(obj=self.ctxt[BUFFER], f=ds_file_writer, _use_new_zipfile_serialization=self.zipfile_serialization) + ds_file_writer.close() # Force flush to storage + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + ds_file_writer._dump_state() + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + io_parallel = args.io_parallel if args.io_parallel else 1 + aio_handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + + if args.gpu: + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}') + else: + buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, aio_handle) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + fast_io_buffer = create_page_locked_tensor(args.fast_io_size, args.use_accelerator_pin_memory, aio_handle) + + task_log(tid, 'created torch_fastio engine') + + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[HANDLE] = aio_handle + ctxt[FAST_IO_BUFFER] = fast_io_buffer + ctxt[ELAPSED_SEC] = 0 + ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory + + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + return ctxt diff --git a/csrc/aio/py_test/torch_io.py b/csrc/aio/py_test/torch_io.py new file mode 100644 index 000000000000..1177f46724d5 --- /dev/null +++ b/csrc/aio/py_test/torch_io.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import os +import time +from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from ds_aio_constants import * + + +class TorchIO_Engine(object): + + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) + self.zipfile_serialization = not args.torch_legacy_save + + def fini(self): + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None + + def read(self, args, tid): + start_time = time.time() + torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def write(self, args, tid): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) + + start_time = time.time() + torch.save(obj=self.ctxt[BUFFER], f=self.ctxt[FILE], _use_new_zipfile_serialization=self.zipfile_serialization) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + if args.gpu: + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}') + else: + buffer = create_page_locked_tensor(args.io_size, True) + + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + task_log(tid, 'created torch_io engine') + + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[ELAPSED_SEC] = 0 + return ctxt diff --git a/csrc/aio/py_test/validate_async_io.py b/csrc/aio/py_test/validate_async_io.py index 019ec05d49d3..10fb638347bc 100644 --- a/csrc/aio/py_test/validate_async_io.py +++ b/csrc/aio/py_test/validate_async_io.py @@ -7,3 +7,4 @@ """ from deepspeed.ops.op_builder import AsyncIOBuilder assert AsyncIOBuilder().is_compatible() +assert AsyncIOBuilder().load() diff --git a/csrc/aio/utils/dgx2_mount_nvme.sh b/csrc/aio/utils/dgx2_mount_nvme.sh new file mode 100755 index 000000000000..0ed12368d52b --- /dev/null +++ b/csrc/aio/utils/dgx2_mount_nvme.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +MOUNT_CMD="sudo mount -v -o data=ordered" + +for dir in nvme23 nvme45 nvme67 nvme89; do + mnt_point=/mnt/${dir} + sudo mkdir -p ${mnt_point} + sudo chmod -R a+rw ${mnt_point} +done +${MOUNT_CMD} /dev/md127 /mnt/nvme23 +${MOUNT_CMD} /dev/md126 /mnt/nvme45 +${MOUNT_CMD} /dev/md125 /mnt/nvme67 +${MOUNT_CMD} /dev/md124 /mnt/nvme89 + +lsblk -f diff --git a/csrc/aio/utils/dgx2_umount_nvme.sh b/csrc/aio/utils/dgx2_umount_nvme.sh new file mode 100755 index 000000000000..6b820ecf468a --- /dev/null +++ b/csrc/aio/utils/dgx2_umount_nvme.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +UMOUNT_CMD="sudo umount -v" + +for md in md127 md126 md125 md124; do + mnt_device=/dev/${md} + ${UMOUNT_CMD} ${mnt_device} +done + +lsblk -f diff --git a/csrc/common/custom_cuda_kernel.cu b/csrc/common/custom_cuda_kernel.cu deleted file mode 100644 index f46bf303125c..000000000000 --- a/csrc/common/custom_cuda_kernel.cu +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#include "custom_cuda_layers.h" - -__global__ void param_update_kernel(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - - if (id < size) { output[id] = (__half)input[id]; } -} - -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel<<>>(input, output, size); -} - -__global__ void param_update_kernel_half(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - __half2* output_cast = reinterpret_cast<__half2*>(output); - if (id < size) { - float input_f = input[id]; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - output_cast[id] = *input_h; - } -} - -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - size /= 2; - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel_half<<>>(input, output, size); -} diff --git a/csrc/compile/deepcompile.cpp b/csrc/compile/deepcompile.cpp new file mode 100644 index 000000000000..b09223b214c5 --- /dev/null +++ b/csrc/compile/deepcompile.cpp @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#define USE_C10D_NCCL + +namespace dc { + +std::shared_ptr param_registry; +std::unordered_map> executors; +std::shared_ptr reduce_buckets = nullptr; + +c10::intrusive_ptr process_group = nullptr; +c10::intrusive_ptr symm_mem = nullptr; +ncclComm_t nccl_comm; +bool use_symm_mem; +bool profile = false; +bool pre_div_reduce = true; + +int64_t free_activation_threshold; + +bool sync_before_reduce; // for debugging +bool sync_after_reduce; // for debugging +bool sync_before_allgather; // for debugging +bool sync_after_allgather; // for debugging + +std::vector sizes_to_int_vector(at::IntArrayRef sizes) +{ + std::vector result; + for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); } + return result; +} + +void enable_profiling(bool enable) { profile = enable; } + +bool is_profiling() { return profile; } + +c10::intrusive_ptr getSymmMemWorkspace(int64_t size) +{ + c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); + std::vector sizes = {size}; + std::vector strides = {1}; + at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p( + {size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt); + return c10d::symmetric_memory::rendezvous(sym_mem_ws); +} + +void lazy_init_symm_memory() +{ + if (use_symm_mem && !symm_mem) { + int64_t max_param_size = 0; + for (const auto& it : param_registry->getParams()) { + int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size(); + if (size > max_param_size) { max_param_size = size; } + } + symm_mem = getSymmMemWorkspace(max_param_size); + } +} + +ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type) +{ + switch (scalar_type) { + case at::kFloat: return ncclFloat; + case at::kHalf: return ncclHalf; + case at::kDouble: return ncclDouble; + case at::kBFloat16: return ncclBfloat16; + case at::kLong: return ncclInt64; + case at::kInt: return ncclInt; + case at::kChar: return ncclInt8; + default: throw std::runtime_error("Unsupported scalar type"); + } +} + +void reset() +{ + executors.clear(); + // We keep the buckets for memory estimation + // reduce_buckets->clear(); +} + +void cleanup() +{ + reset(); + + ncclCommDestroy(nccl_comm); + process_group = nullptr; + symm_mem = nullptr; +} + +at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id) +{ + if (sync_before_reduce) { c10::cuda::device_synchronize(); } + + assert(hasKey(executors, graph_id)); + if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); } + + if (sync_after_reduce) { c10::cuda::device_synchronize(); } + + return torch::empty({0}, grad_tensor.options()); +} + +at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id) +{ + return torch::empty({0}, grad_tensor.options()); +} + +void free_tensors(std::vector tensors) +{ + if (!profile) { + for (auto& tensor : tensors) { + if (tensor.is_cuda() && tensor.numel() > free_activation_threshold) { + tensor.record_stream(at::cuda::getCurrentCUDAStream()); + tensor.set_data(torch::empty({0}, tensor.options())); + } + } + } +} + +void free_tensors_meta(std::vector tensors) {} + +template +static T get_config(pybind11::object& config, const char* name) +{ + return pybind11::getattr(config, name).cast(); +} + +void init(c10::intrusive_ptr pg, + pybind11::object& config, + int64_t initial_reduce_bucket_size) +{ + process_group = pg; + + ncclUniqueId ncclID; + ncclGetUniqueId(&ncclID); + + // ProcessGroup doesn't have an API to get the CUDA stream for comm calls. + // So we create a NCCL communicator and call NCCL APIs directly. + auto vec = std::vector(reinterpret_cast(&ncclID), + reinterpret_cast(&ncclID) + NCCL_UNIQUE_ID_BYTES); + auto device = torch::Device(torch::kCUDA); + at::Tensor tensor = torch::from_blob(vec.data(), {static_cast(vec.size())}, torch::kUInt8) + .to(torch::Device(torch::kCUDA)); + std::vector bcast_input = {tensor}; + + process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait(); + + // create a new nccl communicator + std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES); + ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank()); + + param_registry = std::make_shared(); + reduce_buckets = std::make_shared( + initial_reduce_bucket_size, get_config(config, "double_buffer")); + use_symm_mem = get_config(config, "symmetric_memory"); + free_activation_threshold = get_config(config, "free_activation_threshold"); + + sync_before_reduce = get_config(config, "sync_before_reduce"); + sync_after_reduce = get_config(config, "sync_after_reduce"); + sync_before_allgather = get_config(config, "sync_before_allgather"); + sync_after_allgather = get_config(config, "sync_after_allgather"); +} + +void start_forward() +{ + lazy_init_symm_memory(); + for (auto& it : executors) { it.second->startForward(); } +} + +void end_forward() +{ + for (auto& it : executors) { it.second->endForward(); } +} + +void start_backward(bool update) +{ + for (auto& it : executors) { it.second->startBackward(update); } +} + +void end_backward(const c10::IValue& deps, long graph_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->endBackward(); +} + +void end_backward_meta(const c10::IValue& deps, long graph_id) {} + +} // namespace dc diff --git a/csrc/compile/init.cpp b/csrc/compile/init.cpp new file mode 100644 index 000000000000..cbb03907d327 --- /dev/null +++ b/csrc/compile/init.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" +#include "z1.h" +#include "z2.h" +#include "z3.h" + +TORCH_LIBRARY(dc, m) +{ + m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor"); + m.def( + "prefetch_params_fused(int graph_id, Tensor[] params, int[] ids," + " ScalarType[]? dtypes = None) -> ()"); + m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)"); + m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)"); + m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor"); + m.def("free_tensors(Tensor[] a) -> ()"); + m.def("offload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("reload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("wait_offload(Tensor a, int id, int id) -> Tensor"); + m.def("wait_reload(Tensor a, int id, int id) -> Tensor"); + m.def("offload_parameter(Tensor a, int id, int id) -> ()"); + m.def("reload_parameter(Tensor a, int id, int id) -> ()"); + m.def("end_backward(Any deps, int graph_id) -> ()"); + + m.def("test_call(Tensor a) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(dc, CPU, m) +{ + m.impl("allgather_param", &dc::allgather_param); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused); + m.impl("wait_allgather", &dc::wait_allgather); + m.impl("release_param", &dc::release_param); + m.impl("reduce_grad", &dc::reduce_grad); + m.impl("free_tensors", &dc::free_tensors); + m.impl("offload_tensor", &dc::offload_tensor); + m.impl("reload_tensor", &dc::reload_tensor); + m.impl("wait_offload", &dc::wait_offload); + m.impl("wait_reload", &dc::wait_reload); + m.impl("offload_parameter", &dc::offload_parameter); + m.impl("reload_parameter", &dc::reload_parameter); + m.impl("end_backward", &dc::end_backward); + + m.impl("test_call", &dc::test_call); +} + +TORCH_LIBRARY_IMPL(dc, CUDA, m) +{ + m.impl("allgather_param", &dc::allgather_param); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused); + m.impl("wait_allgather", &dc::wait_allgather); + m.impl("release_param", &dc::release_param); + m.impl("reduce_grad", &dc::reduce_grad); + m.impl("free_tensors", &dc::free_tensors); + m.impl("offload_tensor", &dc::offload_tensor); + m.impl("reload_tensor", &dc::reload_tensor); + m.impl("wait_offload", &dc::wait_offload); + m.impl("wait_reload", &dc::wait_reload); + m.impl("offload_parameter", &dc::offload_parameter); + m.impl("reload_parameter", &dc::reload_parameter); + m.impl("end_backward", &dc::end_backward); + + m.impl("test_call", &dc::test_call); +} + +TORCH_LIBRARY_IMPL(dc, Meta, m) +{ + m.impl("allgather_param", &dc::allgather_param_meta); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused_meta); + m.impl("release_param", &dc::release_param_meta); + m.impl("wait_allgather", &dc::wait_allgather_meta); + m.impl("reduce_grad", &dc::reduce_grad_meta); + m.impl("free_tensors", &dc::free_tensors_meta); + m.impl("reload_parameter", &dc::reload_parameter_meta); + m.impl("offload_parameter", &dc::offload_parameter_meta); + m.impl("end_backward", &dc::end_backward_meta); +} + +// end_backward may be invoked with dependency placeholders that have already +// become None, in which case the dispatcher sees no tensor arguments. +TORCH_LIBRARY_IMPL(dc, Undefined, m) { m.impl("end_backward", &dc::end_backward); } + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_persistent", &dc::set_persistent, "Set persistent flag for a parameter"); + m.def("enable_profiling", &dc::enable_profiling, "Enable profiling"); + m.def("is_profiling", &dc::is_profiling, "Check if profiling is enabled"); + m.def("init", &dc::init, "Set the process group"); + m.def("cleanup", &dc::cleanup, "Cleanup the process group"); + m.def("register_param", &dc::register_param, "Register a parameter"); + m.def("register_graph_z1", + &dc::register_graph_z1, + "Register graph with a list of ds parameter ids"); + m.def("register_graph_z2", + &dc::register_graph_z2, + "Register graph with a list of ds parameter ids"); + m.def("register_z3_param", &dc::register_z3_param, "Register a parameter"); + m.def("register_graph_z3", + &dc::register_graph_z3, + "Register graph with a list of ds parameter ids"); + m.def("start_forward", &dc::start_forward, "Start forward pass"); + m.def("end_forward", &dc::end_forward, "End forward pass"); + m.def("start_backward", &dc::start_backward, "Start backward pass"); + m.def("cleanup", &dc::cleanup, "Clean up DeepCompile"); + m.def("reset", &dc::reset, "Reset the state"); + m.def("invalidate_gathered_param", &dc::invalidate_gathered_param, "Invalidate gathered param"); + m.def("clear_all_gathered_params", &dc::clear_all_gathered_params, "Clear all gathered params"); +} diff --git a/csrc/compile/util.cpp b/csrc/compile/util.cpp new file mode 100644 index 000000000000..948338028059 --- /dev/null +++ b/csrc/compile/util.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#include + +namespace dc { + +std::string tensorToString(const at::Tensor& t, size_t max_elem, size_t max_str_len) +{ + auto t_cpu = t.flatten() + .slice(0, 0, std::min((int64_t)max_elem, t.numel())) + .to(c10::Device(c10::kCPU), false, true); + + size_t size = std::min(max_elem, productDim(t.sizes())); + + if (t.scalar_type() == c10::ScalarType::Half || t.scalar_type() == c10::ScalarType::BFloat16) { + auto float_ten = t_cpu.to(c10::ScalarType::Float, false, true).contiguous(); + return tensorPtrToString((float*)float_ten.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Float) { + return tensorPtrToString((float*)t_cpu.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Double) { + return tensorPtrToString((double*)t_cpu.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Int) { + int* ptr = static_cast(t_cpu.data_ptr()); + return tensorPtrToString(ptr, size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Long) { + long* ptr = static_cast(t_cpu.data_ptr()); + return tensorPtrToString(ptr, size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Byte) { + unsigned char* ptr = static_cast(t_cpu.data_ptr()); + std::vector vec; + vec.reserve(size); + for (size_t i = 0; i < size; i++) { + vec.push_back(*ptr); + ptr++; + } + return tensorPtrToString(&vec[0], size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Bool) { + bool* ptr = static_cast(t_cpu.data_ptr()); + std::vector vec; + vec.reserve(size); + for (size_t i = 0; i < size; i++) { + vec.push_back(*ptr); + ptr++; + } + return tensorPtrToString(&vec[0], size, max_str_len); + } + std::stringstream ss; + ss << "Failed to convert tensor to string. Invalid type of tensor: " + << toString(t.scalar_type()); + throw std::invalid_argument(ss.str()); +} + +std::string tensorPtrToString(void* ptr, + size_t size, + c10::ScalarType datatype, + size_t max_elem, + size_t max_str_len) +{ + int64_t elem_size = std::min((size_t)max_elem, size); + + if (datatype == c10::ScalarType::Long) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Int) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Double) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Float) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Half || datatype == c10::ScalarType::BFloat16) { + const auto ten = torch::from_blob(ptr, {(int64_t)elem_size}, datatype); + auto float_ten = ten.to(c10::ScalarType::Float, false, true).contiguous(); + return tensorPtrToString((float*)float_ten.data_ptr(), elem_size, max_str_len); + } + std::stringstream ss; + ss << "Failed to convert tensor ptr to string. Invalid type of tensor: " << toString(datatype); + throw std::invalid_argument(ss.str()); +} + +std::string tensorDimToString(const at::Tensor& t) +{ + const auto dim = t.sizes(); + return join_as_str(dim); +} +} // namespace dc diff --git a/csrc/compile/z1.cpp b/csrc/compile/z1.cpp new file mode 100644 index 000000000000..d0c804b7f1ad --- /dev/null +++ b/csrc/compile/z1.cpp @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "z1.h" +#include "deepcompile.h" + +namespace dc { + +class Z1CustomOpExecutor : public CustomOpExecutor { +public: + Z1CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + bool pre_div_reduce) + : CustomOpExecutor(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce) + { + } + ~Z1CustomOpExecutor() {} + + at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id) override + { + if (!hasKey(grad_tensors_, ds_id)) { + grad_tensors_[ds_id] = grad_tensor; + } else { + grad_tensors_[ds_id].add_(grad_tensor); + } + + if (param_updated_) { + CustomOpExecutor::reduceGrad(grad_tensors_[ds_id], ds_id); + grad_tensors_.erase(ds_id); + } + + return at::Tensor(); + } + + void flushReduceBucket(at::ScalarType scalar_type) override + { + if (!hasKey(reduce_tasks_, scalar_type)) { return; } + + blockCopyEvents(scalar_type); + applyPreDivision(scalar_type); + + // NCCL AllReduce operation + ncclGroupStart(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(), + t.getSendBuf().data_ptr(), + t.getSendBuf().numel(), + get_nccl_data_type(scalar_type), + getReductionOp(), + nccl_comm_, + rs_stream_); + if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); } + } + ncclGroupEnd(); + + // Copy results to gradient buffers + { + at::cuda::CUDAStreamGuard guard(rs_stream_); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto param = param_registry_->getParam(t.getDSId()); + auto grad_buf = param.getGradBuffer().flatten(); + + if (grad_buf.numel() == 0) { continue; } + + int64_t offset = param.getOffset(); + auto recv_buf = t.getSendBuf().flatten().index( + {torch::indexing::Slice(offset, offset + grad_buf.numel())}); + grad_buf.copy_(recv_buf); + } + } + + performCleanup(scalar_type); + } + +protected: + std::unordered_map grad_tensors_; +}; + +namespace { + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +} // namespace + +void register_graph_z1(long graph_id, const std::vector& ds_ids) +{ + executors[graph_id] = std::make_shared(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + get_rs_stream(), + get_copy_stream(), + pre_div_reduce); +} + +void register_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + int64_t offset) +{ + param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false); +} + +} // namespace dc diff --git a/csrc/compile/z1.h b/csrc/compile/z1.h new file mode 100644 index 000000000000..1d3607a59b06 --- /dev/null +++ b/csrc/compile/z1.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#pragma once + +namespace dc { + +void register_graph_z1(long graph_id, const std::vector& ds_ids); +void register_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + int64_t offset); +} // namespace dc diff --git a/csrc/compile/z2.cpp b/csrc/compile/z2.cpp new file mode 100644 index 000000000000..09290174f146 --- /dev/null +++ b/csrc/compile/z2.cpp @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "z2.h" +#include "deepcompile.h" + +namespace dc { + +class Z2CustomOpExecutor : public CustomOpExecutor { +public: + Z2CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + bool pre_div_reduce) + : CustomOpExecutor(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce) + { + } + ~Z2CustomOpExecutor() {} + + void endBackward() override + { + CustomOpExecutor::endBackward(); + + if (param_updated_) { + for (auto& it : has_acc_grad_) { it.second = false; } + } + } + + void flushReduceBucket(at::ScalarType scalar_type) override + { + if (!hasKey(reduce_tasks_, scalar_type)) { return; } + + blockCopyEvents(scalar_type); + applyPreDivision(scalar_type); + + // NCCL AllReduce operation + ncclGroupStart(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(), + t.getSendBuf().data_ptr(), + t.getSendBuf().numel(), + get_nccl_data_type(scalar_type), + getReductionOp(), + nccl_comm_, + rs_stream_); + if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); } + } + ncclGroupEnd(); + + // Copy or accumulate results to gradient buffers + { + at::cuda::CUDAStreamGuard guard(rs_stream_); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + bool acc_grad = has_acc_grad_.at(t.getDSId()); + auto param = param_registry_->getParam(t.getDSId()); + auto grad_buf = param.getGradBuffer().flatten(); + + if (grad_buf.numel() == 0) { continue; } + + int64_t offset = param.getOffset(); + auto recv_buf = t.getSendBuf().flatten().index( + {torch::indexing::Slice(offset, offset + grad_buf.numel())}); + if (acc_grad) { + grad_buf.add_(recv_buf); + } else { + grad_buf.copy_(recv_buf); + } + has_acc_grad_[t.getDSId()] = true; + } + } + + performCleanup(scalar_type); + } +}; + +namespace { + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +} // namespace + +void register_graph_z2(long graph_id, const std::vector& ds_ids) +{ + executors[graph_id] = std::make_shared(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + get_rs_stream(), + get_copy_stream(), + pre_div_reduce); +} + +} // namespace dc diff --git a/csrc/compile/z2.h b/csrc/compile/z2.h new file mode 100644 index 000000000000..cc6c3136c20c --- /dev/null +++ b/csrc/compile/z2.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#pragma once + +namespace dc { + +void register_graph_z2(long graph_id, const std::vector& ds_ids); + +} // namespace dc diff --git a/csrc/compile/z3.cpp b/csrc/compile/z3.cpp new file mode 100644 index 000000000000..fdc146b4ec02 --- /dev/null +++ b/csrc/compile/z3.cpp @@ -0,0 +1,651 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "z3.h" +#include "deepcompile.h" + +namespace dc { + +const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000; + +class Z3CustomOpExecutor : public CustomOpExecutor { +public: + Z3CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream ag_stream, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + at::cuda::CUDAStream offload_stream, + at::cuda::CUDAStream reload_stream, + bool pre_div_reduce) + : CustomOpExecutor(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce), + ag_stream_(ag_stream), + offload_stream_(offload_stream), + reload_stream_(reload_stream) + { + for (long ds_id : ds_ids_) { + ag_comm_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + ag_comp_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + + param_use_count_[ds_id] = 0; + } + } + ~Z3CustomOpExecutor() {} + + void endBackward() override + { + CustomOpExecutor::endBackward(); + + if (param_updated_) { + for (auto& it : has_acc_grad_) { + it.second = false; + param_registry_->setValid(it.first, false); + } + } + + for (auto& it : reload_buffers_) { + it.second.record_stream(at::cuda::getCurrentCUDAStream()); + } + reload_buffers_.clear(); + } + + void launchAllGather(at::Tensor output_buf, + long ds_id, + c10::intrusive_ptr symm_mem) + { + const DSParam& param = param_registry_->getParam(ds_id); + at::Tensor ds_tensor = param.getDSTensor(); + + if (ds_tensor.scalar_type() != output_buf.scalar_type()) { + at::cuda::CUDAStreamGuard guard(ag_stream_); + ds_tensor = ds_tensor.to(output_buf.scalar_type(), true, true); + } + + if (symm_mem == nullptr) { + // Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size) + const int world_size = process_group_->getSize(); + const int64_t shard_elems = ds_tensor.numel(); + + // Perform all-gather directly into the pre-allocated padded output buffer + // NCCL requires contiguous storage; use .contiguous() explicitly + ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(), + output_buf.data_ptr(), + shard_elems, + get_nccl_data_type(ds_tensor.scalar_type()), + nccl_comm_, + ag_stream_); + + if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); } + } else { + at::cuda::CUDAStreamGuard guard(ag_stream_); + int world_size = process_group_->getSize(); + int rank = process_group_->getRank(); + + at::Tensor local_buf = + symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); + local_buf.copy_(ds_tensor, true); + + symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); + auto chunks = output_buf.flatten().chunk(world_size); + for (int step = 0; step < world_size; step++) { + int remote_rank = (rank - step + world_size) % world_size; + auto src_buf = symm_mem->get_buffer( + remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); + chunks[remote_rank].copy_(src_buf.flatten(), true); + } + symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); + } + + param_registry_->registerGatheredParam(ds_id, output_buf); + param_registry_->setValid(ds_id, true); + } + + at::Tensor allgatherParam(long ds_id, + std::optional dtype, + c10::intrusive_ptr symm_mem) + { + const DSParam& param = param_registry_->getParam(ds_id); + const at::Tensor& ds_tensor = param.getDSTensor(); + const int world_size = process_group_->getSize(); + const int64_t true_numel = static_cast(productDim(param.getShape())); + const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size; + const int64_t padded_numel = static_cast(world_size) * padded_per_rank; + at::ScalarType target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type(); + + if (param_registry_->isValid(ds_id)) { + // Return a view sliced to the true size with the original shape + // + // Persistent params are gathered in their original dtype which may + // be different from the requested. + auto base = param_registry_->getGatheredParam(ds_id); + return base.flatten() + .to(target_dtype) + .index({torch::indexing::Slice(0, true_numel)}) + .view(param.getShape()); + } + + at::Tensor output_buf; + if (param_registry_->hasGatheredParam(ds_id)) { + auto existing = param_registry_->getGatheredParam(ds_id); + if (existing.defined() && existing.numel() == padded_numel) { output_buf = existing; } + } + if (!output_buf.defined()) { + at::cuda::CUDAStreamGuard guard(ag_stream_); + output_buf = torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype)); + } + + assert(hasKey(ag_comp_done_events_, ds_id)); + ag_comp_done_events_[ds_id]->record(); + ag_comp_done_events_[ds_id]->block(ag_stream_); + + launchAllGather(output_buf, ds_id, symm_mem); + + ag_comm_done_events_[ds_id]->record(ag_stream_); + // Return a view of the gathered padded buffer matching the true param shape + return output_buf.flatten() + .index({torch::indexing::Slice(0, true_numel)}) + .view(param.getShape()); + } + + void prefetchParamsFused(const std::vector& ds_ids, + const std::optional> dtypes, + c10::intrusive_ptr symm_mem) + { + std::vector>> invalid_params; + for (int i = 0; i < ds_ids.size(); i++) { + if (!param_registry_->isValid(ds_ids[i])) { + auto dtype = dtypes ? dtypes.value()[i] : std::optional(); + invalid_params.push_back(std::make_tuple(ds_ids[i], dtype)); + } + } + + std::unordered_map output_bufs; + for (const auto& [ds_id, dtype] : invalid_params) { + const DSParam& param = param_registry_->getParam(ds_id); + const at::Tensor& ds_tensor = param.getDSTensor(); + const int world_size = process_group_->getSize(); + const int64_t shard_elems = ds_tensor.numel(); + const int64_t padded_numel = static_cast(world_size) * shard_elems; + + if (param_registry_->hasGatheredParam(ds_id)) { + auto existing = param_registry_->getGatheredParam(ds_id); + if (existing.defined() && existing.numel() == padded_numel) { + output_bufs[ds_id] = existing; + continue; + } + } + auto target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type(); + output_bufs[ds_id] = + torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype)); + } + + for (const auto& [ds_id, _] : invalid_params) { + ag_comp_done_events_[ds_id]->record(); + ag_comp_done_events_[ds_id]->block(ag_stream_); + } + + ncclGroupStart(); + for (const auto& [ds_id, _] : invalid_params) { + assert(hasKey(output_bufs, ds_id)); + launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem); + } + ncclGroupEnd(); + + for (const auto& [ds_id, _] : invalid_params) { + ag_comm_done_events_[ds_id]->record(ag_stream_); + } + } + + void releaseParam(long ds_id, long n_users) + { + const DSParam& param = param_registry_->getParam(ds_id); + + assert(hasKey(param_use_count_, ds_id)); + if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; } + param_use_count_[ds_id]--; + + if (param_use_count_[ds_id] == 0 && !param.isPersistent()) { + at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id); + + if (gathered_param.defined()) { // gathered param is undefined while profiling + const auto options = gathered_param.options(); + at::Tensor empty_buffer = torch::empty({0}, options); + gathered_param.set_data(empty_buffer); + } + + param_registry_->unregisterGatheredParam(ds_id); + } + } + + at::Tensor waitAllgather(at::Tensor v, long ds_id) + { + assert(hasKey(ag_comm_done_events_, ds_id)); + ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream()); + return v; + } + + void flushReduceBucket(at::ScalarType scalar_type) override + { + if (!hasKey(reduce_tasks_, scalar_type)) { return; } + + blockCopyEvents(scalar_type); + + // Calculate temporary buffer size for accumulated gradients + int64_t tmp_recv_numel = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + if (has_acc_grad_.at(t.getDSId())) { + tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel(); + } + } + + // Allocate temporary buffer if needed + at::Tensor tmp_recv_buf = at::Tensor(); + if (tmp_recv_numel > 0) { + at::cuda::CUDAStreamGuard guard(rs_stream_); + tmp_recv_buf = torch::empty({tmp_recv_numel}, + at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); + } + + applyPreDivision(scalar_type); + + // NCCL ReduceScatter operation + ncclGroupStart(); + int64_t offset = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); + bool acc_grad = has_acc_grad_.at(t.getDSId()); + + if (acc_grad) { + recv_buf = + tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())}); + } + + ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(), + recv_buf.data_ptr(), + recv_buf.numel(), + get_nccl_data_type(scalar_type), + getReductionOp(), + nccl_comm_, + rs_stream_); + if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); } + + if (acc_grad) { offset += recv_buf.numel(); } + } + ncclGroupEnd(); + + // Handle gradient accumulation with temporary buffer + { + at::cuda::CUDAStreamGuard guard(rs_stream_); + int64_t offset = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + bool acc_grad = has_acc_grad_.at(t.getDSId()); + + if (acc_grad) { + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); + recv_buf.add_(tmp_recv_buf.index( + {torch::indexing::Slice(offset, offset + recv_buf.numel())})); + offset += recv_buf.numel(); + } + has_acc_grad_[t.getDSId()] = true; + } + } + + performCleanup(scalar_type); + + // Record stream for temporary buffer to prevent early deallocation + if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); } + } + + at::Tensor offloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(offload_events_, id)) { + offload_events_[id] = std::make_shared(cudaEventDisableTiming); + offload_comp_done_events_[id] = + std::make_shared(cudaEventDisableTiming); + + const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU); + offload_buffers_[id] = at::empty_like(tensor, options); + } + + offload_comp_done_events_[id]->record(); + offload_comp_done_events_[id]->block(offload_stream_); + { + at::cuda::CUDAStreamGuard guard(offload_stream_); + offload_buffers_.at(id).copy_(tensor, true); + } + + tensor.record_stream(offload_stream_); + + offload_events_[id]->record(offload_stream_); + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor reloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(reload_events_, id)) { + reload_events_[id] = std::make_shared(cudaEventDisableTiming); + } + + assert(hasKey(offload_buffers_, id)); + offload_events_[id]->block(reload_stream_); + + at::Tensor ten; + { + at::cuda::CUDAStreamGuard guard(reload_stream_); + + assert(hasKey(offload_buffers_, id)); + at::Tensor buf = offload_buffers_.at(id); + const auto options = at::TensorOptions().device(torch::kCUDA); + ten = at::empty_like(buf, options); + ten.copy_(buf, true); + + reload_buffers_[id] = ten; + } + + reload_events_[id]->record(reload_stream_); + return ten; + } + + at::Tensor waitOffload(at::Tensor tensor, long id) + { + assert(hasKey(offload_events_, id)); + offload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor waitReload(at::Tensor tensor, long id) + { + assert(hasKey(reload_events_, id)); + reload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(reload_buffers_, id)); + auto ten = reload_buffers_.at(id); + + // We can't release here because the tensor is still being used + // We will need "freeReloadedTensor" after the last user of the tensor to call + // ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of + // the backward pass for now. reload_buffers_.erase(id); + return ten; + } + + void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); } + void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); } + + bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); } + + bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } + +private: + at::cuda::CUDAStream ag_stream_; + at::cuda::CUDAStream offload_stream_; + at::cuda::CUDAStream reload_stream_; + + std::unordered_map> ag_comp_done_events_; + std::unordered_map> ag_comm_done_events_; + + std::unordered_map> offload_events_; + std::unordered_map> offload_comp_done_events_; + std::unordered_map> reload_events_; + std::unordered_map offload_buffers_; + std::unordered_map reload_buffers_; + + std::unordered_map param_use_count_; +}; + +namespace { + +at::cuda::CUDAStream get_ag_stream() +{ + static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); + return ag_stream; +} + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +at::cuda::CUDAStream get_offload_stream() +{ + static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); + return offload_stream; +} + +at::cuda::CUDAStream get_reload_stream() +{ + static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); + return reload_stream; +} + +} // namespace + +void register_graph_z3(long graph_id, const std::vector& ds_ids) +{ + executors[graph_id] = std::make_shared(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + get_ag_stream(), + get_rs_stream(), + get_copy_stream(), + get_offload_stream(), + get_reload_stream(), + pre_div_reduce); +} + +void register_z3_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool persistent) +{ + param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent); + if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); } + + // Validate that padded shard sizes are uniform across ranks at registration time + // DeepSpeed pads parameters to ensure even division, so we check the padded size + // which should be uniform across all ranks for correct allgather behavior + const int64_t local_count = ds_tensor.numel(); + const int world_size = process_group->getSize(); + + // Calculate padded size (aligned to world_size) + // Use ds_shape to compute the full (unpartitioned) parameter size + int64_t total_numel = 1; + for (const auto dim : ds_shape) { total_numel *= dim; } + const int64_t padded_per_rank = (total_numel + world_size - 1) / world_size; + + // For verification: all ranks should have the same padded size + auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA); + at::Tensor local_padded_tensor = torch::tensor({padded_per_rank}, count_options); + std::vector all_padded_counts(world_size); + for (int i = 0; i < world_size; ++i) { + all_padded_counts[i] = torch::empty_like(local_padded_tensor); + } + + // Build lvalue buffers for output and input as required by ProcessGroup::allgather + // The first argument must be a single-element vector containing a vector of WORLD_SIZE tensors + std::vector> output_tensors(1); + output_tensors[0] = all_padded_counts; + std::vector input_tensors = {local_padded_tensor}; + process_group->allgather(output_tensors, input_tensors)->wait(); + + // Verify all ranks agree on the padded size + for (int i = 0; i < world_size; ++i) { + int64_t padded_count = all_padded_counts[i].to(torch::kCPU).item(); + if (padded_count != padded_per_rank) { + throw std::runtime_error( + "ZeRO-3 registration error: inconsistent padded shard sizes across ranks. " + "This is an internal error - please report this issue."); + } + } +} + +at::Tensor allgather_param(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype) +{ + auto executor = getExecutor(graph_id, executors); + + if (sync_before_allgather) { c10::cuda::device_synchronize(); } + auto ret = executor->allgatherParam(ds_id, dtype, symm_mem); + if (sync_after_allgather) { c10::cuda::device_synchronize(); } + return ret; +} + +void set_persistent(long ds_id) +{ + param_registry->setPersistent(ds_id, true); + + // Allocate buffer here + // Memory fragmentation will be more severe if we allocate in forward/backward + for (auto& it : executors) { + if (it.second->hasParam(ds_id)) { + auto executor = getExecutor(it.first, executors); + auto dtype = param_registry->getParam(ds_id).getDtype(); + executor->allgatherParam(ds_id, dtype, symm_mem); + } + } +} + +void prefetch_params_fused(long graph_id, + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes) +{ + auto executor = getExecutor(graph_id, executors); + executor->prefetchParamsFused(ds_ids, dtypes, symm_mem); +} + +void prefetch_params_fused_meta(long graph_id, + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes) +{ +} + +// for profiling +void invalidate_gathered_param(long ds_id) +{ + const DSParam& param = param_registry->getParam(ds_id); + if (param.isPersistent()) { return; } + + param_registry->unregisterGatheredParam(ds_id); + param_registry->registerGatheredParam(ds_id, at::Tensor()); +} + +void clear_all_gathered_params() +{ + for (const auto& it : param_registry->getParams()) { + long ds_id = it.first; + const DSParam& param = param_registry->getParam(ds_id); + if (param.isPersistent()) { continue; } + if (param_registry->hasGatheredParam(ds_id)) { + param_registry->unregisterGatheredParam(ds_id); + } + } +} + +at::Tensor allgather_param_meta(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype) +{ + const DSParam& param = param_registry->getParam(ds_id); + auto options = param.getDSTensor().options().device(c10::kMeta); + at::Tensor output_buf = torch::empty(param.getShape(), options.dtype(dtype)); + return output_buf; +} + +at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users) +{ + auto executor = getExecutor(graph_id, executors); + executor->releaseParam(ds_id, n_users); + return dummy; +} + +at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users) +{ + return dummy; +} + +at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->waitAllgather(v, ds_id); + return v; +} + +at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; } + +at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->offloadTensor(tensor, id); +} + +at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->reloadTensor(tensor, id); +} + +at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->waitOffload(tensor, id); +} + +at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + if (profile && !executor->hasReloadBuffer(id)) { return tensor; } + return executor->waitReload(tensor, id); +} + +at::Tensor test_call(at::Tensor a) +{ + std::cout << "test_call" << std::endl; + return a; +} + +void reload_parameter(at::Tensor tensor, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->reloadParameter(tensor, ds_id); +} + +void offload_parameter(at::Tensor tensor, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->offloadParameter(tensor, ds_id); +} +void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {} +void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {} + +} // namespace dc diff --git a/csrc/compile/z3.h b/csrc/compile/z3.h new file mode 100644 index 000000000000..bc095c86cfb6 --- /dev/null +++ b/csrc/compile/z3.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#pragma once + +namespace dc { + +void register_graph_z3(long graph_id, const std::vector& ds_ids); +void register_graph_ops_z3(long graph_id, + const std::vector& op_names, + const std::vector& n_args); +void register_bwd_graph_ops_z3(long graph_id, + const std::vector& op_names, + const std::vector& n_args); +void register_z3_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool persistent); +at::Tensor allgather_param(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype); +void set_persistent(long ds_id); +void prefetch_params_fused(long graph_id, + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes); +void prefetch_params_fused_meta(long graph_id, + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes); +// for profiling +void invalidate_gathered_param(long ds_id); +void clear_all_gathered_params(); +at::Tensor allgather_param_meta(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype); +at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users); +at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users); +at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id); +at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id); +at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id); +at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id); +at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id); +at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id); +void reload_parameter(at::Tensor tensor, long graph_id, long id); +void offload_parameter(at::Tensor tensor, long graph_id, long id); +void reload_parameter_meta(at::Tensor tensor, long graph_id, long id); +void offload_parameter_meta(at::Tensor tensor, long graph_id, long id); +void end_backward(const c10::IValue& deps, long graph_id); +void end_backward_meta(const c10::IValue& deps, long graph_id); +} // namespace dc diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp new file mode 100644 index 000000000000..d25578f410da --- /dev/null +++ b/csrc/cpu/adam/fused_adam.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_adam.h" + +// C++ interface + +void multi_tensor_adam(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, /*gpmv*/ + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ + static bool initialized = false; + if (!initialized) { + create_adam_optimizer(0); + initialized = true; + } + for (int i = 0; i < tensor_lists[0].size(); i++) { + ds_adam_step(0, + step, + lr, + beta1, + beta2, + epsilon, + weight_decay, + bias_correction, + tensor_lists[1][i], + tensor_lists[0][i], + tensor_lists[2][i], + tensor_lists[3][i]); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_adam", + &multi_tensor_adam, + "Compute and apply gradient update to parameters for Adam optimizer"); +} diff --git a/csrc/cpu/comm/arm64/shm.h b/csrc/cpu/comm/arm64/shm.h new file mode 100644 index 000000000000..f6bdc41c6d43 --- /dev/null +++ b/csrc/cpu/comm/arm64/shm.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// NOTE: +// This shared-memory implementation targets AArch64 CPUs. +// Minimum supported architecture is ARMv8-A with NEON (Advanced SIMD) support. +// Systems without NEON are not supported. + +#include +#include +#include +#include + +// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements. +static int vector_length_in_bytes = 16; +// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register. +// Using 8 would require two 128-bit registers, so limit to 4. +static constexpr int full_precision_elements_in_fixed_vector = 4; + +static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input) +{ + // Zero-extend 16-bit to 32-bit and shift left by 16 bits + // BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits + uint32x4_t result_32 = vshll_n_u16(input, 16); + return vreinterpretq_f32_u32(result_32); +} + +static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input) +{ + // Converts 4 FP16 values to 4 FP32 values + return vcvt_f32_f16(input); +} + +// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and +// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling +static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src) +{ + // Reinterpret float32 bits as uint32 + uint32x4_t u32 = vreinterpretq_u32_f32(src); + + const uint32x4_t ones = vdupq_n_u32(0x1); + const uint32x4_t vec_bias = + vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range + const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF); + + // RNE: lsb = (input >> 16) & 1 + uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones); + + // rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1. + uint32x4_t bias = vaddq_u32(vec_bias, lsb); + + // input += rounding_bias + u32 = vaddq_u32(u32, bias); + + // >> 16 to get bfloat16 + // vshrq_n_u32 - keeps 32 bit width after shift + // vshrn_n_u32 - keeps 16 bits width after shift + uint16x4_t bf16 = vshrn_n_u32(u32, 16); + + // vmvnq_u32 is bitwise NOT + // NaN mask: ~(src == src) -> 1 if NaN + // for normal num, ~(src == src) -> 0 + uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src)); + + // Select nan_bf16 if isnan (use 16-bit mask) + uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan)); + return vbsl_u16(mask, nan_bf16, bf16); +} + +// fp32 and fp16 are IEEE formats. +// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the +// lsb but rounds to nearest. +static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input) +{ + // Converts 4 FP32 values to 4 FP16 values with rounding + return vcvt_f16_f32(input); +} + +// Reduce functions down below use vectorized algorithm, the number of bytes processed each +// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); + +void parallel_memcpy(void* to, void* from, size_t n_bytes); + +#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X)) +#define VLOAD_U16(X) vld1_u16((uint16_t*)(X)) +#define VLOAD_F16(X) vld1_f16((float16_t*)(X)) +#define VLOAD_F32(X) vld1q_f32((float32_t*)(X)) + +#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B) +#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B) +#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16 +#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B) + +#define VADD_F32(A, B) vaddq_f32(A, B) +#define VADD_F32_2VL(A, B) vaddq_f32(A, B) + +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X) +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X) +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X) +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X) diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp new file mode 100644 index 000000000000..d28509e59266 --- /dev/null +++ b/csrc/cpu/comm/ccl.cpp @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include +#include "shm.h" + +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif + +// Communication settings +static int world_rank = -1; +static int world_size = -1; + +static std::set _comm_ids; +static std::set _colors; +static std::vector _ccl_comms; +static ccl::shared_ptr_class sub_kvs; +static std::map, int> group_to_comm_id; + +ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; } +ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; } +ccl::communicator& _get_comm_from_group(std::vector ranks) +{ + if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) { + auto id = group_to_comm_id.find(ranks); + return _ccl_comms[id->second]; + } + return _ccl_comms[0]; +} + +#define CCLCHECK(cmd) \ + do { \ + cmd; \ + } while (0) + +#define KVS_CREATE_SUCCESS 0 +#define KVS_CREATE_FAILURE -1 + +static bool is_initialized = 0; + +static ccl::shared_ptr_class kvs; + +static bool all_ranks_local_p = false; + +void initialize(int size, int rank, torch::Tensor& kvs_data) +{ + if (is_initialized) return; + + // Check whether all ranks is on the same physical machine. + // If true, we will use an SHM based low latency allreduce + + auto ls_string = std::getenv("LOCAL_SIZE"); + int ls = 0; + if (ls_string != NULL) { ls = std::stoi(std::getenv("LOCAL_SIZE")); } + + if (size >= 1 && size == ls) { all_ranks_local_p = true; } + + world_size = size; + world_rank = rank; + is_initialized = 1; + + ccl::kvs::address_type main_addr; + + if (rank != 0) { + memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size()); + kvs = ccl::create_kvs(main_addr); + } + + _ccl_comms.emplace_back(ccl::create_communicator(size, rank, kvs)); + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { addr_string = ""; } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { port_string = ""; } + + if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } +} + +/* + rank == 0: create main kvs and return its address + rank == else: return an empty address +*/ +std::vector get_kvs_addr(int rank) +{ + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } else { + ccl::kvs::address_type main_addr; + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } +} + +int get_rank(int group = 0) { return world_rank; } + +int get_world_size(int group = 0) { return world_size; } + +// Find the next ordered, unique value to a set. E.g. <0,1,2,7> --> 3 +int next_unique_val(std::set s) +{ + std::set::iterator itr; + // Base case. Add 0 to start of set. + if (s.empty() || *s.begin() != 0) { + return 0; + // second base case where s = {0} (the case of s = {n != 0} is caught above) + } else if (s.size() == 1) { + return 1; + } else { + int prev_val = *s.begin(); + for (itr = std::next(s.begin()); itr != s.end(); itr++) { + if (*itr != prev_val + 1) { return prev_val + 1; } + prev_val = *itr; + } + return *(s.end()) + 1; + } +} + +std::vector get_sub_kvs_addr(bool first) +{ + if (first) { + sub_kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = sub_kvs->get_address(); + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } else { + ccl::kvs::address_type main_addr; + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } +} + +void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector ranks) +{ + ccl::kvs::address_type main_addr; + if (rank != 0) { + memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size()); + sub_kvs = ccl::create_kvs(main_addr); + } + _ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs)); + group_to_comm_id[ranks] = _ccl_comms.size() - 1; +} + +ccl::datatype get_ccl_datatype(c10::ScalarType type) +{ + ccl::datatype ccl_type; + switch (type) { + case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break; + case c10::ScalarType::Long: ccl_type = ccl::datatype::int64; break; + case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break; + case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break; + case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break; + case c10::ScalarType::Half: ccl_type = ccl::datatype::float16; break; + default: ccl_type = ccl::datatype::int8; + } + return ccl_type; +} + +ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input) +{ + py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); + if (!py::isinstance(op, ReduceOp)) { + throw std::runtime_error("Error: Op must be of type ReduceOp"); + } + + int op_val = py::int_(op.attr("value")); + ccl::reduction ccl_op; + + if (input.scalar_type() == at::kBool) { + if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see cclDataType mapping). + ccl_op = ccl::reduction::max; + } else if (op_val == (int)py::int_(ReduceOp.attr("AVG").attr("value"))) { + throw std::runtime_error("Error: For bool tensors, op must be of type ReduceOp"); + } + } + + if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) { + ccl_op = ccl::reduction::sum; + } else if (op_val == (int)py::int_(ReduceOp.attr("MIN").attr("value"))) { + ccl_op = ccl::reduction::min; + } else if (op_val == (int)py::int_(ReduceOp.attr("MAX").attr("value"))) { + ccl_op = ccl::reduction::max; + } else if (op_val == (int)py::int_(ReduceOp.attr("PRODUCT").attr("value"))) { + ccl_op = ccl::reduction::prod; + } else { + throw std::runtime_error("Error: Unrecognized ReduceOp type"); + } + return ccl_op; +} + +void broadcast(torch::Tensor& data, int src, std::vector group, bool async_op) +{ + CCLCHECK(ccl::broadcast(data.data_ptr(), + data.numel(), + get_ccl_datatype(data.scalar_type()), + src, + _get_comm_from_group(group)) + .wait()); +} + +// TODO: implement torch's async_op behavior, document it. +void all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op) +{ + CCLCHECK(ccl::allreduce(data.data_ptr(), + data.data_ptr(), + data.numel(), + get_ccl_datatype(data.scalar_type()), + get_ccl_reduce_op(op, data), + _get_comm_from_group(group)) + .wait()); +} + +void all_reduce_caching(torch::Tensor& data, + py::object op, + std::string match_id, + std::vector group, + bool async_op) +{ + ccl::allreduce_attr attr = ccl::default_allreduce_attr; + auto match_str = ccl::v1::string(match_id); + attr.template set(true); + attr.template set(match_str); + // To control this, use operation attribute and set true value for to_cache field and unique + // string (for example, tensor name) for match_id field. Note that: + // match_id should be the same for a specific communication operation across all ranks. + // If the same tensor is a part of different communication operations, match_id should have + // different values for each of these operations. + CCLCHECK(ccl::allreduce(data.data_ptr(), + data.data_ptr(), + data.numel(), + get_ccl_datatype(data.scalar_type()), + get_ccl_reduce_op(op, data), + _get_comm_from_group(group), + attr) + .wait()); +} + +void inference_all_reduce(torch::Tensor& data, py::object op) +{ +#ifdef DO_PROFILE + static double total_time = 0.0; + static double total_time_sq = 0.0; + static int count = -16; // warmup + static double max_time = 0.0; + static double min_time = DBL_MAX; + // make sure all rank reach this point before measuring time + // turn on this if you suspect each rank didn't reach here at the same time (stragger) + // if (all_ranks_local_p) { + // barrier_wait(0, world_size); + //} + auto start = std::chrono::system_clock::now(); +#endif + + static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); + static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); + + assert(py::int_(op.attr("value")) == ReduceOpSum); + + auto numel = data.numel(); + + int data_size = 0; + bool data_type_fallback = false; + + switch (data.scalar_type()) { + case c10::ScalarType::BFloat16: data_size = numel * 2; break; + case c10::ScalarType::Float: data_size = numel * 4; break; + default: data_type_fallback = true; + } + + if (data_type_fallback || !all_ranks_local_p) { + // fallback to oneccl allreduce + CCLCHECK(ccl::allreduce(data.data_ptr(), + data.data_ptr(), + data.numel(), + get_ccl_datatype(data.scalar_type()), + get_ccl_reduce_op(op, data), + _get_comm_from_group()) + .wait()); + } else { + all_reduce_outer_loop(data, numel, data_size); + } + +#ifdef DO_PROFILE + auto end = std::chrono::system_clock::now(); + count++; + if (count > 0) { + double elapsed = std::chrono::duration_cast(end - start).count(); + if (elapsed > max_time) { max_time = elapsed; } + if (elapsed < min_time) { min_time = elapsed; } + total_time += elapsed; + total_time_sq += elapsed * elapsed; + if (world_rank == 0 && count == 1000) { + auto avg = total_time / count; + auto sd = + sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100; + printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n", + min_time, + max_time, + total_time / count, + sd); + } + } +#endif +} + +void barrier(std::vector group, bool async_op) +{ + CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait()); +} + +std::vector get_available_coll() +{ + std::vector colls{ + "broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"}; + return colls; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr"); + m.def("initialize", &initialize, "ccl initialize"); + m.def("get_rank", &get_rank, "get rank"); + m.def("get_world_size", &get_world_size, "get world size"); + m.def("broadcast", &broadcast, "ccl broadcast"); + m.def("all_reduce", &all_reduce, "ccl all_reduce"); + m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation"); + m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching"); + m.def("barrier", &barrier, "barrier"); + m.def("initialize_sub_comm", &initialize_sub_comm, "initialize_sub_comm"); + m.def("get_sub_kvs_addr", &get_sub_kvs_addr, "get_sub_kvs_addr"); + m.def("get_available_coll", &get_available_coll, "get_available_coll"); +} diff --git a/csrc/cpu/comm/riscv64/shm.h b/csrc/cpu/comm/riscv64/shm.h new file mode 100644 index 000000000000..475cbd71846b --- /dev/null +++ b/csrc/cpu/comm/riscv64/shm.h @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +using float16_t = _Float16; + +inline vfloat32m2_t cvt_bf16_to_fp32(vuint16m1_t src, size_t vl) __attribute__((target("arch=+v"))); +inline vfloat32m2_t cvt_bf16_to_fp32(vuint16m1_t src, size_t vl) +{ + vuint32m2_t widened = __riscv_vwcvtu_x_x_v_u32m2(src, vl); + return __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsll_vx_u32m2(widened, 16, vl)); +} + +inline vuint16m1_t cvt_fp32_to_bf16(vfloat32m2_t src, size_t vl) __attribute__((target("arch=+v"))); +inline vuint16m1_t cvt_fp32_to_bf16(vfloat32m2_t src, size_t vl) +{ + vuint32m2_t value = __riscv_vreinterpret_v_f32m2_u32m2(src); + vuint32m2_t nan = __riscv_vmv_v_x_u32m2(0xFFFF, vl); + vbool16_t mask_value = __riscv_vmfne_vv_f32m2_b16(src, src, vl); + vuint32m2_t ones = __riscv_vmv_v_x_u32m2(0x1, vl); + vuint32m2_t vec_bias = __riscv_vmv_v_x_u32m2(0x7FFF, vl); + // uint32_t lsb = (input >> 16) & 1; + vuint32m2_t t_value = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(value, 16, vl), 0x1, vl); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = __riscv_vadd_vv_u32m2(t_value, vec_bias, vl); + // input += rounding_bias; + t_value = __riscv_vadd_vv_u32m2(t_value, value, vl); + // input = input >> 16; + t_value = __riscv_vsrl_vx_u32m2(t_value, 16, vl); + // Check NaN before converting back to bf16 + t_value = __riscv_vmerge_vvm_u32m2(t_value, nan, mask_value, vl); + + return __riscv_vncvt_x_x_w_u16m1(t_value, vl); +} + +inline vfloat32m2_t cvt_fp16_to_fp32(vfloat16m1_t src, size_t vl) + __attribute__((target("arch=+v,+zvfh"))); +inline vfloat32m2_t cvt_fp16_to_fp32(vfloat16m1_t src, size_t vl) +{ + return __riscv_vfwcvt_f_f_v_f32m2(src, vl); +} + +inline vfloat16m1_t cvt_fp32_to_fp16(vfloat32m2_t src, size_t vl) + __attribute__((target("arch=+v,+zvfh"))); +inline vfloat16m1_t cvt_fp32_to_fp16(vfloat32m2_t src, size_t vl) +{ + return __riscv_vfncvt_rod_f_f_w_f16m1(src, vl); +} + +// Reduce functions down below use vectorized algorithm, the number of bytes processed each +// iteration depends on vector length. Dynamically acquired via the vsetvl instruction to +// compatible with different vector length. +static int vector_length_in_bytes = -1; + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("arch=+v"))); +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("arch=+v,+zvfh"))); +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("arch=+v"))); + +void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("arch=+v"))); + +#define VLOAD_U8(X) __riscv_vle8_v_u8m1((uint8_t*)(X), vl) +#define VLOAD_U16(X) __riscv_vle16_v_u16m1((uint16_t*)(X), vl) +#define VLOAD_F16(X) __riscv_vle16_v_f16m1((float16_t*)(X), vl) +#define VLOAD_F32(X) __riscv_vle32_v_f32m1((float*)(X), vl) + +#define VSTORE_U8(A, B) __riscv_vse8_v_u8m1((uint8_t*)(A), B, vl) +#define VSTORE_U16(A, B) __riscv_vse16_v_u16m1((uint16_t*)(A), B, vl) +#define VSTORE_F16(A, B) __riscv_vse16_v_f16m1((float16_t*)(A), B, vl) +#define VSTORE_F32(A, B) __riscv_vse32_v_f32m1((float*)(A), B, vl) + +#define VADD_F32(A, B) __riscv_vfadd_vv_f32m1(A, B, vl) +#define VADD_F32_2VL(A, B) __riscv_vfadd_vv_f32m2(A, B, vl) + +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X, vl) +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X, vl) +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X, vl) +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X, vl) diff --git a/csrc/cpu/comm/shm.cpp b/csrc/cpu/comm/shm.cpp new file mode 100644 index 000000000000..40073e6863f2 --- /dev/null +++ b/csrc/cpu/comm/shm.cpp @@ -0,0 +1,675 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include +#include +#include +#include +#include "shm.h" + +#if defined(__riscv) +#define TARGET_RISCV 1 +#include "riscv64/shm.h" +#elif defined(__aarch64__) +#define TARGET_ARM 1 +#include "arm64/shm.h" +#else +#include "x86_64/shm.h" +#endif + +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif + +// states for collectives +enum coll_state { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, +}; + +// SHM building blocks +struct SharedData { + const char* name; + int descriptor; + void* bytes; + size_t nbytes; +}; + +void shared_open(SharedData* data, const char* name, size_t nbytes) +{ + int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); + data->name = name; + data->descriptor = d; + data->bytes = bytes; + data->nbytes = nbytes; + } else { + if (errno != ENOENT) { + // don't print if shm can not be found because we want to loop over from + // caller again until the other ranks created the shm + printf("shared_open %s failed, errno=%d\n", name, errno); + } + data->descriptor = -1; + } +} + +void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) +{ + int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + if (nbytes = write(d, bytes, nbytes)) { shared_open(data, name, nbytes); } + } else { + printf("shared_create %s failed\n", name); + } +} + +void shared_close(SharedData* data) +{ + if (data->descriptor != -1) { + munmap(data->bytes, data->nbytes); + shm_unlink(data->name); + } +} + +static int world_size; + +// SHM based allreduce helper functions +// buffer that holds shm name +#define NAME_BUF_SIZE 1000 +#define MAX_BUF_SIZE 1048576 * 32 +#define NAIVE_ALLREDUCE_THRESHOLD 1048576 +#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +struct allreduce_workspace { + enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for symmetric_naive_all_reduce + // after that : buffer for distributed_naive_all_reduce + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; +}; + +#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE + +struct allreduce_workspace** workspace; + +// buffer for small messages, double buffer +char** symmetric_buffer[2]; +// buffer for large messages, double buffer +char** distributed_buffer[2]; + +void wait_buffer_state_until_2(int index, + enum coll_state state0, + enum coll_state state1, + int state_group) +{ + volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]); + + while (1) { + volatile enum coll_state cur_state = *state_ptr; + if (cur_state == state0 || cur_state == state1) break; + } +} + +void reduce_all_buffers(int start_elements, + int num_elements, + c10::ScalarType scalar_type, + int to_buffer_idx, + char* to_buffer, + char** buffers) +{ + switch (scalar_type) { + case c10::ScalarType::BFloat16: + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case c10::ScalarType::Half: + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case c10::ScalarType::Float: + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + break; + default: assert(!"Should not get here"); + } +} + +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[x] + i)); \ + inout_val = VADD_F32_2VL(inout_val, in##x##_val); \ + } while (0) + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) +{ + const int element_size = 2; +#if TARGET_RISCV + size_t vl = __riscv_vsetvl_e16m1(num_elements); + vector_length_in_bytes = vl * element_size; +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 + const int vl = vector_length_in_bytes / element_size; +#endif + int main_elements = num_elements - (num_elements % vl); + int remain_elements = num_elements % vl; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += vector_length_in_bytes) { + auto inout_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[0] + i)); + switch (world_size) { + case 16: CVT_ADD_BF16(15); + case 15: CVT_ADD_BF16(14); + case 14: CVT_ADD_BF16(13); + case 13: CVT_ADD_BF16(12); + case 12: CVT_ADD_BF16(11); + case 11: CVT_ADD_BF16(10); + case 10: CVT_ADD_BF16(9); + case 9: CVT_ADD_BF16(8); + case 8: CVT_ADD_BF16(7); + case 7: CVT_ADD_BF16(6); + case 6: CVT_ADD_BF16(5); + case 5: CVT_ADD_BF16(4); + case 4: CVT_ADD_BF16(3); + case 3: CVT_ADD_BF16(2); + case 2: CVT_ADD_BF16(1); + case 1: break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[j] + i)); + inout_val = VADD_F32_2VL(inout_val, in_val); + } + } + VSTORE_U16(to_buffer + i, CVT_FP32_TO_BF16(inout_val)); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { val += *(at::BFloat16*)(buffers[j] + i); } + *(at::BFloat16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_FP16(x) \ + do { \ + auto in##x##_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[x] + i)); \ + inout_val = VADD_F32_2VL(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) +{ + const int element_size = 2; +#if TARGET_RISCV + size_t vl = __riscv_vsetvl_e16m1(num_elements); + vector_length_in_bytes = vl * element_size; +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 + const int vl = vector_length_in_bytes / element_size; +#endif + int main_elements = num_elements - (num_elements % vl); + int remain_elements = num_elements % vl; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += vector_length_in_bytes) { + auto inout_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[0] + i)); + switch (world_size) { + case 16: CVT_ADD_FP16(15); + case 15: CVT_ADD_FP16(14); + case 14: CVT_ADD_FP16(13); + case 13: CVT_ADD_FP16(12); + case 12: CVT_ADD_FP16(11); + case 11: CVT_ADD_FP16(10); + case 10: CVT_ADD_FP16(9); + case 9: CVT_ADD_FP16(8); + case 8: CVT_ADD_FP16(7); + case 7: CVT_ADD_FP16(6); + case 6: CVT_ADD_FP16(5); + case 5: CVT_ADD_FP16(4); + case 4: CVT_ADD_FP16(3); + case 3: CVT_ADD_FP16(2); + case 2: CVT_ADD_FP16(1); + case 1: break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[j] + i)); + inout_val = VADD_F32_2VL(inout_val, in_val); + } + } + VSTORE_F16(to_buffer + i, CVT_FP32_TO_FP16(inout_val)); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { val += *(at::Half*)(buffers[j] + i); } + *(at::Half*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = VLOAD_F32(buffers[x] + i); \ + inout_val = VADD_F32(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) +{ + const int element_size = 4; +#if TARGET_RISCV + size_t vl = __riscv_vsetvl_e32m1(num_elements); + vector_length_in_bytes = vl * element_size; +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 + const int vl = vector_length_in_bytes / element_size; +#endif + int main_elements = num_elements - (num_elements % vl); + int remain_elements = num_elements % vl; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += vector_length_in_bytes) { + auto inout_val = VLOAD_F32(buffers[0] + i); + switch (world_size) { + case 16: CVT_ADD_F32(15); + case 15: CVT_ADD_F32(14); + case 14: CVT_ADD_F32(13); + case 13: CVT_ADD_F32(12); + case 12: CVT_ADD_F32(11); + case 11: CVT_ADD_F32(10); + case 10: CVT_ADD_F32(9); + case 9: CVT_ADD_F32(8); + case 8: CVT_ADD_F32(7); + case 7: CVT_ADD_F32(6); + case 6: CVT_ADD_F32(5); + case 5: CVT_ADD_F32(4); + case 4: CVT_ADD_F32(3); + case 3: CVT_ADD_F32(2); + case 2: CVT_ADD_F32(1); + case 1: break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = VLOAD_F32(buffers[j] + i); + inout_val = VADD_F32(inout_val, in_val); + } + } + VSTORE_F32(to_buffer + i, inout_val); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { val += *(float*)(buffers[j] + i); } + *(float*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +static bool is_initialized = 0; +static int world_rank; + +void shm_initialize(int size, int rank, char* addr_string, char* port_string) +{ + if (is_initialized) return; + is_initialized = 1; + + world_size = size; + world_rank = rank; + + char shm_name_prefix[NAME_BUF_SIZE]; + char shm_name[NAME_BUF_SIZE]; + snprintf(shm_name_prefix, + NAME_BUF_SIZE, + "%s_%d_%s_%s", + SHM_BUFFER_NAME, + getuid(), + addr_string, + port_string); + // create shared workspace for SHM based allreduce + SharedData allreduce_buffer; + // allocate workspace_buf for current rank + struct allreduce_workspace* workspace_buf; + struct allreduce_workspace* workspace_buf_other; + workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); + workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; + + // create the workspace pointer list + workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*)); + symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); + + // map shm of all ranks + for (int i = 0; i < size; i++) { + if (i != rank) { + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + // printf("open %s, %d\n", shm_name, rank); + do { + shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; + } + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); + } +} + +void parallel_memcpy(void* to, void* from, size_t n_bytes) +{ +#if TARGET_RISCV + size_t vl = __riscv_vsetvl_e8m1(n_bytes); + vector_length_in_bytes = vl; +#endif + auto aligned_bytes = n_bytes - (n_bytes % vector_length_in_bytes); + // process aligned part +#pragma omp parallel for + for (int i = 0; i < aligned_bytes; i += vector_length_in_bytes) { + auto val = VLOAD_U8((char*)from + i); + VSTORE_U8((char*)to + i, val); + } + + // process remaining part + for (int i = aligned_bytes; i < n_bytes; i++) { *((char*)to + i) = *((char*)from + i); } +} + +#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) +#define rank_mod(rank) positive_mod(rank, world_size) +size_t slice_size(size_t chunk_el, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size; +} + +char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + size_t el_offset = slice_size * slice_idx; + return data_ptr + el_offset * el_size; +} + +size_t slice_el_start(size_t chunk_el, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + return slice_size * slice_idx; +} + +/* + Symmetrical naive all_reduce + step 0: before enter the function ith times, state is copy(i-1) + step 1: each rank copy data from input (data_ptr) to SHM buffer[i] + step 2: set own state to copy(i) + step 3: wait each other rank's state equal or later than copy(i) + step 4: reduce across SHM buffer(ith) directly into output (data_ptr) +*/ +void symmetric_naive_all_reduce(char* data_ptr, + c10::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) +{ +#ifdef DO_PROFILE + static double total_t1_t0 = 0.0; + static double total_t2_t1 = 0.0; + static double total_t3_t2 = 0.0; + static int count = -16; // warmup + auto t0 = std::chrono::system_clock::now(); +#endif + + /* + We can't have infinite number of buffers and states. 2 sets of buffer + and 3 sets of states is just enough. Consider current rank is in step 3, + with it's own state set to copy(i), the other rank will them have the + following situations: + ------------------------------------------------ + my state | can I proceed? | the other rank state + ================================================ + | N | copy(i-1) + |----------------|--------------------- + copy(i) | Y | copy(i) + |----------------|--------------------- + | Y | copy(i+1) + ------------------------------------------------ + * When I have state as copy(i), the other rank cannot have state + copy(i-2) or before. In that case I'll be in state copy(i-1) and cannot + proceed to copy(i). + * The other rank cannot have state copy(i+2) or beyond because my + state is still copy(i), copy(i+1) is as far as the other rank could go. + * From a rank's POV, all the other ranks can be divided into three sets: + - Lagging ranks: ranks that are still working on previous iteration + - Syncing ranks: ranks that are working on current iteration + - Leading ranks: ranks that are working on next iteration + * We can have 3 sets of states, one set for syncing ranks; one set for + lagging ranks; one set of leading ranks. With 3 sets of states, we can + distinguish between lagging and leading ranks. + * Note from any rank's POV, leading ranks and lagging ranks does not + appear at the same time. Either all other ranks are syncing or + lagging, or all other ranks are syncing or leading. Otherwise leading + and lagging ranks will be 2 iterations apart and this should not happen. + * So we have 2 sets of buffers, one buffer is used by current iter; + one buffer used by either lagging ranks or leading ranks. + */ + const int state_group = 0; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + copy_next = coll_alt2_allreduce_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allreduce_naive__copy_in_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 3; + + parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + +#ifdef DO_PROFILE + auto t1 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + // wait until the other rank copy the buffer + if (i != world_rank) { wait_buffer_state_until_2(i, copy_current, copy_next, state_group); } + } +#ifdef DO_PROFILE + auto t2 = std::chrono::system_clock::now(); +#endif + + // each rank reduce the buffer independently so therre is no need for synchronization afterward + reduce_all_buffers( + 0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]); + + // switch buffer + current_buffer = 1 - current_buffer; + +#ifdef DO_PROFILE + auto t3 = std::chrono::system_clock::now(); + + count++; + if (count > 0) { + total_t1_t0 += std::chrono::duration_cast(t1 - t0).count(); + total_t2_t1 += std::chrono::duration_cast(t2 - t1).count(); + total_t3_t2 += std::chrono::duration_cast(t3 - t2).count(); + if (world_rank == 0 && count == 1000) { + printf("symmetric_naive_all_reduce time breakdown:\n"); + printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); + printf("\twait for copy: %.2f\n", total_t2_t1 / count); + printf("\treduce: %.2f\n", total_t3_t2 / count); + } + } +#endif +} + +// naive allreduce distributed, each rank do naive reduce on its slice +void distributed_naive_reduce(char* data_ptr, + c10::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) +{ +#ifdef DO_PROFILE + static double total_t1_t0 = 0.0; + static double total_t2_t1 = 0.0; + static double total_t3_t2 = 0.0; + static double total_t4_t3 = 0.0; + static double total_t5_t4 = 0.0; + static int count = -16; // warmup + auto t0 = std::chrono::system_clock::now(); +#endif + + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next, reduce_current; + + // similar to symmetric_naive_allreduce, but here we only need two sets of + // states, because distributed naive reduce has two barriers in the algorithm + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + reduce_current = coll_allreduce_naive__reduce_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + reduce_current = coll_alt1_allreduce_naive__reduce_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 2; + + int data_size = chunk_size / chunk_el; + parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + +#ifdef DO_PROFILE + auto t1 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + if (i != world_rank) + wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); + } + +#ifdef DO_PROFILE + auto t2 = std::chrono::system_clock::now(); +#endif + + // reduce scatter + reduce_all_buffers(slice_el_start(chunk_el, world_rank), + slice_size(chunk_el, world_rank), + scalar_type, + world_rank, + distributed_buffer[current_buffer][world_rank], + distributed_buffer[current_buffer]); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = reduce_current; + +#ifdef DO_PROFILE + auto t3 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks reduce the buffer + if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); + } + + auto t4 = std::chrono::system_clock::now(); + + for (int i = 0; i < world_size; i++) { + int rank = (i + world_rank) % world_size; + parallel_memcpy( + slice_data(data_ptr, chunk_el, data_size, rank), + slice_data( + distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank), + slice_size(chunk_el, rank) * data_size); + } + + current_buffer = 1 - current_buffer; + +#ifdef DO_PROFILE + auto t5 = std::chrono::system_clock::now(); + count++; + if (count > 0) { + total_t1_t0 += std::chrono::duration_cast(t1 - t0).count(); + total_t2_t1 += std::chrono::duration_cast(t2 - t1).count(); + total_t3_t2 += std::chrono::duration_cast(t3 - t2).count(); + total_t4_t3 += std::chrono::duration_cast(t4 - t3).count(); + total_t5_t4 += std::chrono::duration_cast(t5 - t4).count(); + if (world_rank == 0 && count == 1000) { + printf("distributed_naive_reduce time breakdown:\n"); + printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); + printf("\twait for copy: %.2f\n", total_t2_t1 / count); + printf("\treduce: %.2f\n", total_t3_t2 / count); + printf("\twait for reduce finish: %.2f\n", total_t4_t3 / count); + printf("\tcopy out: %.2f\n", total_t5_t4 / count); + } + } +#endif +} + +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) +{ + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data.data_ptr()) + offset); + size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / numel); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) + symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + else + distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + } +} diff --git a/csrc/cpu/comm/shm.h b/csrc/cpu/comm/shm.h new file mode 100644 index 000000000000..4aed6ecda0de --- /dev/null +++ b/csrc/cpu/comm/shm.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef __SHM_COLLECTIVES__ +#define __SHM_COLLECTIVES__ +void shm_initialize(int size, int rank, char* addr_string, char* port_string); +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size); +void barrier_wait(int root_idx, int num_ranks); +#endif diff --git a/csrc/cpu/comm/shm_interface.cpp b/csrc/cpu/comm/shm_interface.cpp new file mode 100644 index 000000000000..5be5cb799a7b --- /dev/null +++ b/csrc/cpu/comm/shm_interface.cpp @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "shm.h" + +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif + +// Communication settings +static int world_rank = -1; +static int world_size = -1; + +static bool is_initialized = 0; + +static bool all_ranks_local_p = false; + +void initialize(int size, int rank) +{ + if (is_initialized) return; + + // Check whether all ranks is on the same physical machine. + // If true, we will use an SHM based low latency allreduce + + auto ls_string = std::getenv("LOCAL_SIZE"); + int ls = 0; + if (ls_string != NULL) { ls = std::stoi(std::getenv("LOCAL_SIZE")); } + + if (size >= 1 && size == ls) { all_ranks_local_p = true; } + + world_size = size; + world_rank = rank; + is_initialized = 1; + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { addr_string = ""; } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { port_string = ""; } + + if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } +} + +void inference_all_reduce_(torch::Tensor& data, int op); + +// Success - return 0 +// Fail (cannot hornor the request and need to fall back) - return -1 +void inference_all_reduce_(torch::Tensor& data, int op) +{ + assert(op == 0); +#ifdef DO_PROFILE + static double total_time = 0.0; + static double total_time_sq = 0.0; + static int count = -16; // warmup + static double max_time = 0.0; + static double min_time = DBL_MAX; + // make sure all rank reach this point before measuring time + // turn on this if you suspect each rank didn't reach here at the same time (stragger) + // if (all_ranks_local_p) { barrier_wait(0, world_size); } + auto start = std::chrono::system_clock::now(); +#endif + + auto numel = data.numel(); + + int data_size = 0; + bool data_type_fallback = false; + + switch (data.scalar_type()) { + case c10::ScalarType::BFloat16: data_size = numel * 2; break; + case c10::ScalarType::Half: data_size = numel * 2; break; + case c10::ScalarType::Float: data_size = numel * 4; break; + default: data_type_fallback = true; + } + + if (data_type_fallback) return; + + all_reduce_outer_loop(data, numel, data_size); + +#ifdef DO_PROFILE + auto end = std::chrono::system_clock::now(); + count++; + if (count > 0) { + double elapsed = std::chrono::duration_cast(end - start).count(); + if (elapsed > max_time) { max_time = elapsed; } + if (elapsed < min_time) { min_time = elapsed; } + total_time += elapsed; + total_time_sq += elapsed * elapsed; + if (world_rank == 0 && count == 1000) { + auto avg = total_time / count; + auto sd = + sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100; + printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n", + min_time, + max_time, + total_time / count, + sd); + } + } +#endif + return; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); } + +TORCH_LIBRARY(deepspeed, m) +{ + m.def("inference_all_reduce(Tensor self) -> Tensor"); + m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)"); +} + +torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_) +{ + torch::Tensor result_ = torch::empty_like(self_); + return result_; +} + +torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; } + +torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_) +{ + TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU); + torch::Tensor self_tensor = self_.contiguous(); + inference_all_reduce_(self_tensor, 0); + return self_; +} + +torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_) +{ + torch::Tensor result = self_.clone(); + inference_all_reduce__cpu(result); + return result; +} + +#include +// The boilerplate functionalization logic, that teaches functionalization +// how to map x_() calls into x() calls. +// Long term, we'd like to not require users to write this logic. +// HOWEVER, if you have a custom op that is mutable, +// You will still need to write an out-of-place version of that op! +at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x) +{ + // We expect all tensor inputs to our op to be "functional tensors" + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x)); + // First, sync and unwrap and functional tensors + at::functionalization::impl::sync(x); + auto x_ = at::functionalization::impl::from_functional_tensor(x); + // Grab the dispatcher entry corresponding to the out-of-place op, "x" + static auto op_handle = c10::Dispatcher::singleton() + // specify namespace::op_name, op_overload_name + .findSchemaOrThrow("deepspeed::inference_all_reduce", "") + // Specify the C++ schema of the out-of-place op. + .typed(); + // Next, redispatch to the out-of-place op, x() (user called x_, we call x) + at::Tensor tmp_output; + { + at::AutoDispatchSkipFunctionalize guard; + tmp_output = op_handle.call(x_); + } + // Finally, tell functionalization about this mutation. + at::functionalization::impl::replace_(x, tmp_output); + at::functionalization::impl::commit_update(x); + at::functionalization::impl::sync(x); + return x; +} + +TORCH_LIBRARY_IMPL(deepspeed, CPU, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_cpu); + m.impl("inference_all_reduce_", inference_all_reduce__cpu); +} + +TORCH_LIBRARY_IMPL(deepspeed, Meta, m) +{ + m.impl("inference_all_reduce", inference_all_reduce_meta); + m.impl("inference_all_reduce_", inference_all_reduce__meta); +} + +TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m) +{ + m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue); +} diff --git a/csrc/cpu/comm/x86_64/shm.h b/csrc/cpu/comm/x86_64/shm.h new file mode 100644 index 000000000000..9b02eb3779cd --- /dev/null +++ b/csrc/cpu/comm/x86_64/shm.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +inline __m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_bf16_to_fp32(const __m256i src) +{ + auto y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_bf16(const __m512 src) +{ + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +inline __m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); } + +inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_fp16(const __m512 src) +{ + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +// Reduce functions down below use vectorized algorithm, the number of bytes processed each +// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes +// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs +// to be changed +static int vector_length_in_bytes = 32; + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); + +void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw"))); + +#define VLOAD_U8(X) _mm256_loadu_si256((__m256i*)(X)) +#define VLOAD_U16(X) _mm256_loadu_si256((__m256i*)(X)) +#define VLOAD_F16(X) _mm256_loadu_si256((__m256i*)(X)) +#define VLOAD_F32(X) _mm256_loadu_ps((float*)(X)) + +#define VSTORE_U8(A, B) _mm256_storeu_si256((__m256i*)(A), B) +#define VSTORE_U16(A, B) _mm256_storeu_si256((__m256i*)(A), B) +#define VSTORE_F16(A, B) _mm256_storeu_si256((__m256i*)(A), B) +#define VSTORE_F32(A, B) _mm256_storeu_ps((float*)(A), B) + +#define VADD_F32(A, B) _mm256_add_ps(A, B) +#define VADD_F32_2VL(A, B) _mm512_add_ps(A, B) + +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X) +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X) +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X) +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X) diff --git a/csrc/cpu/lion/fused_lion.cpp b/csrc/cpu/lion/fused_lion.cpp new file mode 100644 index 000000000000..708df7f0146a --- /dev/null +++ b/csrc/cpu/lion/fused_lion.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_lion.h" + +// C++ interface + +void multi_tensor_lion(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, /*gpmv*/ + const float lr, + const float beta1, + const float beta2, + const int step, + const int mode, + const float weight_decay) +{ + static bool initialized = false; + if (!initialized) { + create_lion_optimizer(0); + initialized = true; + } + for (int i = 0; i < tensor_lists[0].size(); i++) { + ds_lion_step(0, + step, + lr, + beta1, + beta2, + weight_decay, + tensor_lists[1][i], + tensor_lists[0][i], + tensor_lists[2][i]); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_lion", + &multi_tensor_lion, + "Compute and apply gradient update to parameters for Lion optimizer"); +} diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cpp b/csrc/deepspeed4science/evoformer_attn/attention.cpp new file mode 100644 index 000000000000..ac3364539ff1 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +void attention_impl(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse); +void attention(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse) +{ + attention_impl(q, k, v, bias1, bias2, o, lse); +} + +void attention_back_impl(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2); +void attention_bwd(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("attention", &attention, ""); + m.def("attention_bwd", &attention_bwd, ""); +} diff --git a/csrc/deepspeed4science/evoformer_attn/attention_back.cu b/csrc/deepspeed4science/evoformer_attn/attention_back.cu new file mode 100644 index 000000000000..a82c4ec68a13 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention_back.cu @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include "gemm_kernel_utils.h" +#include "kernel_backward.h" +#include "transform/bias_broadcast.h" + +constexpr auto kBlockSizeI = 64; +constexpr auto kBlockSizeJ = 64; + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_back_impl_template( + torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + EVOFORMER_CHECK(false, "Unsupported GPU and data type combination") +} + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_back_impl_template( + torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + constexpr bool kPreload_ = arch::kMinComputeCapability >= 80; + using Kernel = AttentionBackwardKernel; + int head_size = q.size(-1); + int head_number = q.size(-2); + int seq_length = q.size(-3); + auto q_view = q.view({-1, seq_length, head_number, head_size}); + auto k_view = k.view({-1, seq_length, head_number, head_size}); + auto v_view = v.view({-1, seq_length, head_number, head_size}); + auto o_view = o.view({-1, seq_length, head_number, head_size}); + auto do_view = go.view({-1, seq_length, head_number, head_size}); + auto dk_view = gk.view({-1, seq_length, head_number, head_size}); + auto dv_view = gv.view({-1, seq_length, head_number, head_size}); + auto dq_view = gq.view({-1, seq_length, head_number, head_size}); + auto q_ptr = reinterpret_cast(q.data_ptr()); + auto k_ptr = reinterpret_cast(k.data_ptr()); + auto v_ptr = reinterpret_cast(v.data_ptr()); + auto o_ptr = reinterpret_cast(o.data_ptr()); + auto do_ptr = reinterpret_cast(go.data_ptr()); + auto dk_ptr = reinterpret_cast(gk.data_ptr()); + auto dv_ptr = reinterpret_cast(gv.data_ptr()); + auto dq_ptr = reinterpret_cast(gq.data_ptr()); + auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast(gb1.data_ptr()) : nullptr; + auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast(gb2.data_ptr()) : nullptr; + auto lse_ptr = reinterpret_cast(lse.data_ptr()); + auto delta_ptr = reinterpret_cast(delta.data_ptr()); + auto bias1_ptr = reinterpret_cast(bias1.data_ptr()); + auto bias2_ptr = reinterpret_cast(bias2.data_ptr()); + static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta"); + + typename Kernel::Params p; + p.query_ptr = q_ptr; + p.key_ptr = k_ptr; + p.value_ptr = v_ptr; + p.logsumexp_ptr = lse_ptr; + p.output_ptr = o_ptr; + p.grad_output_ptr = do_ptr; + p.delta_ptr = delta_ptr; + p.grad_query_ptr = dq_ptr; + p.grad_key_ptr = dk_ptr; + p.grad_value_ptr = dv_ptr; + + p.grad_bias1_ptr = db1_ptr; + p.grad_bias2_ptr = db2_ptr; + p.B = q.size(0); + p.N = q.size(1); + p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr; + p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr; + + p.scale = 1.0f / sqrtf(head_size); + + p.head_dim = head_size; + p.head_dim_value = head_size; + p.num_queries = seq_length; + p.num_keys = seq_length; + p.num_heads = head_number; + + p.q_strideM = q_view.stride(-3); + p.k_strideM = k_view.stride(-3); + p.v_strideM = v_view.stride(-3); + p.gO_strideM = do_view.stride(-3); + p.o_strideH = o_view.stride(-2); + p.q_strideH = q_view.stride(-2); + p.k_strideH = k_view.stride(-2); + p.v_strideH = v_view.stride(-2); + p.o_strideB = o_view.stride(-4); + p.q_strideB = q_view.stride(-4); + p.k_strideB = k_view.stride(-4); + p.v_strideB = v_view.stride(-4); + p.lse_strideB = lse.stride(-3); + p.lse_strideH = lse.stride(-2); + p.delta_strideB = delta.stride(-3); + p.delta_strideH = delta.stride(-2); + p.num_batches = q_view.size(-4); + + p.gO_strideB = do_view.stride(-4); + p.gQ_strideB = dq_view.stride(-4); + p.gK_strideB = dk_view.stride(-4); + p.gV_strideB = dv_view.stride(-4); + p.gO_strideH = do_view.stride(-2); + p.gQ_strideH = dq_view.stride(-2); + p.gK_strideH = dk_view.stride(-2); + p.gV_strideH = dv_view.stride(-2); + + torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options()); + p.workspace = workspace.data_ptr(); + + auto kernel_fn = attention_kernel_backward_batched_impl; + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)); + if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); } + kernel_fn<<>>(p); +} + +#define CODE(scalar_t, torch_scalar_t) \ + do { \ + if (bias1.size(0) == 0 && bias2.size(0) == 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else if (bias1.size(0) > 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } \ + } while (0) + +void attention_back_impl(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + DISPATCH_ARCHTAG(prop->major * 10 + prop->minor, + DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); })); +} diff --git a/csrc/deepspeed4science/evoformer_attn/attention_cu.cu b/csrc/deepspeed4science/evoformer_attn/attention_cu.cu new file mode 100644 index 000000000000..37636c4bf988 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention_cu.cu @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include "gemm_kernel_utils.h" +#include "kernel_forward.h" +#include "transform/bias_broadcast.h" + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_impl_template( + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + float* lse_ptr) +{ + EVOFORMER_CHECK(false, "Unsupported GPU and data type combination") +} + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_impl_template( + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + float* lse_ptr) +{ + // Attention definition goes here, replaced with BroadcastType1 and + // BroadcastType2 + using Attention = AttentionKernel; + + static_assert(!Attention::kNeedsOutputAccumulatorBuffer, + "This test does not support output accumulator buffer"); + int head_size = q.size(-1); + int head_number = q.size(-2); + int seq_length = q.size(-3); + auto q_view = q.view({-1, seq_length, head_number, head_size}); + auto k_view = k.view({-1, seq_length, head_number, head_size}); + auto v_view = v.view({-1, seq_length, head_number, head_size}); + auto o_view = o.view({-1, seq_length, head_number, head_size}); + int batch_size = q_view.size(0); + auto q_ptr = reinterpret_cast(q.data_ptr()); + auto k_ptr = reinterpret_cast(k.data_ptr()); + auto v_ptr = reinterpret_cast(v.data_ptr()); + auto o_ptr = reinterpret_cast(o.data_ptr()); + + auto bias1_ptr = reinterpret_cast(bias1.data_ptr()); + auto bias2_ptr = reinterpret_cast(bias2.data_ptr()); + + typename Attention::Params p; + { // set parameters + p.query_ptr = q_ptr; + p.key_ptr = k_ptr; + p.value_ptr = v_ptr; + p.logsumexp_ptr = lse_ptr; // Only needed for bw + p.output_accum_ptr = nullptr; + p.output_ptr = o_ptr; + p.scale = 1.0f / sqrt(float(head_size)); + + p.bias1_ptr = bias1_ptr; + p.bias2_ptr = bias2_ptr; + p.B = q.size(0); + p.N = q.size(1); + + p.num_heads = head_number; + p.num_batches = batch_size; + p.head_dim = head_size; + p.head_dim_value = head_size; + p.num_queries = seq_length; + p.num_keys = seq_length; + + // All tensors are in BMHK shapes + p.q_strideH = q_view.stride(-2); + p.k_strideH = k_view.stride(-2); + p.v_strideH = v_view.stride(-2); + p.q_strideM = q_view.stride(-3); + p.k_strideM = k_view.stride(-3); + p.v_strideM = v_view.stride(-3); + p.o_strideM = o_view.stride(-3); + p.q_strideB = q_view.stride(-4); + p.k_strideB = k_view.stride(-4); + p.v_strideB = v_view.stride(-4); + } + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); } + kernel_fn<<>>(p); +} + +#define CODE(scalar_t, torch_scalar_t) \ + do { \ + if (bias1.size(0) == 0 && bias2.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else if (bias1.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else if (bias2.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else { \ + attention_impl_template( \ + q, k, v, bias1, bias2, o, lse_ptr); \ + } \ + } while (0) + +// Function to select and call the correct template based on biases sizes +void attention_impl(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse) +{ + auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast(lse.data_ptr()); + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + DISPATCH_ARCHTAG(prop->major * 10 + prop->minor, + DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); })); +} diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h new file mode 100644 index 000000000000..17b6479ed8c5 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include +#include "../iterators/predicated_tile_iterator_atomic.h" +#include "cutlass/epilogue/threadblock/epilogue.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { +template +struct EpilogueTensorOpAffineRankN : public DefaultEpilogueTensorOpAffineRankN { + using Base = DefaultEpilogueTensorOpAffineRankN; + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + Rank>; + + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueVoltaTensorOpAffineRankN + : public DefaultEpilogueVoltaTensorOpAffineRankN { + using Base = DefaultEpilogueVoltaTensorOpAffineRankN; + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + Rank>; + + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueTensorOp : public DefaultEpilogueTensorOp { + using Base = DefaultEpilogueTensorOp; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + ScatterD, + PermuteDLayout>; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueVoltaTensorOp : public DefaultEpilogueVoltaTensorOp { + using Base = DefaultEpilogueVoltaTensorOp; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + ScatterD, + PermuteDLayout>; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +template +struct BiasGradEpilogue { + using Epilogue = + typename cutlass::epilogue::threadblock::EpilogueTensorOp::Epilogue; +}; + +template +struct BiasGradEpilogue { + using Epilogue = + typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOp::Epilogue; +}; + +template +struct BiasGradEpilogueAffineRankN { + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueTensorOpAffineRankN< + Rank, + Shape_, + WarpMmaTensorOp_, + PartitionsK, + OutputOp_, + ElementsPerAccess>::Epilogue; +}; + +template +struct BiasGradEpilogueAffineRankN { + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOpAffineRankN< + Rank, + Shape_, + WarpMmaTensorOp_, + PartitionsK, + OutputOp_, + ElementsPerAccess>::Epilogue; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h new file mode 100644 index 000000000000..3b7b32d61452 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) + { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput + apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum) + { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template ::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase { +public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = + Array; + using SourceAccessType = Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = + Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + static_assert(OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert(OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert(SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + +public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined(typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) + { + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()(OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator source_iterator) + { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()(OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators) + { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + +private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators ///< Complete warp-level accumulator tile + ) + { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = + add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_(destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + ) + { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { __syncthreads(); } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = + add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_(destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_(int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) + { + OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) + { + OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) + { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { return row_offset; } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h new file mode 100644 index 000000000000..f81a09f74f1e --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template , + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { +public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + +private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + +public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return !isFirst; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const + { + assert(!isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) const + { + assert(isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator(alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template +struct ApplyEpilogueOp< + thread::MemoryEfficientAttentionNormalize> { + using Op = thread::MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) + { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput + apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum) + { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 000000000000..46fb2bf17c1c --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const + { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { result[i] = expf(input[i]); } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()(Array const& input) const + { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { res_ptr[i] = h2exp(input_ptr[i]); } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template +class ApplyLogSumExp { +public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + +public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const + { + FragmentCompute frag_AB = + NumericArrayConverter()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()(bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter()( + frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h new file mode 100644 index 000000000000..75833bbfe7d2 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h @@ -0,0 +1,119 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template +struct MakeCustomMma, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage; +}; + +template +struct MakeCustomMma, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h new file mode 100644 index 000000000000..bbf91240b900 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = + GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() + { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { return TensorRef{buffer.data(), Layout()}; } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + using SharedStorageA = + OperandSharedStorage; + using SharedStorageB = + OperandSharedStorage; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h new file mode 100644 index 000000000000..3760ccab852a --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h @@ -0,0 +1,714 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { +public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireMat ? Stages : Stages - 1; + +private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) + { + } + + CUTLASS_DEVICE + bool set_prologue_done(bool value) + { + prologue_done_ = value; + return true; + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) + { + zero_outside_bounds_ = value; + return true; + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue(iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) + { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma(tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma(accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h new file mode 100644 index 000000000000..07b26ca31299 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { +public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) + { + } + + CUTLASS_DEVICE + bool set_prologue_done(bool value) + { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) + { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) + { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h new file mode 100644 index 000000000000..163dcbf85259 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template +struct FindDefaultMma 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 000000000000..5e2f0cf681bf --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,347 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col + + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = + typename cutlass::platform::conditional::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + static_assert(cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + static_assert(cutlass::platform::is_same>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaSimtTileIterator; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h new file mode 100644 index 000000000000..40d3265c7a63 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h @@ -0,0 +1,1939 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template +class AccumulatorSharedStorage { +public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = + cutlass::MatrixShape; + +public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + +public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { return TensorRefAccum{accum.data(), LayoutAccum()}; } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum value for K + int kMaxK, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = + GemmShape; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + +protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { +public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset(typename TensorRef::TensorCoord const&) { return *this; } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { return *this; } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { +public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) + { + Fragment converted_scale_frag = + cutlass::NumericArrayConverter()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { +public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { return frag; } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + // END smem + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory + : public MmaBaseFromSharedMemory { +public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = + FragmentElementwiseScaler; + +protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + +public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::SharedStorage& shared_storage, ///< Shared storage needed for internal use + ///< by threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) + { + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = TransformB()) + { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma(accum, + FragmentAScaler::apply(warp_frag_A[warp_mma_k % 2], + warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { +public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireB ? Base::kStages + : Base::kStages - 1; + +private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = + FragmentElementwiseScaler; + +private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + +public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::SharedStorage& shared_storage, ///< Shared storage needed for internal use + ///< by threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { prologue_done_ = value; } + + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) + { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue(iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() + { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1& iterator_B1, int group_start_B1 = 0) + { + iterator_B1.set_iteration_index(group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue(IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) + { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { ++iterator_B1; } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply(warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply(warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma1(tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1(accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +template +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<(sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + + using WarpIterator = + cutlass::gemm::warp::WarpIteratorFromSmem; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<(sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + using RegularMma = MmaPipelined; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = + typename DefaultWarpIteratorAFromSharedMemory::WarpIterator; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast::Iterator; + + using Mma = + typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + + using RegularMma = MmaMultistage; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorA_ = + typename DefaultWarpIteratorAFromSharedMemory::WarpIterator; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform:: + conditional::type; + + static int constexpr kMaxK = kIsTransposedA ? AccumulatorSharedStorage_::Shape::kM + : AccumulatorSharedStorage_::Shape::kN; + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast::Iterator; + using Mma = typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = typename cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = + typename cutlass::epilogue::warp::TileIteratorTensorOp; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator + // - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue(minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorVoltaTensorOp, + scalar_t, + SmemAccumulatorLayout>; + + // // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + 16, + 32>, // typename SmemIteratorD0::TensorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + + using OutputLayout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = + typename cutlass::platform::conditional::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset(tile_coords * cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast(ref_.data() + ref_.offset({r, c})) = + to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template +struct B2bGemm, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage // Padding + >; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset(tile_coords * cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h new file mode 100644 index 000000000000..95b6e8ad214e --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "cutlass/arch/mma.h" + +template +struct CheckArch { + static constexpr bool isPreVolta = arch::kMinComputeCapability < 70; + // DISPATCH_ARCHTAG only binds Sm70/Sm75/Sm80+, so overlap with isPreVolta is unreachable. + static constexpr bool isPreAmpere = arch::kMinComputeCapability < 80; + static constexpr bool isAmpere = arch::kMinComputeCapability >= 80; + static constexpr bool value = (isPreVolta && std::is_same_v) || + (isPreAmpere && !std::is_same_v) || + isAmpere; +}; + +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if ((CC) >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func; \ + } else if ((CC) >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func; \ + } else if ((CC) >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func; \ + } else { \ + EVOFORMER_CHECK(false, "Only GPUs with Tensor Core (SM >= 70) are supported"); \ + } \ + } + +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (tensor.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + using torch_scalar_t = at::Half; \ + func; \ + } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + using torch_scalar_t = at::BFloat16; \ + func; \ + } else { \ + EVOFORMER_CHECK(false, "Only fp16 and bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + EVOFORMER_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define EVOFORMER_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { return false; } +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { return false; } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "[Evoformer Attention]" << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ + } +#endif + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) +{ + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) +{ + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(ta(arg)) + { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(tb(arg)) + { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE int32_t warp_uniform(int32_t value) +{ + return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) +{ + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h b/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 000000000000..667f1982d30d --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template +class PredicatedTileIteratorPrefetch { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = + reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() + { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() + { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = + (uint64_t)((void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const + { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h new file mode 100644 index 000000000000..ff0e324c3a6c --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template +struct MakeIteratorResidualLast< + PredicatedTileIterator> { + using Iterator = PredicatedTileIteratorResidualLast; +}; + +template +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 000000000000..7f6a2430845a --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,1964 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + +private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) + { + the_predicates.compute_predicates_(extent, is_steady_state); + } + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + the_predicates(extent), + indices_(indices) + { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset(layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) + { + if (is_residual_tile) { the_predicates.set_mask(residual_tile_mask); } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const + { + if (Gather) { + assert(indices_); + + if (!valid()) { return nullptr; } + + LongIndex contiguous_offset = + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * LongIndex(params_.stride_) * + sizeof_bits::value / 8; + + return reinterpret_cast(pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { return *this; } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { pointer_ += params_.inc_strided_; } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() : stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_({layout.stride(0), layout.stride(1)}) + { + inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = + inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + +private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) + { + the_predicates.compute_predicates_(extent, is_steady_state); + } + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + the_predicates(extent) + { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) + { + if (is_residual_tile) { the_predicates.set_mask(residual_tile_mask); } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const + { + return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { return *this; } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h new file mode 100644 index 000000000000..8d4173f1a6a2 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h @@ -0,0 +1,886 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include +#include +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct atomic_store {}; + +template +struct atomic_store::value>::type> { + using Element = typename AccessType::Element; + static const int kCount = AccessType::kElements; + + CUTLASS_DEVICE + atomic_store(AccessType const& D, void* ptr, bool pred_guard) + { + static_assert(!(kCount % 2), "kCount must be even"); + half2* p = reinterpret_cast(ptr); + uint const* data = reinterpret_cast(&D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + : + : "r"((int)pred_guard)); + for (int i = 0; i < kCount / 2; i++) { + asm volatile(" @p red.relaxed.global.add.noftz.f16x2 [%0], %1;\n" + : + : "l"(p + i), "r"(data[i])); + } + asm volatile("}\n" ::); + } +}; + +template +struct atomic_store::value>::type> { + using Element = typename AccessType::Element; + static const int kCount = AccessType::kElements; + + CUTLASS_DEVICE + atomic_store(AccessType const& D, void* ptr, bool pred_guard) + { + Element* p = reinterpret_cast(ptr); + uint const* data = reinterpret_cast(&D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + : + : "r"((int)pred_guard)); + for (int i = 0; i < kCount; i++) { + asm volatile(" @p red.relaxed.global.add.f32 [%0], %1;\n" + : + : "l"(p + i), "r"(data[i])); + } + asm volatile("}\n" ::); + } +}; + +template +class PredicatedTileIteratorAffineRankNAtomic { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::AffineRankN; + using TensorRef = TensorRef; + using TensorView = TensorView; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = typename Layout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + static_assert(!(Layout::kRank % 2), + "Layout rank must be even. This assumes the first half of the " + "modes correspond to the 'row' " + "and the second half of the modes correspond to the 'column'"); + + static bool const kBigEndian = false; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Parameters structure + struct Params { + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(TensorCoord const& extent, Layout const& layout_) : layout(layout_) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + } + + CUTLASS_HOST_DEVICE + Params(Layout const& layout_) : layout(layout_) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; + rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; + } + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in columns + Index extent_col_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have + /// been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Offsets in columns, cached for performance + int64_t offset_modes_n_[ThreadMap::Iterations::kColumn]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAffineRankNAtomic( + Params const& params, + Element* pointer, + MatrixCoord extent, + int thread_idx, + MatrixCoord threadblock_offset = MatrixCoord(), + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params) + { + MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_col_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + if (Layout::kRank > 2) { + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + // + // Compute coordinate and decompose into N modes + // + + int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn; + + mask_.predicates[c] = coord_n < extent.column(); + + Coord modes_n; + + int64_t offset_modes_n = 0; + + if (kBigEndian) { + modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } else { + modes_n = CoordinateDecompositionLittleEndian( + coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } + + offset_modes_n_[c] = offset_modes_n; + } + + if (!pointer) { mask_.clear(); } + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, + params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian( + coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + // + // Compute coordinate and decompose into N modes + // + + if (Layout::kRank > 2) { offset_modes_n = offset_modes_n_[column]; } + + // + // Compute the pointer and access + // + bool guard; + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && + ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < + extent_col_); + } + + atomic_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard); + + if (Layout::kRank == 2) { offset_modes_n += params_.rank2_inc_col; } + } + + if (Layout::kRank == 2) { offset_modes_m += params_.rank2_inc_row; } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load(Fragment& frag) {} + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineRankNAtomic& operator++() + { + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +template +class PredicatedTileIteratorAtomic { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), + /// unless PermuteD is performed. When having PermuteD, byte_pointer_ is only + /// for load(). + uint8_t* byte_pointer_; + + /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ + /// may be with different address computation compared to byte_pointer_. + uint8_t* store_byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + /// PermuteDLayout + PermuteDLayout permute_layout_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAtomic(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), + indices_(indices), + permute_layout_(PitchLinearCoord(extent.column(), extent.row()), + params_.stride * kElementsPerAccess / sizeof(AccessType)) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = + reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // store_byte_pointer_ is set to be the same with byte_pointer_ unless + // PermuteD is used. + store_byte_pointer_ = PermuteD ? reinterpret_cast(pointer) : byte_pointer_; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = store_byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (PermuteD) { + int col_offset = column * ThreadMap::Delta::kColumn; + + int col = col_offset + thread_start_column_; + int row = row_offset + thread_start_row_; + + // Locate memory_pointer + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + permute_layout_(PitchLinearCoord(col, row)) * sizeof(AccessType) / + kElementsPerAccess); + } + atomic_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + + if (!PermuteD) { + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD && !PermuteD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load(Fragment& frag) {} + + CUTLASS_DEVICE + MatrixCoord thread_start() const + { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAtomic& operator++() + { + ++state_[0]; + + if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_row; } + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + store_byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + store_byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + store_byte_pointer_ += params_.advance_tile; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAtomic& operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + byte_pointer_ += (params_.advance_row * increment); + store_byte_pointer_ += (params_.advance_row * increment); + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + byte_pointer_ += (params_.advance_group * increment_row); + store_byte_pointer_ += (params_.advance_group * increment_row); + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * increment_row; + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + byte_pointer_ += (params_.advance_cluster * increment_group); + store_byte_pointer_ += (params_.advance_cluster * increment_group); + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow * increment_group; + + // Tile + byte_pointer_ += (params_.advance_tile * increment_cluster); + store_byte_pointer_ += (params_.advance_tile * increment_cluster); + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 000000000000..629047dbb057 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,1938 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_(params.params_, pointer, extent, thread_id, threadblock_offset, indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { address_iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const* access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { *access_ptr = frag_ptr[idx]; } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, threadblock_offset) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { address_iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const* access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { *access_ptr = frag_ptr[idx]; } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) + { + } + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) + { + } + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h b/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h new file mode 100644 index 000000000000..2435c07f8989 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h @@ -0,0 +1,57 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp::WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp::WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h b/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h new file mode 100644 index 000000000000..7dd59832b4b0 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + bool kTranspose = false> +class WarpIteratorFromSmem { +public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = cutlass::MatrixShape<16, 8>; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape; + + static int const kIterations = (kOperand == Operand::kA) ? InstructionCount::kColumn + : InstructionCount::kRow; + +public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); + static int constexpr kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + +private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + +public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) + { + } + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) + { + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert(InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; + ++access_m_idx) { + int access_idx = + access_m_idx + + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset(access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + origin_ += offset; + } + } + } + } + } else { + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset(inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) + { + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn); + if (kTranspose) { coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() + { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() + { + iterations_++; + + if (iterations_ >= kIterations) advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const + { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = + typename platform::conditional::type; + + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + cutlass::arch::ldsm(access_ptr[0], ref_.data() + ref_.offset(offset)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/kernel_backward.h b/csrc/deepspeed4science/evoformer_attn/kernel_backward.h new file mode 100644 index 000000000000..87e6df18bb04 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/kernel_backward.h @@ -0,0 +1,1965 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "gemm_kernel_utils.h" + +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "iterators/epilogue_predicated_tile_iterator.h" + +#include "epilogue/epilogue_grad_bias.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" +#include "transform/bias_broadcast.h" +#include "transform/tile_smem_loader.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements; + static_assert(FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load(sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store(sub_fragment, gmem_ptr, true); + } + } +}; + +template +constexpr int getWarpsPerSm() +{ + constexpr bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + template class Broadcast1_ = BroadcastNoLoad, + template class Broadcast2_ = BroadcastNoLoad> +struct AttentionBackwardKernel { + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] + + accum_t* grad_bias1_ptr = nullptr; + accum_t* grad_bias2_ptr = nullptr; + int32_t B = 0; + int32_t N = 0; + scalar_t* bias1_ptr = nullptr; + scalar_t* bias2_ptr = nullptr; + + // Accumulators + union { + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gk; + }; + output_accum_t* workspace_gv; // (will be calculated by the kernel) + output_accum_t* workspace_gq; // (will be calculated by the kernel) + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int32_t gB_strideM; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int64_t lse_strideB; + int64_t lse_strideH; + int64_t delta_strideB; + int64_t delta_strideH; + int32_t num_batches; + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE bool advance_to_block() + { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = workspace_gv + workspace_elements_gv(); + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + using broadcast_1 = Broadcast1_; + using broadcast_2 = Broadcast2_; + + if (broadcast_1::kEnable && grad_bias1_ptr) { + grad_bias1_ptr += batch_id * num_queries; + } + if (broadcast_2::kEnable && grad_bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + grad_bias2_ptr += (batch_id / N) * strideB + head_id * strideH; + } + if (broadcast_1::kEnable && bias1_ptr) { + bias1_ptr = broadcast_1::advance(bias1_ptr, + batch_id / N, + batch_id % N, + head_id, + num_queries * N, + num_queries, + 0); + } + if (broadcast_2::kEnable && bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + bias2_ptr = broadcast_2::advance( + bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH); + } + + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + if (broadcast_1::kEnable) { + grad_bias1_ptr = warp_uniform(grad_bias1_ptr); + bias1_ptr = warp_uniform(bias1_ptr); + } + if (broadcast_2::kEnable) { + grad_bias2_ptr = warp_uniform(grad_bias2_ptr); + bias2_ptr = warp_uniform(bias2_ptr); + } + + return true; + } + + __host__ dim3 getBlocksGrid() const { return dim3(1, num_heads, num_batches); } + __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const + { + if (!kNeedsAccumGradK) { return 0; } + return align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const + { + if (!kNeedsAccumGradV) { return 0; } + return align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const + { + if (!kNeedsAccumGradQ) { return 0; } + if (num_keys <= kBlockSizeJ) { return 0; } + return align_up(num_queries, (int32_t)kBlockSizeI) * + align_up(head_dim, (int32_t)kBlockSizeJ); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const + { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const + { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + }; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem every time + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert(!kPreload || (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + static constexpr bool kNeedsAccumGradQ = + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = + !kOutputInRF && !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = + !kOutputInRF && !cutlass::platform::is_same::value; + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSm() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr auto kOptimalAlignement = + cutlass::platform::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = + TileSmemLoader, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulQK::AccumulatorSharedStorage, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = accum_t; // CSY: Change it for better accuracy + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulDOIVJ::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulQK::AccumulatorSharedStorage, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulDOIVJ::AccumulatorSharedStorage, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = + typename cutlass::platform::conditional::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + using broadcast_1 = Broadcast1_; + using broadcast_2 = Broadcast2_; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + // typename MatmulQK::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed: + // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij + // are loaded for the computation of dVj. + // - to compute dPij = (dOi @ Vj.T) * Zij + // 6. used in dVj += (Pij.T * Zij) @ dOi + // 9. used in dPij = dPij_dropped * Zij + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; + } part4; + }; +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + cutlass::AlignedBuffer bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed: + // - in this step, where it is used to compute Pij_dropped = Pij * Zij + // on the + // fly as fragments of Pij are loaded for the computation of dVj. + // - later to compute dPij = (dOi @ Vj.T) * Zij + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // (from part2) - Zij for computing dPij = dPij_dropped * Zij + ZijSharedStorage zij; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; + } part6; + }; +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform:: + conditional::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() + { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) + { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + EVOFORMER_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); + EVOFORMER_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + EVOFORMER_CHECK(kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + EVOFORMER_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + EVOFORMER_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + EVOFORMER_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + EVOFORMER_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + EVOFORMER_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + EVOFORMER_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + EVOFORMER_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + EVOFORMER_CHECK(p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) + { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + if (kPrologueQK) { + prologueQkNextIteration(shared_storage, p, 0, 0, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + int32_t key_start = 0; + int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; + for (; key_start < key_end; key_start += kBlockSizeJ) { + output_frags.clear(); + int32_t query_start = getQueryStart(p, key_start); + int32_t query_end = + query_start + (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI; + for (; query_start < query_end; query_start += kBlockSizeI) { + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + // last (partial) query + if (query_start < p.num_queries) { + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV(p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + // Last (partial) key + if (key_start != p.num_keys) { + output_frags.clear(); + int32_t query_start = getQueryStart(p, key_start); + for (; query_start < p.num_queries; query_start += kBlockSizeI) { + warp_id = warp_uniform(warp_id); + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV(p, key_start, warp_id, lane_id); + } + } + } + + static CUTLASS_DEVICE void loadDi(cutlass::Array& di, + Params const& p, + int32_t query_start) + { + int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + if (thread_id < kBlockSizeI) { + accum_t di_rf = accum_t(0); + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + di[thread_id] = di_rf; + } + } + + template + static CUTLASS_DEVICE void zfillGradKV(Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { continue; } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { gk_ptr[k] = scalar_t(0); } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ(SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + __syncthreads(); + loadDi(shared_storage.di(), p, query_start); + + int32_t num_queries_in_block = + skipBoundsChecks ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kN, + p.num_queries - query_start)); + int32_t num_keys_in_block = + skipBoundsChecks ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, + p.num_keys - key_start)); + + auto prologueGradV = [&](int col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), iterator_dO, thread_id, num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), iterator_Q, thread_id, num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), iterator_A, iterator_B, thread_id, p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size(num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), thread_id, warp_id, lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + if (broadcast_1::kEnable || broadcast_2::kEnable) { + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + using Shape = cutlass::MatrixShape; + AttentionBiasEpilogue + bias_epilogue; + bias_epilogue(bias_tensor_ref, + p.bias1_ptr + key_start, + p.bias2_ptr + query_start * p.num_keys + key_start, + thread_id, + {num_queries_in_block, num_keys_in_block}, + p.num_keys); + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] = accum[idx] * scale + bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } else { + accum = cutlass::multiplies()(scale, accum); + } + + __syncthreads(); + if (kPrologueGV) { prologueGradV(0); } + if (kPrologueDOV) { prologueDOV(); } + + MatmulQK::B2bGemm::accumApplyLSEToSmem(shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); + + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B({int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma(shared_storage.mm_gradV(), + // operand A: Pij + typename MatmulGradV::WarpIteratorA( + shared_storage.attn_shared_storage().accum_ref(), lane_id), + // if we're using dropout, operand A is Pij_dropped = Pij * Zij + // which is computed on the fly as fragments of Pij are loaded in + typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(), lane_id), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, output_frags.gradV, iterator_B, output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem(shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A({int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { prologueGradQ(0); } + if (kPrologueGK) { prologueGradK(0); } + + int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, output_tile_coords); + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + if (skipBoundsChecks || + (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + using DefaultGemm = typename MatmulDOIVJ::DefaultGemm; + using OutputOp = typename MatmulDOIVJ::BiasGradEpilogueOutputOp; + if (broadcast_1::kEnable && p.grad_bias1_ptr) { + using Epilogue = + typename BiasGradEpilogueAffineRankN::Epilogue; + cutlass::layout::AffineRankN<2> layout({0, 1}); + auto dst_ptr = p.grad_bias1_ptr + key_start; + typename Epilogue::OutputTileIterator output_iter( + {layout}, + dst_ptr, + {num_queries_in_block, num_keys_in_block}, + (int)thread_id); + Epilogue epilogue(shared_storage.gradB_epilogue(), + (int)thread_id, + (int)warp_id, + (int)lane_id); + epilogue(OutputOp(1), output_iter, accum); + } + + if (broadcast_2::kEnable && p.grad_bias2_ptr) { + if (broadcast_1::kEnable) { __syncthreads(); } + using Epilogue = + typename BiasGradEpilogue::Epilogue; + typename Epilogue::OutputTileIterator::Params params{p.num_keys}; + auto dst_ptr = p.grad_bias2_ptr + query_start * p.num_keys + key_start; + typename Epilogue::OutputTileIterator output_iter( + params, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id); + Epilogue epilogue(shared_storage.gradB_epilogue(), + (int)thread_id, + (int)warp_id, + (int)lane_id); + epilogue(OutputOp(1), output_iter, accum); + } + + accum = accum * scale; + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), accum, lane_id, output_tile_coords); + __syncthreads(); + } + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma(shared_storage.mm_gradQ(), + shared_storage.tmp_shared_storage(), + thread_id, + warp_id, + lane_id, + problem_size.k()); + + typename Mma::FragmentC accum; + + bool isFirst = key_start == 0; + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = + kSingleIterationGradQ ? 1 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + AccumTileGmem gmem_tile{p.workspace_gq + storage_id * AccumTileGmem::kElementsStored}; + if (isFirst || !kNeedsAccumGradQ) { + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + // Output results + int32_t next_query, next_key; + incrIteration(p, p.num_queries, key_start, next_query, next_key); + bool isLast = next_query > query_start || next_key >= p.num_keys; + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + } else { + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + accumulateInGmem(isLastColumn + ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + isFirst || kNeedsAccumGradQ, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = + *call_conditional::apply( + getTmp, getTmpT, 0); + Mma mma(shared_storage.mm_gradK(), opA, thread_id, warp_id, lane_id, problem_size.k()); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{p.workspace_gk + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, output_frags.gradK, iterator_B, output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem(isLastColumn + ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return 0; }; + + static CUTLASS_DEVICE void incrIteration(Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) + { + next_query = query_start + kBlockSizeI; + next_key = key_start; + if (next_query >= p.num_queries) { + next_key = key_start + kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration(SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + if (query_start >= p.num_queries || key_start >= p.num_keys) { return; } + + static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::prologue(shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem(SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = + skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + accumulateInGmem(shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem(shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) + { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = + kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta(Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) + { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); + const AccessType* __restrict__ output_ptr = reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); + + static constexpr int64_t kMaxIters = kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && (laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { columnIteration(iter); } + } else { + int num_iters = ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { columnIteration(iter); } + } + + // Reduce between workers + static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { p.delta_ptr[query_start + laneRow] = delta_value; } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) +{ + if (!p.advance_to_block()) { return; } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/csrc/deepspeed4science/evoformer_attn/kernel_forward.h b/csrc/deepspeed4science/evoformer_attn/kernel_forward.h new file mode 100644 index 000000000000..e3b11ebcc661 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/kernel_forward.h @@ -0,0 +1,986 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/bias_broadcast.h" +#include "transform/tile_smem_loader.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSm() +{ + return (Arch::kMinComputeCapability >= 80 && !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) +{ + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock_, + bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsBias_ = false, + template class Broadcast1_ = BroadcastNoLoad, + template class Broadcast2_ = BroadcastNoLoad> +struct AttentionKernel { + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kSingleValueIteration_; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits::value == 16; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = + !kKeepOutputInRF && !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = getWarpsPerSm() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + + // Output tensors + output_t* output_ptr; // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr; // [num_queries, num_heads, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim; + int32_t head_dim_value; + int32_t num_queries; + int32_t num_keys; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + // int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + // int32_t bias_strideH = 0; + + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + // int32_t bias_strideB = 0; + + int32_t num_batches; + int32_t num_heads; + + // Parameters for biases + scalar_t* bias1_ptr = nullptr; + scalar_t* bias2_ptr = nullptr; + int32_t B = 0; + int32_t N = 0; + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() + { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + + int64_t q_start = 0, k_start = 0; + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + using broadcast_1 = Broadcast1_; + if (kSupportsBias && broadcast_1::kEnable && bias1_ptr) { + bias1_ptr = broadcast_1::advance(bias1_ptr, + batch_id / N, + batch_id % N, + head_id, + num_queries * N, + num_queries, + 0); + } + using broadcast_2 = Broadcast2_; + if (kSupportsBias && broadcast_2::kEnable && bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + bias2_ptr = broadcast_2::advance( + bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + o_strideM = warp_uniform(o_strideM); + if (kSupportsBias && broadcast_1::kEnable) { bias1_ptr = warp_uniform(bias1_ptr); } + if (kSupportsBias && broadcast_2::kEnable) { bias2_ptr = warp_uniform(bias2_ptr); } + return true; + } + + __host__ dim3 getBlocksGrid() const + { + return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock), num_heads, num_batches); + } + + __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr int kAlignmentA = kIsAligned ? DefaultConfig::kAlignmentA + : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB + : GemmType::kMinimumAlignment; + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that + // uses too much smem + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = + TileSmemLoader, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB + : GemmType::kMinimumAlignment; + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + // typename MM0::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() + { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + // typename MM0::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() + { + return after_mm0.epilogue; + } + }; + + using SharedStorage = + typename cutlass::platform::conditional::type; + + static bool __host__ check_supported(Params const& p) + { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + EVOFORMER_CHECK(p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.k_strideM % kAlignmentK == 0, "key is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.v_strideM % kAlignmentV == 0, "value is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) + { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = + cutlass::fast_min(int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1.mm, iterator_V, thread_id(), problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params(typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params(typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma(shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + // if (kSupportsBias) { + // accum = + // cutlass::multiplies()(p.scale, + // accum); + // } + + if (kSupportsBias) { + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + using Shape = + cutlass::MatrixShape; + AttentionBiasEpilogue + bias_epilogue; + bias_epilogue(bias_tensor_ref, + p.bias1_ptr + iter_key_start, + p.bias2_ptr + query_start * p.num_keys + iter_key_start, + thread_id(), + {problem_size_0_m, problem_size_0_n}, + p.num_keys); + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] = + accum[idx] * p.scale + bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] { + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax(accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = + my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = + cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kSingleValueIteration + ? 1 + : ceil_div((int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { accum_o.clear(); } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue::thread:: + MemoryEfficientAttentionNormalize< + typename cutlass::platform:: + conditional:: + type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = + call_conditional:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { __syncthreads(); } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = + accum_t(mi[thread_id()]) + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) + { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { m_prime[thread_id] = mi[thread_id]; } + __syncthreads(); + } + + auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { max = -cutlass::platform::numeric_limits::infinity(); }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } + + static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; } + static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; } + static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) +{ + if (!p.advance_to_block()) { return; } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h b/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h new file mode 100644 index 000000000000..0f15a43574cf --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +// This does nothing. +template +struct BroadcastNoLoad { + using Fragment = + cutlass::Array; + static const bool kEnable = false; + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + } + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr; + } +}; + +// This is to load the bias matrix from the global memory with on-the-fly +// broadcast. The shape in global memory is [B, N, 1, 1, L]. Each time we load +// the last dimension as a L row vector, and we further broadcast the L vector +// to a tile of size [L, L] by repeating the L vector L times +template +struct BroadcastA : public BroadcastNoLoad { + using Base = BroadcastNoLoad; + static const bool kEnable = true; + using layout = cutlass::layout::AffineRank2RowMajor; + + using GmemTileIterator = cutlass::transform::threadblock:: + PredicatedTileIterator; + using Fragment = typename GmemTileIterator::Fragment; + + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + GmemTileIterator iter({layout(0, 1)}, ptr, extent, thread_id); + iter.load(frag); + } + + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr + B_id * strideB + N_id * strideN; + } +}; + +// This is to load the bias matrix from the global memory with on-the-fly +// broadcast. The shape in global memory is [B, 1, H, L, L]. Each time we load +// a [L, L] matrix. Different N use the same bias matrix when B and H are the +// same. +template +struct BroadcastB : public BroadcastNoLoad { + using Base = BroadcastNoLoad; + static const bool kEnable = true; + using layout = cutlass::layout::RowMajor; + + using GmemTileIterator = cutlass::transform::threadblock:: + PredicatedTileIterator; + using Fragment = typename GmemTileIterator::Fragment; + + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + GmemTileIterator iter({layout(stride)}, ptr, extent, thread_id); + iter.load(frag); + } + + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr + B_id * strideB + H_id * strideH; + } +}; + +template + class Broadcast1_, + template + class Broadcast2_> +struct AttentionBiasEpilogue { + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape, + kThreads, + 1>; + + using Broadcast1 = Broadcast1_; + using Broadcast2 = Broadcast2_; + + Broadcast1 broadcast1; + Broadcast2 broadcast2; + + using Ref = cutlass::TensorRef; + using SmemTileIterator = cutlass::transform::threadblock:: + RegularTileIterator; + + CUTLASS_DEVICE void operator()(const Ref& ref, + scalar_t* ptr1, + scalar_t* ptr2, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + static_assert(Broadcast1::Fragment::kElements == Broadcast2::Fragment::kElements, + "The two broadcast fragments must have the same number of " + "elements"); + typename SmemTileIterator::Fragment frag; + frag.clear(); + float* frag_ptr = reinterpret_cast(&frag); + if (Broadcast1::kEnable) { + typename Broadcast1::Fragment frag1; + frag1.clear(); + broadcast1.load(frag1, ptr1, thread_id, extent, stride); + scalar_t* frag1_ptr = reinterpret_cast(&frag1); + for (int i = 0; i < Broadcast1::Fragment::kElements; ++i) { + frag_ptr[i] += static_cast(frag1_ptr[i]); + } + } + if (Broadcast2::kEnable) { + typename Broadcast2::Fragment frag2; + frag2.clear(); + broadcast2.load(frag2, ptr2, thread_id, extent, stride); + scalar_t* frag2_ptr = reinterpret_cast(&frag2); + for (int i = 0; i < Broadcast2::Fragment::kElements; ++i) { + frag_ptr[i] += static_cast(frag2_ptr[i]); + } + } + SmemTileIterator iter(ref, thread_id); + iter.store(frag); + __syncthreads(); + } +}; diff --git a/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h b/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h new file mode 100644 index 000000000000..5f3f0ee507e5 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template // thread access width in elements +class TileSmemLoader { +public: + using Shape = ThreadblockTileShape; + using SmemTile = cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = + cutlass::transform::threadblock::RegularTileIterator; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load(GmemTileIterator tile_load_iter, SmemTileIterator tile_store_iter) + { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp new file mode 100644 index 000000000000..1a887b50e1a3 --- /dev/null +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "fp_quantize.h" + +#include +#include +#include + +#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_quantization((C_TYPE*)val.data_ptr(), \ + (uint8_t*)out.data_ptr(), \ + num_groups, \ + group_size, \ + at::cuda::getCurrentCUDAStream(), \ + q_range, \ + q_bits, \ + q_mantisa_bits, \ + stochastic_rounding); \ + } + +at::Tensor quantize(torch::Tensor& out, + torch::Tensor& val, + int group_size, + int stochastic_rounding, + int q_bits, + int q_mantisa_bits) +{ + int total_elems = at::numel(val); + float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges + (q_bits == 12 ? 510.0 : // fp12 range + (q_bits == 6 ? 28.0 : // fp6 range + 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 + // in case accuracy is not matching! + int num_groups = total_elems / group_size; + + DISPATCH_QUANTIZE(kHalf, __half, 23, 8); +#ifdef BF16_AVAILABLE + DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8); +#endif + + return out; +} + +#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + num_groups, \ + group_size, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void dequantize(torch::Tensor& val, + torch::Tensor& val_q, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7); +#endif +} + +#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_selective_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + (int32_t*)indexes.data_ptr(), \ + num_groups, \ + group_size, \ + num_indexes, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } +void selective_dequantize(torch::Tensor& val, + torch::Tensor& val_q, + torch::Tensor& indexes, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + int num_indexes = indexes.size(0); + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7); +#endif +} + +at::Tensor get_scales(torch::Tensor& out, int num_groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto scales = + torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options); + return scales; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("quantize", &quantize, "quantize function"); + m.def("dequantize", &dequantize, "dequantize function"); + m.def("get_scales", &get_scales, "get scales function"); + m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); +} diff --git a/csrc/fp_quantizer/fp_quantize_impl.cu b/csrc/fp_quantizer/fp_quantize_impl.cu new file mode 100644 index 000000000000..8b1913e1588f --- /dev/null +++ b/csrc/fp_quantizer/fp_quantize_impl.cu @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "fp_context.h" +#include "fp_quantize.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +#include +#include + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif +#include + +using ROp = reduce::ROpType; + +namespace quantization { + +constexpr int access_granularity = 16; +constexpr int quanitzed_access_granularity = 4; +constexpr int quanitzed_access_granularity_6bits = 2; +constexpr int threads = 256; +constexpr int warps = threads / 32; + +} // namespace quantization + +template +__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state) +{ + constexpr uint32_t mantisa_mask = (1U << (_mantisa_bits - q_mantisa_bits)) - 1; + uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask) + : 1U << (_mantisa_bits - q_mantisa_bits - 1); + mantisa += offset; + dst_exponent += (((mantisa & ~mantisa_mask) == (1U << _mantisa_bits)) ? 1 : 0); +} + +template +__device__ void clip(uint32_t& exponent, uint32_t& mantisa) +{ + constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1)); + constexpr uint32_t min_exponent = + (1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1); + if (exponent > max_exponent) { + exponent = max_exponent; + mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10 + } + if (exponent < min_exponent) { + exponent = min_exponent; + mantisa = 0; + } +} + +template +__global__ void apply_quantization(T* val, + uint8_t* q_val, + int group_size, + std::pair seed, + float q_range) +{ + int tidx = threadIdx.x; + int wid = tidx >> 5; + int lane = tidx & 0x1f; + int gid = blockIdx.x * quantization::warps + wid; + + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint32_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + constexpr uint32_t load_stride = vector_size * hw_warp_size; + constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size; + const uint32_t thread_offset = lane * vector_size; + const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8); + const uint32_t base_load_offset = gid * group_size + thread_offset; + const uint32_t base_store_offset = + gid * ((group_size * total_q_bits / 8) + 4) + + store_thread_offset; // 4-byte for saving the scale per group + const T* load_base_ptr = val + base_load_offset; + T tmp_buf[unroll * vector_size]; + T cur_max; + reduce::init(&cur_max); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + mem_access::load_global( + &tmp_buf[vector_size * i], load_base_ptr + i * load_stride); + for (int j = 0; j < vector_size; j++) + cur_max = reduce::element(cur_max, __habs(tmp_buf[i * vector_size + j])); + } + } + reduce::_block(tb, warp, &cur_max); + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + uint8_t* store_base_ptr = q_val + base_store_offset; + float scale = (float)q_range / conversion::to(cur_max); +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + uint64_t q_buf = 0; + uint64_t q_buf1 = 0; +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float val_f = conversion::to(tmp_buf[i * vector_size + j]) * scale; + uint32_t* data = reinterpret_cast(&val_f); + uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits; + uint32_t dst_mantisa = (data[0] & _mantisa_mask); + + uint32_t dst_exponent = cur_exponent; + + round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>( + dst_mantisa, dst_exponent, &state); + if (cur_exponent != 0) + clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>( + dst_exponent, dst_mantisa); + + dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6) + q_buf = q_buf | + ((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else if (total_q_bits == 12) { + if (j < 5) + q_buf = + q_buf | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else + q_buf1 = + q_buf1 | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << (j - 5) * total_q_bits); + } + } + if (total_q_bits == 12) { + uint64_t last_nibble_mask = 0xf; + last_nibble_mask = q_buf1 & last_nibble_mask; + q_buf = (last_nibble_mask << 60) | q_buf; + q_buf1 >>= 4; + } + uint8_t* int8_data = reinterpret_cast(&q_buf); + uint8_t* int8_data1 = reinterpret_cast(&q_buf1); + if (total_q_bits == 6) { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits, + int8_data + quantization::quanitzed_access_granularity_6bits); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits * 2, + int8_data + 2 * quantization::quanitzed_access_granularity_6bits); + } else { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + + if (total_q_bits > 4) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity, + int8_data + quantization::quanitzed_access_granularity); + if (total_q_bits == 12) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity * 2, + int8_data1); + } + } + } + } + } + if (lane == 0) { + float q_scale = conversion::to(cur_max) / (float)q_range; + uint8_t* scale_as_int8 = reinterpret_cast(&q_scale); + uint32_t scale_offset = + gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8); + if (total_q_bits != 6) + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + else { + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + mem_access::store_global( + q_val + scale_offset + quantization::quanitzed_access_granularity_6bits, + scale_as_int8 + quantization::quanitzed_access_granularity_6bits); + } + } +} + +template +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) +{ + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; + + constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); + const uint32_t g_index = (tidx / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; + + T* store_base_ptr = q_val + tidx; + float scale; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) | + (dst_mantisa << (mantisa_bits - q_mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global(store_base_ptr, store_buf); + } +} + +#define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \ + case COUNT: \ + apply_quantization \ + <<>>(val, q_val, group_size, seed, q_range); \ + break; + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding) +{ + const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + const dim3 block(quantization::threads); + + std::pair seed = FPContext::Instance().IncrementOffset(16); + + constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); + + const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] { + switch (copy_unroll) { + LAUNCH_FOR_QUANTIZATION_UNROLL(1) + LAUNCH_FOR_QUANTIZATION_UNROLL(2) + LAUNCH_FOR_QUANTIZATION_UNROLL(3) + LAUNCH_FOR_QUANTIZATION_UNROLL(4) + LAUNCH_FOR_QUANTIZATION_UNROLL(5) + LAUNCH_FOR_QUANTIZATION_UNROLL(6) + } + }); +} +#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \ + template void launch_quantization( \ + T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int); +// fp8(E4M3), nearest-rounding +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); +#endif +INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + int blocks = ((num_groups * group_size) - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(blocks); + const dim3 block(quantization::threads); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_dequantization + <<>>(val, q_val, group_size, (num_groups * group_size)); + }); +} +#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ + template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); + +template +__global__ void apply_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int group_size, + int total_num_elements) +{ + int index = indexes[blockIdx.x]; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; + int input_index = index * total_num_elements + tidx; + constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); + const uint32_t g_index = (input_index / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8; + + T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements; + float scale; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) | + (dst_mantisa << (mantisa_bits - q_mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global(store_base_ptr, store_buf); + } +} + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + int total_elements_per_index = (num_groups / num_indexes) * group_size; + int blocks = (total_elements_per_index - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(num_indexes, blocks); + const dim3 block(quantization::threads); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_selective_dequantization + <<>>(val, q_val, indexes, group_size, total_elements_per_index); + }); +} +#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \ + template void launch_selective_dequantization( \ + uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10); diff --git a/csrc/fp_quantizer/includes/fp_context.h b/csrc/fp_quantizer/includes/fp_context.h new file mode 100644 index 000000000000..5bd9badbcb4f --- /dev/null +++ b/csrc/fp_quantizer/includes/fp_context.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#define WARP_SIZE 32 + +class FPContext { +public: + FPContext() : _seed(42) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + } + + virtual ~FPContext() {} + + static FPContext& Instance() + { + static FPContext _ctx; + return _ctx; + } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + uint64_t _seed; + uint64_t _curr_offset; +}; diff --git a/csrc/fp_quantizer/includes/fp_quantize.h b/csrc/fp_quantizer/includes/fp_quantize.h new file mode 100644 index 000000000000..a15b8ddf5a22 --- /dev/null +++ b/csrc/fp_quantizer/includes/fp_quantize.h @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include + +#include +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) +#if !defined(__HIP_PLATFORM_AMD__) +#include +#else +struct __hip_bfloat16; +#endif +#endif + +#include +#include + +#define QUANT_SWITCH(Q_BITS, ...) \ + [&] { \ + if (12 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (13 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (10 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (11 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (28 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (29 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (6 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (7 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (2 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } \ + }() + +#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \ + [&] { \ + if (12 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (10 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 5; \ + __VA_ARGS__(); \ + } else if (28 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (6 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 3; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + constexpr int CONST_Q_EXPONENT_BITS = 2; \ + __VA_ARGS__(); \ + } \ + }() + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding); + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp new file mode 100644 index 000000000000..460330b93d31 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_gds_op.h" + +using namespace std; + +// For when there is more than 1 device +static std::map> base_ptr_registry; + +static void _safe_handle_register(const int fd, CUfileDescr_t& cf_descr, CUfileHandle_t& cf_handle) +{ + memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); + cf_descr.handle.fd = fd; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl; + close(fd); + exit(EXIT_FAILURE); + } +} + +static void* _find_base_ptr(const int64_t device, char* buf_ptr) +{ + void* base_ptr = nullptr; + int64_t last = -1; + int64_t ptr_diff; + for (const auto& value : base_ptr_registry[device]) { + ptr_diff = buf_ptr - (char*)value; + if (last == -1 && ptr_diff >= 0) { + last = ptr_diff; + base_ptr = value; + } else if (ptr_diff < last && ptr_diff >= 0) { + last = ptr_diff; + base_ptr = value; + } + } + if (!base_ptr || buf_ptr < base_ptr) { + std::cerr << "BASE PTR ERROR :" << base_ptr << " BUF PTR " << (void*)buf_ptr << std::endl; + for (const auto& value : base_ptr_registry[device]) { + std::cerr << "BASE PTR AVAIL :" << value << std::endl; + } + exit(EXIT_FAILURE); + } + + return base_ptr; +} + +void gds_op_desc_t::add_buffer_to_registry(const torch::Tensor& buffer) +{ + const int64_t device = buffer.get_device(); + void* reg_ptr = buffer.data_ptr(); + + // TODO: add checking to make sure pointer isn't already in set + const auto it = base_ptr_registry.find(device); + if (it == base_ptr_registry.end()) { + std::set new_ptr_set; + new_ptr_set.insert(reg_ptr); + base_ptr_registry.insert(std::pair>(device, new_ptr_set)); + } else { + base_ptr_registry[device].insert(reg_ptr); + } + + check_cudaruntimecall(cudaSetDevice(device)); + CUfileError_t status = cuFileBufRegister(reg_ptr, buffer.nbytes(), 0); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "buffer register failed:" << cuFileGetErrorString(status) << std::endl; + exit(EXIT_FAILURE); + } +} + +void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer) +{ + const int64_t device = buffer.get_device(); + void* reg_ptr = buffer.data_ptr(); + + // std::cout << "DEREG PTR " << reg_ptr << std::endl; + check_cudaruntimecall(cudaSetDevice(device)); + cuFileBufDeregister(reg_ptr); + + // Remove from tracked registry + base_ptr_registry[device].erase(reg_ptr); +} + +gds_op_desc_t::gds_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset) + : io_op_desc_t(read_op, buffer, fd, filename, intra_op_parallelism, validate, file_offset) +{ + _contiguous_buffer = _buffer.contiguous(); + const int64_t device = _buffer.get_device(); + check_cudaruntimecall(cudaSetDevice(device)); + _base_ptr = _find_base_ptr(device, (char*)_contiguous_buffer.data_ptr()); + + _safe_handle_register(fd, _cf_descr, _cf_handle); +} + +char* gds_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void gds_op_desc_t::finish() { cuFileHandleDeregister(_cf_handle); } + +void gds_op_desc_t::validate() +{ + check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); + const auto cpu_buffer = _buffer.to(torch::kCPU); + const auto num_io_bytes = static_cast(_contiguous_buffer.nbytes()); + validate_aio_operation( + _read_op, _filename.c_str(), (char*)(cpu_buffer.data_ptr()), num_io_bytes); +} + +void gds_op_desc_t::run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config) +{ + assert(tid < _intra_op_parallelism); + check_cudaruntimecall(cudaSetDevice(_buffer.get_device())); + const auto buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr; + const auto tid_file_offset = _file_offset + (_num_bytes_per_thread * tid); + + if (_read_op) { + auto ret = + cuFileRead(_cf_handle, _base_ptr, _num_bytes_per_thread, tid_file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, tid_file_offset); } + } else { + auto ret = + cuFileWrite(_cf_handle, _base_ptr, _num_bytes_per_thread, tid_file_offset, buf_offset); + if (ret < 0) { _report_error(ret, errno, tid_file_offset); } + } +} + +void gds_op_desc_t::_report_error(const ssize_t return_code, + const int error_num, + const off_t offset) +{ + const auto op_string = _read_op ? "read failed with " : "write failed with "; + const auto error_string = IS_CUFILE_ERR(return_code) ? "cuFile error: " : "posix error: "; + const auto error_code = IS_CUFILE_ERR(return_code) ? cuFileGetErrorString(return_code) + : cuFileGetErrorString(error_num); + std::cerr << op_string << error_string << error_code << " return code = " << return_code + << " filename = " << _filename << " num bytes = " << _num_bytes_per_thread + << " offset = " << offset << std::endl; + exit(EXIT_FAILURE); +} diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h new file mode 100644 index 000000000000..fe2d3cafb8ef --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_op.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include + +#include "deepspeed_aio_op_desc.h" +#include "deepspeed_gds_utils.h" + +struct gds_op_desc_t : io_op_desc_t { + CUfileDescr_t _cf_descr; + CUfileHandle_t _cf_handle; + void* _base_ptr; + + gds_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const int intra_op_parallelism, + const bool validate, + const int64_t file_offset); + + void run(const int tid, + std::unique_ptr& aio_ctxt, + deepspeed_aio_config_t* aio_config); + + char* data_ptr() const; + + void validate(); + + void finish(); + + void _report_error(const ssize_t return_code, const int error_num, const off_t offset); + + static void add_buffer_to_registry(const torch::Tensor& buffer); + + static void remove_buffer_from_registry(const torch::Tensor& buffer); +}; diff --git a/csrc/gds/py_lib/deepspeed_gds_utils.h b/csrc/gds/py_lib/deepspeed_gds_utils.h new file mode 100644 index 000000000000..12b014d90988 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_gds_utils.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +// CUDA/cuFile includes +#include +#include +#include "cufile.h" + +// Macro for checking cuda errors following a cuda launch or api call +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } + +#define check_cudadrivercall(fn) \ + do { \ + CUresult res = fn; \ + if (res != CUDA_SUCCESS) { \ + const char* str = nullptr; \ + cuGetErrorName(res, &str); \ + std::cerr << "cuda driver api call failed " << #fn << " res : " << res << ", " \ + << __LINE__ << ":" << str << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cudaruntimecall(fn) \ + do { \ + cudaError_t res = fn; \ + if (res != cudaSuccess) { \ + const char* str = cudaGetErrorName(res); \ + std::cerr << "cuda runtime api call failed " << #fn << __LINE__ << ":" << str \ + << std::endl; \ + std::cerr << "EXITING program!!!" << std::endl; \ + exit(1); \ + } \ + } while (0) + +#define check_cuFileCall(fn, api_msg) \ + do { \ + CUfileError_t status = fn; \ + if (status.err != CU_FILE_SUCCESS) { \ + std::cout << api_msg << " failed with error " << CUFILE_ERRSTR(status.err) \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// +// cuda driver error description +// +static inline const char* GetCuErrorString(CUresult curesult) +{ + const char* descp; + if (cuGetErrorName(curesult, &descp) != CUDA_SUCCESS) descp = "unknown cuda error"; + return descp; +} + +// +// cuFile APIs return both cuFile specific error codes as well as POSIX error codes +// for ease, the below template can be used for getting the error description depending +// on its type. + +// POSIX +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + status = std::abs(status); + return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) + : std::string(std::strerror(status)); +} + +// CUfileError_t +template ::value, std::nullptr_t>::type = nullptr> +std::string cuFileGetErrorString(T status) +{ + std::string errStr = cuFileGetErrorString(static_cast(status.err)); + if (IS_CUDA_ERR(status)) errStr.append(".").append(GetCuErrorString(status.cu_err)); + return errStr; +} diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp new file mode 100644 index 000000000000..4d9dc5445043 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* + GPUDirect Storage functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_gds_handle.h" +#include +#include "deepspeed_gds_op.h" + +using namespace std; + +int deepspeed_gds_handle_t::s_cuFile_init = 0; + +deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int intra_op_parallelism) + : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, 1), + _intra_gds_op_parallelism(intra_op_parallelism) +{ + _init_cuFile(block_size, queue_depth); +} + +deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); } + +const int deepspeed_gds_handle_t::get_intra_op_parallelism() const +{ + return _intra_gds_op_parallelism; +} + +void deepspeed_gds_handle_t::_init_cuFile(const int block_size, const int queue_depth) +{ + if (deepspeed_gds_handle_t::s_cuFile_init == 0) { + std::string depthStr = std::to_string(queue_depth); + std::string threadsStr = std::to_string(_intra_gds_op_parallelism); + std::string json1 = R"({"execution": {"max_io_queue_depth": )" + depthStr + ", "; + std::string json2 = R"("max_request_parallelism": )" + threadsStr + ", "; + std::string json3 = R"("max_io_threads": )" + threadsStr + ", "; + std::string json4 = R"("parallel_io": true, "min_io_threshold_size_kb": 8192}})"; + std::ofstream outFile("local_cufile.json"); + if (outFile.is_open()) { + outFile << json1 + json2 + json3 + json4; + outFile.close(); + } else { + std::cerr << "Can't open local cufile" << std::endl; + exit(EXIT_FAILURE); + } + // TODO: Address the following issues with this code + // (1) Fix C++14 warning + // (2) Create file in a different location than PWD + // (3) Handle multi-GPU/multi-rank scenarios: should cufile be shared, is per-rank cufile + // safe? + putenv("CUFILE_ENV_PATH_JSON=$PWD/local_cufile.json"); + cuFileDriverOpen(); + cudaCheckError(); + size_t direct_io_size = (size_t)block_size / 1024; + CUfileError_t status = cuFileDriverSetMaxDirectIOSize(direct_io_size); + if (status.err != CU_FILE_SUCCESS) { + std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl; + exit(EXIT_FAILURE); + } + } + deepspeed_gds_handle_t::s_cuFile_init++; +} + +void deepspeed_gds_handle_t::_close_cuFile() +{ + deepspeed_gds_handle_t::s_cuFile_init--; + if (deepspeed_gds_handle_t::s_cuFile_init == 0) { cuFileDriverClose(); } +} + +torch::Tensor deepspeed_gds_handle_t::new_pinned_device_tensor(const size_t num_elem, + const torch::Tensor& example_tensor) +{ + auto options = torch::TensorOptions().dtype(example_tensor.scalar_type()).device(torch::kCUDA); + auto dev_tensor = torch::empty(num_elem, options); + pin_device_tensor(dev_tensor); + return dev_tensor; +} + +bool deepspeed_gds_handle_t::free_pinned_device_tensor(torch::Tensor& buffer) +{ + unpin_device_tensor(buffer); + return true; +} + +bool deepspeed_gds_handle_t::pin_device_tensor(const torch::Tensor& buffer) +{ + gds_op_desc_t::add_buffer_to_registry(buffer); + return true; +} + +bool deepspeed_gds_handle_t::unpin_device_tensor(const torch::Tensor& buffer) +{ + gds_op_desc_t::remove_buffer_from_registry(buffer); + return true; +} + +std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc( + const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const int64_t file_offset) +{ + if (buffer.is_cuda()) { + return std::make_shared( + read_op, buffer, fd, filename, _intra_op_parallelism, validate, file_offset); + } + return deepspeed_io_handle_t::_create_io_op_desc( + read_op, buffer, fd, filename, validate, file_offset); +} diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h new file mode 100644 index 000000000000..f8090fc6fde4 --- /dev/null +++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_py_io_handle.h" + +struct deepspeed_gds_handle_t : deepspeed_io_handle_t { + const int _intra_gds_op_parallelism; + + deepspeed_gds_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int intra_op_parallelism); + + ~deepspeed_gds_handle_t(); + + torch::Tensor new_pinned_device_tensor(const size_t num_elem, + const torch::Tensor& example_tensor); + + bool free_pinned_device_tensor(torch::Tensor&); + + bool pin_device_tensor(const torch::Tensor& buffer); + + bool unpin_device_tensor(const torch::Tensor& buffer); + + void _init_cuFile(const int block_size, const int queue_depth); + + void _close_cuFile(); + + const int get_intra_op_parallelism() const; + + std::shared_ptr _create_io_op_desc(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const bool validate, + const int64_t file_offset); + + static int s_cuFile_init; +}; diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp new file mode 100644 index 000000000000..bc6327bb3806 --- /dev/null +++ b/csrc/gds/py_lib/py_ds_gds.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include "deepspeed_py_gds_handle.h" +using namespace pybind11::literals; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + py::class_(m, "gds_handle") + .def(py::init(), + "GDS handle constructor", + "block_size"_a = 1024 * 1024, + "queue_depth"_a = 128, + "single_submit"_a = false, + "overlap_events"_a = false, + "intra_op_parallelism"_a = 1) + + .def("get_block_size", &deepspeed_gds_handle_t::get_block_size) + .def("get_queue_depth", &deepspeed_gds_handle_t::get_queue_depth) + .def("get_single_submit", &deepspeed_gds_handle_t::get_single_submit) + .def("get_overlap_events", &deepspeed_gds_handle_t::get_overlap_events) + .def("get_intra_op_parallelism", &deepspeed_gds_handle_t::get_intra_op_parallelism) + .def("get_alignment", &deepspeed_gds_handle_t::get_alignment) + + .def("read", + &deepspeed_gds_handle_t::read, + "Synchronous and non-parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "file_offset"_a = 0) + + .def("write", + &deepspeed_gds_handle_t::write, + "Synchronous and non-parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "file_offset"_a = 0) + + .def("pread", + &deepspeed_gds_handle_t::pread, + "Parallel file read with option of parallelism. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a, + "file_offset"_a = 0) + + .def("pwrite", + &deepspeed_gds_handle_t::pwrite, + "Parallel file write with option of parallelism. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "validate"_a, + "async"_a, + "file_offset"_a = 0) + + .def("sync_pread", + &deepspeed_gds_handle_t::sync_pread, + "Synchrononous parallel file read. Returns count of completed read ops", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def("sync_pwrite", + &deepspeed_gds_handle_t::sync_pwrite, + "Synchronous parallel file write. Returns count of completed write ops", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def("async_pread", + &deepspeed_gds_handle_t::async_pread, + "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and " + "following wait() returns count of completed ops.", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def( + "async_pwrite", + py::overload_cast( + &deepspeed_gds_handle_t::async_pwrite), + "Asynchronous parallel file write. Returns 0 on success, and subsequent wait() returns " + "count of completed ops.", + "buffer"_a, + "filename"_a, + "file_offset"_a = 0) + + .def("async_pwrite", + py::overload_cast( + &deepspeed_gds_handle_t::async_pwrite), + "Asynchronous parallel file write using opened python file object.", + "buffer"_a, + "fd"_a, + "file_offset"_a = 0) + + .def("new_cpu_locked_tensor", + &deepspeed_gds_handle_t::new_cpu_locked_tensor, + "Allocate pinned CPU tensor.", + "num_elem"_a, + "example_tenosr"_a) + + .def("free_cpu_locked_tensor", + &deepspeed_gds_handle_t::free_cpu_locked_tensor, + "Free pinned CPU tensor.", + "tensor"_a) + + .def("new_pinned_device_tensor", + &deepspeed_gds_handle_t::new_pinned_device_tensor, + "Allocate pinned device tensor.", + "num_elem"_a, + "example_tenosr"_a) + + .def("free_pinned_device_tensor", + &deepspeed_gds_handle_t::free_pinned_device_tensor, + "Free pinned device tensor.", + "tensor"_a) + + .def("pin_device_tensor", + &deepspeed_gds_handle_t::pin_device_tensor, + "Pin device tensor.", + "tensor"_a) + + .def("unpin_device_tensor", + &deepspeed_gds_handle_t::unpin_device_tensor, + "Unpin device tensor.", + "tensor"_a) + + .def("wait", + &deepspeed_gds_handle_t::wait, + "Wait for (ongoing) asynchronous operations to complete"); +} diff --git a/csrc/gds/py_test/validate_gds.py b/csrc/gds/py_test/validate_gds.py new file mode 100644 index 000000000000..b34b1194f582 --- /dev/null +++ b/csrc/gds/py_test/validate_gds.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" +from deepspeed.ops.op_builder import GDSBuilder +assert GDSBuilder().is_compatible(True) +assert GDSBuilder().load(True) diff --git a/csrc/includes/activation_type.h b/csrc/includes/activation_type.h new file mode 100644 index 000000000000..a44921d5d650 --- /dev/null +++ b/csrc/includes/activation_type.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +enum ActivationType { + GELU = 0, + RELU = 1, + SILU = 2, + GEGLU = 3, + ReGLU = 4, + SiGLU = 5, + IDENTITY = 6, + InvalidType = -1 +}; diff --git a/csrc/includes/context.h b/csrc/includes/context.h index 3a9067dc3b9f..cd80f8fbeebe 100644 --- a/csrc/includes/context.h +++ b/csrc/includes/context.h @@ -50,8 +50,12 @@ class TrainingContext { { curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); curandSetPseudoRandomGeneratorSeed(_gen, 123); - if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { - auto message = std::string("Fail to create cublas handle."); + cublasStatus_t stat = cublasCreate(&_cublasHandle); + if (stat != CUBLAS_STATUS_SUCCESS) { + // It would be nice to use cublasGetStatusName and + // cublasGetStatusString, but they were only added in CUDA 11.4.2. + auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") + + std::to_string(stat); std::cerr << message << std::endl; throw std::runtime_error(message); } diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h index 27600b83d2b1..d6d8f11e0854 100644 --- a/csrc/includes/conversion_utils.h +++ b/csrc/includes/conversion_utils.h @@ -7,7 +7,6 @@ #include "ds_kernel_utils.h" -#include #include #ifdef BF16_AVAILABLE @@ -266,7 +265,12 @@ DS_D_INLINE float2 to(__nv_bfloat162 val) template <> DS_D_INLINE __half to(double val) { +#ifdef __HIP_PLATFORM_AMD__ + float val_f = __double2float_rn(val); + return __float2half(val_f); +#else return __double2half(val); +#endif } template <> DS_D_INLINE __half to(float val) @@ -329,6 +333,11 @@ DS_D_INLINE __half2 to(float2 val) { return __float22half2_rn(val); } +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} #ifdef BF16_AVAILABLE // No direct conversion @@ -354,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val) template <> DS_D_INLINE __nv_bfloat16 to(int64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ll2double_rn(val)); +#else return __ll2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __short2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ull2double_rn(val)); +#else return __ull2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __ushort2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } #endif @@ -401,6 +442,15 @@ DS_D_INLINE __nv_bfloat162 to(float2 val) return __float22bfloat162_rn(val); } template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __bfloat162bfloat162(__float2bfloat16(val)); +#else + return __float2bfloat162_rn(val); +#endif +} +template <> DS_D_INLINE __nv_bfloat162 to(__half2 val) { return to<__nv_bfloat162>(to(val)); @@ -430,7 +480,11 @@ DS_D_INLINE int64_t to(__half val) template <> DS_D_INLINE int64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ll_rn(__bfloat162float(val)); +#else return __bfloat162ll_rn(val); +#endif } #endif @@ -457,7 +511,11 @@ DS_D_INLINE int32_t to(__half val) template <> DS_D_INLINE int32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -484,7 +542,11 @@ DS_D_INLINE int16_t to(__half val) template <> DS_D_INLINE int16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -511,7 +573,11 @@ DS_D_INLINE int8_t to(__half val) template <> DS_D_INLINE int8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -538,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val) template <> DS_D_INLINE uint64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ull_rn(__bfloat162float(val)); +#else return __bfloat162ull_rn(val); +#endif } #endif @@ -565,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val) template <> DS_D_INLINE uint32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -592,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val) template <> DS_D_INLINE uint16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -619,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val) template <> DS_D_INLINE uint8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index ba40fcf7b62a..6f500250f033 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -9,67 +9,35 @@ // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c #include +#include #include #include "simd.h" -#if defined(__ENABLE_CUDA__) -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; -#else -typedef unsigned short ds_half_precision_t; -#endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ - bool half_precision = false); +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_precision_t* _params, \ + ds_params_precision_t* grads, \ + ds_state_precision_t* _exp_avg_sq, \ + size_t _param_size); class Adagrad_Optimizer { public: Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) : _alpha(alpha), _eps(eps), _weight_decay(weight_decay) { -#if defined(__ENABLE_CUDA__) - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = TrainingContext::Instance().GetCurrentStream(); - _streams[1] = TrainingContext::Instance().GetNewStream(); - _buf_index = false; -#endif - } - ~Adagrad_Optimizer() - { -#if defined(__ENABLE_CUDA__) - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); -#endif } + ~Adagrad_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t param_size, - ds_half_precision_t* dev_param = nullptr, - bool half_precision = false); + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t param_size); #endif STEP(1) STEP(4) STEP(8) -#if defined(__ENABLE_CUDA__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } -#endif inline void IncrementStep(size_t step) { _step++; @@ -90,24 +58,22 @@ class Adagrad_Optimizer { float _betta1_t; float _betta2_t; size_t _step; - -#if defined(__ENABLE_CUDA__) - bool _buf_index; - float* _doubled_buffer[2]; - cudaStream_t _streams[2]; -#endif }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; AVX_Data eps_4; eps_4.data = SIMD_SET(_eps); @@ -123,22 +89,19 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, grads + i, false); + simd_load(momentum_4, grads + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); + simd_load(param_4, _params + i); if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } @@ -148,26 +111,9 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, simd_div(grad_4, momentum_4, grad_4); simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + i, param_4, half_precision); -#if defined(__ENABLE_CUDA__) - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } -#endif - simd_store(_exp_avg_sq + i, variance_4, false); + simd_store(_params + i, param_4); + simd_store(_exp_avg_sq + i, variance_4); } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#endif } *rounded_size = new_rounded_size; } diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 4648aede93ee..f07a14e08438 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -9,28 +9,17 @@ // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c #include +#include #include #include "simd.h" -#if defined(__ENABLE_CUDA__) -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -typedef __half ds_half_precision_t; -#else -#include -typedef unsigned short ds_half_precision_t; -#endif - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ - bool half_precision = false); +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_precision_t* _params, \ + ds_params_precision_t* grads, \ + ds_state_precision_t* _exp_avg, \ + ds_state_precision_t* _exp_avg_sq, \ + size_t _param_size); class Adam_Optimizer { public: @@ -50,43 +39,21 @@ class Adam_Optimizer { _step(0), _adamw_mode(adamw_mode) { -#if defined(__ENABLE_CUDA__) - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = TrainingContext::Instance().GetCurrentStream(); - _streams[1] = TrainingContext::Instance().GetNewStream(); - _buf_index = false; -#endif - } - ~Adam_Optimizer() - { -#if defined(__ENABLE_CUDA__) - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); -#endif } + ~Adam_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) - template + template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - ds_half_precision_t* dev_param = nullptr, - bool half_precision = false); + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t param_size); #endif STEP(1) STEP(4) STEP(8) -#if defined(__ENABLE_CUDA__) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } -#endif inline void IncrementStep(size_t step, float beta1, float beta2) { if (beta1 != _betta1 || beta2 != _betta2) { @@ -96,14 +63,17 @@ class Adam_Optimizer { _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { + if (step == _step + 1) { // first optimizer step increase + _step++; + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } else if (step == + _step) { // no need to update step; beta1_t and beta2_t already updated + return; + } else { // support step increase not equal to 1 _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; } } } @@ -136,27 +106,24 @@ class Adam_Optimizer { float _bias_correction2; bool _adamw_mode; - -#if defined(__ENABLE_CUDA__) - float* _doubled_buffer[2]; - cudaStream_t _streams[2]; - bool _buf_index; -#endif }; #if defined(__AVX512__) or defined(__AVX256__) -template +template void Adam_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + ds_state_precision_t* _exp_avg_sq, + size_t _param_size) { +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif size_t new_rounded_size = 0; - int rshft = half_precision ? 1 : 0; AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -189,22 +156,19 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, size_t copy_size = TILE; if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); + simd_load(grad_4, grads + i); AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); + simd_load(momentum_4, _exp_avg + i); AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); + simd_load(variance_4, _exp_avg_sq + i); AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); + simd_load(param_4, _params + i); if (_weight_decay > 0 && !_adamw_mode) { simd_fma(grad_4, param_4, weight_decay4, grad_4); @@ -225,28 +189,48 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(param_4, grad_4, step_size_4, param_4); - simd_store(_params + (i >> rshft), param_4, half_precision); -#if defined(__ENABLE_CUDA__) - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } -#endif - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); + simd_store(_params + i, param_4); + simd_store(_exp_avg + i, momentum_4); + simd_store(_exp_avg_sq + i, variance_4); } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#endif } *rounded_size = new_rounded_size; } #endif + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false); + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq); + +int ds_adam_rollback(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq); + +int destroy_adam_optimizer(int optimizer_id); diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h new file mode 100644 index 000000000000..beaf357a3211 --- /dev/null +++ b/csrc/includes/cpu_lion.h @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#include +#include +#include +#include "simd.h" + +#define STEP(SPAN) \ + template \ + void Step_##SPAN(ds_params_precision_t* _params, \ + ds_params_precision_t* grads, \ + ds_state_precision_t* _exp_avg, \ + size_t _param_size); + +class Lion_Optimizer { +public: + Lion_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float weight_decay = 0) + : _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0) + { + } + ~Lion_Optimizer() {} + +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + size_t param_size); +#endif + STEP(1) + STEP(4) + STEP(8) + + inline void IncrementStep(size_t step, float beta1, float beta2) + { + _step++; + if (_step != step || beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + } + } + inline void update_state(float lr, float weight_decay) + { + _alpha = lr; + _weight_decay = weight_decay; + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _weight_decay; + size_t _step; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Lion_Optimizer::Step_AVX(size_t* rounded_size, + ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) +{ +#if !defined(__AVX512__) + if (std::is_same_v || + std::is_same_v) { + return; + } +#endif + size_t new_rounded_size = 0; + + constexpr float neg1 = -1.0f; + AVX_Data neg1_4; + neg1_4.data = SIMD_SET(neg1); + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + float step_size = -_alpha; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float after_decay = 1.0f - _alpha * _weight_decay; + AVX_Data after_decay_4; + if (_weight_decay > 0) after_decay_4.data = SIMD_SET(after_decay); + + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i); + + AVX_Data tmp_4[span]; + + simd_mul(tmp_4, momentum_4, betta1_4); + simd_fma(tmp_4, grad_4, betta1_minus1_4, tmp_4); + // We already used intrinsics, so consider the machine representation fixed. + simd_and(tmp_4, tmp_4, neg1_4); + simd_xor(tmp_4, tmp_4, step_size_4); + if (_weight_decay > 0) { + simd_fma(param_4, param_4, after_decay_4, tmp_4); + } else { + simd_add(param_4, param_4, tmp_4); + } + + simd_mul(momentum_4, momentum_4, betta2_4); + simd_fma(momentum_4, grad_4, betta2_minus1_4, momentum_4); + + simd_store(_params + i, param_4); + simd_store(_exp_avg + i, momentum_4); + } + } + *rounded_size = new_rounded_size; +} +#endif + +int create_lion_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float weight_decay = 0, + bool should_log = false); + +int ds_lion_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg); + +int destroy_lion_optimizer(int optimizer_id); diff --git a/csrc/includes/cublas_wrappers.h b/csrc/includes/cublas_wrappers.h index b016832dc9b3..2721fb990c7e 100644 --- a/csrc/includes/cublas_wrappers.h +++ b/csrc/includes/cublas_wrappers.h @@ -10,10 +10,14 @@ #include #include #include -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif +#ifdef __HIP_PLATFORM_AMD__ +#include +#endif #include +#include int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, @@ -26,7 +30,9 @@ int cublas_gemm_ex(cublasHandle_t handle, const float* A, const float* B, float* C, -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -43,7 +49,8 @@ int cublas_gemm_ex(cublasHandle_t handle, const __half* A, const __half* B, __half* C, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -64,7 +71,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -85,7 +93,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h index 265eb7b12444..21f19749d4cf 100644 --- a/csrc/includes/custom_cuda_layers.h +++ b/csrc/includes/custom_cuda_layers.h @@ -272,9 +272,6 @@ void launch_fuse_transpose_bias_kernel(const T* inp, int cols, cudaStream_t stream); -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream); -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream); - void launch_token_sort(int32_t* indices, int layers, int batch_size, diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h new file mode 100644 index 000000000000..7016d4a99310 --- /dev/null +++ b/csrc/includes/deepcompile.h @@ -0,0 +1,617 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#define USE_C10D_NCCL + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +namespace dc { + +template +static bool hasKey(const std::unordered_map& map, const K& key) +{ + return map.find(key) != map.end(); +} + +template +inline std::string to_string(const T& v) +{ + std::stringstream ss; + ss << v; + return ss.str(); +} + +template +size_t productDim(const L& dim) +{ + size_t prod = 1; + for (auto d : dim) { prod *= d; } + return prod; +} + +template +std::string join_as_str(const T& v, const char* delim = ",", const size_t maxlen = 0) +{ + std::stringstream ss; + + if (!v.empty()) { + auto it = v.begin(); + ss << to_string(*it); + it++; + for (; it != v.end(); ++it) { + if (delim) ss << delim; + ss << to_string(*it); + } + } + + std::string s = ss.str(); + if (maxlen > 0 && s.length() > maxlen) { s = s.substr(0, maxlen) + " ..."; } + + return "[" + s + "]"; +} + +template +std::string tensorPtrToString(T* ptr, size_t size, size_t str_len = 100) +{ + std::vector vals; + for (size_t i = 0; i < size; i++) { + vals.push_back(*ptr); + ptr++; + } + return join_as_str(vals, ",", str_len); +} + +std::string tensorPtrToString(void* ptr, + size_t size, + c10::ScalarType datatype, + size_t max_elem = 20, + size_t max_str_len = 100); + +std::string tensorToString(const at::Tensor& t, size_t max_elem = 20, size_t max_str_len = 100); + +std::string tensorDimToString(const at::Tensor& t); + +at::Tensor test_call(at::Tensor param); + +extern c10::intrusive_ptr process_group; +extern c10::intrusive_ptr symm_mem; +extern ncclComm_t nccl_comm; +extern bool use_symm_mem; +extern bool profile; +extern bool pre_div_reduce; + +extern bool sync_before_reduce; // for debugging +extern bool sync_after_reduce; // for debugging +extern bool sync_before_allgather; // for debugging +extern bool sync_after_allgather; // for debugging + +std::vector sizes_to_int_vector(at::IntArrayRef sizes); +void enable_profiling(bool enable); +bool is_profiling(); + +c10::intrusive_ptr getSymmMemWorkspace(int64_t size); +void lazy_init_symm_memory(); +ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type); +void cleanup(); + +class ReduceTask { +public: + ReduceTask(long ds_id, at::Tensor grad, at::Tensor send_buf) + : ds_id_(ds_id), grad_(std::move(grad)), send_buf_(std::move(send_buf)) + { + } + + long getDSId() const { return ds_id_; } + at::Tensor getSendBuf() const { return send_buf_; } + +private: + long ds_id_; + at::Tensor grad_; + at::Tensor send_buf_; +}; + +class ReduceBucket { +public: + ReduceBucket(int64_t size, at::ScalarType scalar_type) : size_(size), scalar_type_(scalar_type) + { + buffer_ = torch::empty({size}, at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); + offset_ = 0; + } + + int64_t getSize() const { return size_; } + int64_t getOffset() const { return offset_; } + at::Tensor getBuffer() const { return buffer_; } + at::ScalarType getScalarType() const { return scalar_type_; } + + void reserve(int64_t size) + { + if (size > size_) { + buffer_ = + torch::empty({size}, at::TensorOptions().dtype(scalar_type_).device(at::kCUDA)); + size_ = size; + } + } + + at::Tensor allocate(int64_t numel) + { + if (offset_ + numel > size_) { + throw std::runtime_error("Buffer size exceeds the reduce bucket size"); + } + + at::Tensor result = buffer_.index({torch::indexing::Slice(offset_, offset_ + numel)}); + offset_ += numel; + return result; + } + + bool shouldFlush(int64_t numel) { return offset_ > 0 && offset_ + numel > size_; } + + void reset() { offset_ = 0; } + +private: + int64_t size_; + int64_t offset_; + at::Tensor buffer_; + at::ScalarType scalar_type_; +}; + +class DoubleBufferedReduceBucket { +public: + DoubleBufferedReduceBucket(int64_t initial_bucket_size, bool enable_double_buffer) + : initial_bucket_size_(initial_bucket_size), enable_double_buffer_(enable_double_buffer) + { + } + + void swap(at::ScalarType scalar_type, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream) + { + assert(hasKey(current_buffer_, scalar_type)); + assert(hasKey(current_buffer_events_, scalar_type)); + + current_buffer_.at(scalar_type)->reset(); + current_buffer_events_.at(scalar_type)->record(rs_stream); + + if (enable_double_buffer_) { + assert(hasKey(shadow_buffer_, scalar_type)); + assert(hasKey(shadow_buffer_events_, scalar_type)); + + auto tmp = current_buffer_.at(scalar_type); + current_buffer_[scalar_type] = shadow_buffer_.at(scalar_type); + shadow_buffer_[scalar_type] = tmp; + + auto tmp_event = current_buffer_events_.at(scalar_type); + current_buffer_events_[scalar_type] = shadow_buffer_events_.at(scalar_type); + shadow_buffer_events_[scalar_type] = tmp_event; + } + } + + std::shared_ptr getBuffer(at::ScalarType scalar_type) + { + if (!hasKey(current_buffer_, scalar_type)) { + current_buffer_[scalar_type] = + std::make_shared(initial_bucket_size_, scalar_type); + current_buffer_events_[scalar_type] = + std::make_shared(cudaEventDisableTiming); + + if (enable_double_buffer_) { + shadow_buffer_[scalar_type] = + std::make_shared(initial_bucket_size_, scalar_type); + shadow_buffer_events_[scalar_type] = + std::make_shared(cudaEventDisableTiming); + } + } + + return current_buffer_.at(scalar_type); + } + + std::shared_ptr getEvent(at::ScalarType scalar_type) + { + assert(hasKey(current_buffer_events_, scalar_type)); + return current_buffer_events_.at(scalar_type); + } + + void clear() + { + current_buffer_.clear(); + shadow_buffer_.clear(); + current_buffer_events_.clear(); + shadow_buffer_events_.clear(); + } + +private: + int64_t initial_bucket_size_; + bool enable_double_buffer_; + std::unordered_map> current_buffer_; + std::unordered_map> shadow_buffer_; + std::unordered_map> current_buffer_events_; + std::unordered_map> shadow_buffer_events_; +}; + +class DSParam { +public: + DSParam(long id, + std::vector ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool partitioned, + int64_t offset, // for Z1 + bool persistent // for Z3 + ) + : id_(id), + shape_(std::move(ds_shape)), + ds_tensor_(ds_tensor), + ds_dtype_(ds_tensor.scalar_type()), + grad_buffer_(grad_buffer), + partitioned_(partitioned), + offset_(offset), + persistent_(persistent) + { + } + + long getId() const { return id_; } + std::vector getShape() const { return shape_; } + at::ScalarType getDtype() const { return ds_dtype_; } + at::Tensor getDSTensor() const + { + // If the reload event exists and is complete, return the reloaded tensor (if defined) + if (reload_done_event_) { + if (!reload_done_event_->query()) { + reload_done_event_->block(at::cuda::getCurrentCUDAStream()); + } + if (ds_reload_tensor_.defined()) { return ds_reload_tensor_; } + } + // Otherwise, if an offload event exists, wait for it to complete + if (offload_done_event_) { + if (!offload_done_event_->query()) { + offload_done_event_->block(at::cuda::getCurrentCUDAStream()); + } + } + return ds_tensor_; + } + at::Tensor getGradBuffer() const { return grad_buffer_; } + bool isPartitioned() const { return partitioned_; } + int64_t getOffset() const { return offset_; } + void setPersistent(bool persistent) { persistent_ = persistent; } + bool isPersistent() const { return persistent_; } + + void offload() + { + // If a reloaded tensor exists, offload its data back to ds_tensor_ + if (ds_reload_tensor_.defined()) { + auto offload_stream = getOffloadStream(); + auto comp_stream = at::cuda::getCurrentCUDAStream(); + comp_done_event_ = std::make_shared(cudaEventDisableTiming); + // Record completion and wait on the offload stream + comp_done_event_->record(comp_stream); + comp_done_event_->block(offload_stream); + offload_done_event_ = std::make_shared(cudaEventDisableTiming); + + { + at::cuda::CUDAStreamGuard guard(offload_stream); + ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true); + ds_reload_tensor_.reset(); // Clear the reloaded tensor + offload_done_event_->record(offload_stream); + } + // Reset the reload event to indicate that no valid reload is present. + if (reload_done_event_) { reload_done_event_.reset(); } + } + } + + void reload() + { + // Reload only if the current ds_tensor_ is on CPU + if (ds_tensor_.device().is_cpu()) { + auto reload_stream = getReloadStream(); + auto comp_stream = at::cuda::getCurrentCUDAStream(); + comp_done_event_ = std::make_shared(cudaEventDisableTiming); + // Record and wait on the reload stream + comp_done_event_->record(comp_stream); + comp_done_event_->block(reload_stream); + reload_done_event_ = std::make_shared(cudaEventDisableTiming); + + { + at::cuda::CUDAStreamGuard guard(reload_stream); + ds_reload_tensor_ = + at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA)); + ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true); + reload_done_event_->record(reload_stream); + } + // Reset offload_done_event if it exists to clear any stale offload state. + if (offload_done_event_) { offload_done_event_.reset(); } + } + } + +private: + at::cuda::CUDAStream getOffloadStream() + { + if (!offload_stream_) { offload_stream_.emplace(at::cuda::getStreamFromPool()); } + return *offload_stream_; + } + + at::cuda::CUDAStream getReloadStream() + { + if (!reload_stream_) { reload_stream_.emplace(at::cuda::getStreamFromPool()); } + return *reload_stream_; + } + + long id_; + std::vector shape_; + at::ScalarType ds_dtype_; + at::Tensor ds_tensor_; + at::Tensor ds_reload_tensor_; + at::Tensor grad_buffer_; + bool partitioned_; + int64_t offset_; // for Z1 + bool persistent_; // for Z3 + mutable bool is_reloaded = false; + + std::optional offload_stream_; + std::optional reload_stream_; + std::shared_ptr comp_done_event_; + std::shared_ptr offload_done_event_; + std::shared_ptr reload_done_event_; +}; + +class DSParamRegistry { +public: + DSParamRegistry() {} + ~DSParamRegistry() {} + + void registerParam(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool partitioned, + int64_t offset, // for Z1 + bool persistent // for Z3 + ) + { + grad_buffer.zero_(); + params_.emplace( + ds_id, + DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent)); + valid_[ds_id] = false; + } + + void registerGatheredParam(long ds_id, at::Tensor ds_tensor) + { + gathered_params_.emplace(ds_id, ds_tensor); + } + + void unregisterGatheredParam(long ds_id) + { + assert(hasKey(gathered_params_, ds_id)); + gathered_params_.erase(ds_id); + valid_[ds_id] = false; + } + + const std::unordered_map& getParams() const { return params_; } + + const DSParam& getParam(long ds_id) const { return params_.at(ds_id); } + const size_t getNumParams() const { return params_.size(); } + const at::Tensor& getGatheredParam(long ds_id) const + { + assert(hasKey(gathered_params_, ds_id)); + return gathered_params_.at(ds_id); + } + bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); } + void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); } + void offload(long ds_id) { params_.at(ds_id).offload(); } + void reload(long ds_id) { params_.at(ds_id).reload(); } + + void setValid(long ds_id, bool valid) { valid_[ds_id] = valid; } + bool isValid(long ds_id) const + { + assert(hasKey(valid_, ds_id)); + return valid_.at(ds_id); + } + +private: + std::unordered_map params_; + std::unordered_map gathered_params_; + std::unordered_map valid_; +}; + +class CustomOpExecutor { +public: + CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + bool pre_div_reduce) + : process_group_(process_group), + param_registry_(std::move(param_registry)), + reduce_buckets_(std::move(reduce_buckets)), + ds_ids_(std::move(ds_ids)), + nccl_comm_(nccl_comm), + rs_stream_(rs_stream), + copy_stream_(copy_stream), + pre_div_reduce_(pre_div_reduce) + { + for (long ds_id : ds_ids_) { + has_acc_grad_[ds_id] = false; + + rs_comp_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + rs_copy_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + } + reduce_counter_ = ds_ids_.size(); + } + ~CustomOpExecutor() {} + + virtual void startForward() {} + + virtual void endForward() {} + + virtual void startBackward(bool update) { param_updated_ = update; } + + virtual void endBackward() + { + flushAllReduceBuckets(); + + // This synchronization ensures all of reduce calls are done before optimizer's step. + at::cuda::stream_synchronize(rs_stream_); + } + + virtual at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id) + { + int world_size = process_group_->getSize(); + const DSParam& param = param_registry_->getParam(ds_id); + const auto scalar_type = grad_tensor.scalar_type(); + std::shared_ptr reduce_bucket = reduce_buckets_->getBuffer(scalar_type); + + auto comp_stream = at::cuda::getCurrentCUDAStream(); + + if (reduce_bucket->shouldFlush(grad_tensor.numel())) { + int rank = process_group_->getRank(); + + flushReduceBucket(scalar_type); + + // reduce_bucket is swapped in flushReduceBucket if double buffering is enabled + reduce_bucket = reduce_buckets_->getBuffer(scalar_type); + } + + if (grad_tensor.numel() > reduce_bucket->getSize()) { + // extend buckets + at::cuda::stream_synchronize(rs_stream_); + reduce_bucket->reserve(grad_tensor.numel()); + } + + at::Tensor reduce_in_buffer = reduce_bucket->allocate(grad_tensor.numel()); + + // This ensures the order of reduce_scatter -> copy + // Without this block, copy may start while reduce_scatter is still running + reduce_buckets_->getEvent(scalar_type)->block(comp_stream); + auto copy_src = grad_tensor.contiguous().view({-1}).detach(); + // keep references to copy src + reduce_tasks_[scalar_type].emplace_back(ds_id, copy_src, reduce_in_buffer); + + // computation must be done before copy + rs_comp_done_events_[ds_id]->record(comp_stream); + rs_comp_done_events_[ds_id]->block(copy_stream_); + { + at::cuda::CUDAStreamGuard guard(copy_stream_); + reduce_in_buffer.copy_(copy_src, true); + rs_copy_done_events_[ds_id]->record(copy_stream_); + } + + return at::Tensor(); + } + + bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } + +protected: + c10::intrusive_ptr process_group_; + std::shared_ptr param_registry_; + std::shared_ptr reduce_buckets_; + std::vector ds_ids_; + ncclComm_t nccl_comm_; + at::cuda::CUDAStream rs_stream_; + at::cuda::CUDAStream copy_stream_; + + std::unordered_map> rs_comp_done_events_; + std::unordered_map> rs_copy_done_events_; + + size_t reduce_counter_ = 0; + bool param_updated_ = false; + std::unordered_map> reduce_tasks_; + std::unordered_map has_acc_grad_; + bool pre_div_reduce_; + + virtual void flushReduceBucket(at::ScalarType scalar_type) = 0; + + void flushAllReduceBuckets() + { + for (const auto& it : reduce_tasks_) { flushReduceBucket(it.first); } + } + + // Common helper methods for flushReduceBucket implementations + void blockCopyEvents(at::ScalarType scalar_type) + { + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(rs_stream_); + } + } + + void applyPreDivision(at::ScalarType scalar_type) + { + if (pre_div_reduce_) { + at::cuda::CUDAStreamGuard guard(rs_stream_); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + t.getSendBuf().div_(process_group_->getSize()); + } + } + } + + ncclRedOp_t getReductionOp() const { return pre_div_reduce_ ? ncclSum : ncclAvg; } + + void performCleanup(at::ScalarType scalar_type) + { + reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_); + + // Prevent grad tensor from being released before the copy is done + auto comp_stream = at::cuda::getCurrentCUDAStream(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(comp_stream); + } + reduce_tasks_[scalar_type].clear(); + } +}; + +template +std::shared_ptr getExecutor(long graph_id, + const std::unordered_map>& executors) +{ + assert(hasKey(executors, graph_id)); + if (auto executor = std::dynamic_pointer_cast(executors.at(graph_id))) { return executor; } + throw std::runtime_error("Invalid executor type"); +} + +extern std::shared_ptr param_registry; +extern std::unordered_map> executors; +extern std::shared_ptr reduce_buckets; + +at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id); +at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id); +void free_tensors(std::vector tensors); +void free_tensors_meta(std::vector tensors); + +void init(c10::intrusive_ptr pg, + pybind11::object& config, + int64_t initial_reduce_bucket_size); +void reset(); +void cleanup(); + +void start_forward(); +void end_forward(); +void start_backward(bool update); + +} // namespace dc diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index 61d424846589..cb8b0b28484e 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -11,18 +11,30 @@ used throughout the codebase. #pragma once #include +#include + +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) && !defined(__HIP_PLATFORM_AMD__) +#include +#endif #define DS_HD_INLINE __host__ __device__ __forceinline__ #define DS_D_INLINE __device__ __forceinline__ -#ifdef __HIP_PLATFORM_HCC__ - +#ifdef __HIP_PLATFORM_AMD__ +#if BF16_AVAILABLE +struct __hip_bfloat16; +#endif // constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 #include +#include -#else // !__HIP_PLATFORM_HCC__ +#else // !__HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating constexpr int hw_warp_size = 32; @@ -34,12 +46,12 @@ constexpr int hw_warp_size = 32; #if __CUDA_ARCH__ >= 800 #define ASYNC_COPY_AVAILABLE -#define BF16_AVAILABLE #endif // __CUDA_ARCH__ >= 800 #include +#include -#endif //__HIP_PLATFORM_HCC__ +#endif //__HIP_PLATFORM_AMD__ inline int next_pow2(const int val) { diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h index 8cf9ee9ef594..d2056403d265 100644 --- a/csrc/includes/feed_forward.h +++ b/csrc/includes/feed_forward.h @@ -48,7 +48,9 @@ class FeedForward { weights, input_ptr, out, -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[0])); #else cublasGemmAlgo_t(config_.gemm_algos[0])); @@ -77,7 +79,8 @@ class FeedForward { input_ptr, out_grad, weights_grad, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[1])); #else cublasGemmAlgo_t(config_.gemm_algos[1])); @@ -94,7 +97,8 @@ class FeedForward { weights, out_grad, inp_grad_out, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[2])); #else cublasGemmAlgo_t(config_.gemm_algos[2])); diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h index cd9fbb5a4e17..de5b55cd3df1 100644 --- a/csrc/includes/gemm_test.h +++ b/csrc/includes/gemm_test.h @@ -6,9 +6,12 @@ #pragma once #include -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif +#ifdef __HIP_PLATFORM_AMD__ +#include +#endif #include #include #include @@ -64,7 +67,9 @@ class GemmTest { B, A, C, -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -83,7 +88,8 @@ class GemmTest { A, C, B, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -102,7 +108,8 @@ class GemmTest { B, C, A, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -118,8 +125,11 @@ class GemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#elif defined(__HIP_PLATFORM_AMD__) + for (int algo = (int)HIPBLAS_GEMM_DEFAULT; algo <= (int)HIPBLAS_GEMM_DEFAULT; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; @@ -208,7 +218,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -242,7 +253,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -273,7 +285,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -289,11 +302,17 @@ class StridedGemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#else +#ifdef __HIP_PLATFORM_AMD__ + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; +#endif #endif algo++) { int warm_up = 5; diff --git a/csrc/includes/general_kernels.h b/csrc/includes/general_kernels.h index 28e2cbf2984f..bd621d3c4329 100644 --- a/csrc/includes/general_kernels.h +++ b/csrc/includes/general_kernels.h @@ -8,7 +8,7 @@ #include #include -#ifdef __HIP_PLATFORM_HCC__ +#ifdef __HIP_PLATFORM_AMD__ #include #else #include diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h index 6789714d27c7..cb990b75bbe8 100644 --- a/csrc/includes/memory_access_utils.h +++ b/csrc/includes/memory_access_utils.h @@ -868,6 +868,35 @@ __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(voi #endif } +template <> +__device__ __forceinline__ void store_global<2>(void* dst, const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + /////////// Store Shared /////////// template <> diff --git a/csrc/includes/quantization.h b/csrc/includes/quantization.h index 826797889ebb..5bdc96061a31 100644 --- a/csrc/includes/quantization.h +++ b/csrc/includes/quantization.h @@ -40,6 +40,62 @@ void launch_dequantize_kernel(T* dequant_data, int total_elems, cudaStream_t stream); +void launch_swizzled_quant(int8_t* q_data, + float* q_scales, + const __half* input_data, + int num_bits, + quantize::Type q_type, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream); + +void launch_loco_swizzled_quant(int8_t* quantized_data, + float* quantized_scales, + const __half* uncompressed_data, + __half* error_feedback, + const float err_beta, + int num_bits, + quantize::Type quant_type, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream); + +void launch_loco_dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int num_gpus, + int num_bits, + quantize::Type quant_type, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + __half2* error_feedback, + const float err_beta, + cudaStream_t stream); + +void launch_dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int num_gpus, + int num_bits, + quantize::Type quant_type, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + cudaStream_t stream); + template void launch_fake_quantize_kernel(T* vals, int total_count, @@ -64,3 +120,19 @@ void launch_sr_fake_quantize_kernel_asym(T* vals, int group_num, int num_bits, cudaStream_t stream); + +void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream); + +void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream); diff --git a/csrc/includes/quantization_utils.h b/csrc/includes/quantization_utils.h index 801bb8e2421e..94958fb455c6 100644 --- a/csrc/includes/quantization_utils.h +++ b/csrc/includes/quantization_utils.h @@ -24,6 +24,7 @@ constexpr int max_threads = 1024; Class to hold the quantization parameters for a given tensor. Holds the implementation of the quantization operation. */ + template class Params { public: @@ -102,9 +103,9 @@ class Params { if (max == min) { scale = 1.0; } else { - scale = (1 << numBits) / (max - min); + scale = ((1 << numBits)) / (max - min); } - offset = -(1 << (numBits - 1)) - (min * scale); + offset = (max + min) / 2; } DS_D_INLINE int8_t quantize(__half val) @@ -112,7 +113,7 @@ class Params { constexpr int32_t q_min = -(1 << (numBits - 1)); constexpr int32_t q_max = (1 << (numBits - 1)) - 1; - float val_f = conversion::to(val) * scale + offset; + float val_f = (conversion::to(val) - offset) * scale; int32_t data_i32 = conversion::to(val_f); data_i32 = min(max(data_i32, q_min), q_max); return (int8_t)data_i32; @@ -121,8 +122,8 @@ class Params { template DS_D_INLINE T dequantize(int8_t val) { - const float val_deq_f = conversion::to(val) * scale + offset; - return conversion::to<__half>(val_deq_f); + const float val_deq_f = ((conversion::to(val)) * scale) + offset; + return conversion::to(val_deq_f); } DS_D_INLINE void store(float* params, int group_index) diff --git a/csrc/includes/quantizer.h b/csrc/includes/quantizer.h index 5265c1fe612a..f4f63160d79b 100644 --- a/csrc/includes/quantizer.h +++ b/csrc/includes/quantizer.h @@ -5,7 +5,12 @@ #pragma once +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif + #include #include #include diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index 54427983b021..68ec106975b6 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -9,6 +9,10 @@ #include "ds_kernel_utils.h" #include "memory_access_utils.h" +#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__) +#include +#endif + namespace cg = cooperative_groups; namespace reduce { @@ -145,6 +149,13 @@ of reduce should be straightforward (can just wrap the sum reduction) and would be a good extension of the header. */ +DS_D_INLINE int _warp_rank() +{ + const int thread_rank = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + return thread_rank / hw_warp_size; +} + /* Float element reduce implementations */ template <> DS_D_INLINE float element(const float lhs, const float rhs) @@ -152,6 +163,12 @@ DS_D_INLINE float element(const float lhs, const float rhs) return lhs + rhs; } +template <> +DS_D_INLINE double element(const double lhs, const double rhs) +{ + return lhs + rhs; +} + template <> DS_D_INLINE float element(const float lhs, const float rhs) { @@ -182,6 +199,19 @@ DS_D_INLINE __half element(const __half lhs, const __half rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 element(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} +#endif + template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { @@ -213,6 +243,21 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 element(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __nv_bfloat162 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} +#endif + template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { @@ -226,6 +271,60 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) #endif } +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + /* Reduction initialization primitives */ @@ -234,6 +333,11 @@ DS_D_INLINE float init() { return 0.0f; } +template <> +DS_D_INLINE double init() +{ + return (double)0.0f; +} template <> DS_D_INLINE float init() @@ -270,25 +374,122 @@ DS_D_INLINE __half init() return __half(neg_inf); } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + constexpr __hip_bfloat16_raw neg_inf = {0xFF80}; +#else + constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; +#endif + return __nv_bfloat16(neg_inf); +} +#endif + template <> DS_D_INLINE __half2 init() { +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x0000, 0x0000}}; +#else constexpr __half2_raw zero = {0x0000, 0x0000}; return __half2(zero); +#endif } template <> DS_D_INLINE __half2 init() { +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x7C00, 0x7C00}}; +#else constexpr __half2_raw inf = {0x7C00, 0x7C00}; return __half2(inf); +#endif } template <> DS_D_INLINE __half2 init() { +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0xFC00, 0xFC00}}; +#else constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; return __half2(neg_inf); +#endif +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; } template @@ -333,8 +534,8 @@ here (fold is C++17 only and I don't think helps and recursion feels like huge overkill that harms readability) that would be wonderful. */ -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -342,8 +543,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -352,8 +553,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -363,8 +564,13 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -375,52 +581,71 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } +#if defined(__HIP_PLATFORM_AMD__) +template +DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile& warp_arg, T* data) +{ + constexpr int elems = sizeof...(Ops); + if constexpr (!(std::is_integral::value || std::is_floating_point::value)) { + float temp_data[elems]; +#pragma unroll + for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to(data[i]); } + _warp(warp_arg, temp_data); +#pragma unroll + for (int i = 0; i < elems; i++) { data[i] = conversion::to(temp_data[i]); } + } else { + _warp(warp_arg, data); + } +} +#endif // defined(__HIP_PLATFORM_AMD__) + /* Implementation for primary block reduction that serves both `block` and `partitioned_block`. -`local_warp_rank` refers to the warp's location within the partition, so -for an unpartitioned threadblock this will be equivalent to -`warp_arg.meta_group_rank()`. - -Similarly, the warp offset is the `local_warp_rank` of the warp with the -lowest rank in the partition. In the case of an 8 warp block with a -4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4]. - -Partition size is the number of warps per partition (equal to the thread -block in the default case). This enables us to only perform the warp reduction -when able to. +Total warps refers to the reduction width of the reduction, not +the number of warps in the block (which may exceed that +if the block is partitioned or if we do a conservative bound at +compile time). */ -template +template DS_D_INLINE void _block(cg::thread_block& tb, cg::thread_block_tile& warp_arg, - float* data, - int warp_offset) + T* data) { constexpr int elems = sizeof...(Ops); - // Separated for now in case this no longer is true - constexpr int bytes = sizeof(float); + constexpr int bytes = sizeof(T); // Unused when `partition_size == 1` or total_warps == 1 - __shared__ float reduce_buffer[max_warps * elems]; + __shared__ T reduce_buffer[max_warps * elems]; + +#ifdef __HIP_PLATFORM_AMD__ + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + const int running_warps = total_threads / hw_warp_size; +#else + const int running_warps = warp_arg.meta_group_size(); +#endif // Always perform warp-scope reduction - _warp(warp_arg, data); +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else + _warp(warp_arg, data); +#endif // If max_warps == 1 let's skip the runtime check - if (warp_arg.meta_group_size() > 1 && total_warps != 1) { + if (total_warps != 1) { if (warp_arg.thread_rank() == 0) { #pragma unroll for (int i = 0; i < elems; i++) { - mem_access::store_shared( - reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i); + mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); } } // Synchronization inside block-uniform conditional is safe tb.sync(); - if (warp_arg.meta_group_rank() == 0) { - if (warp_arg.thread_rank() < warp_arg.meta_group_size()) { + if (_warp_rank() == 0) { + if (warp_arg.thread_rank() < running_warps) { #pragma unroll for (int i = 0; i < elems; i++) { mem_access::load_shared( @@ -429,8 +654,11 @@ DS_D_INLINE void _block(cg::thread_block& tb, } else { init(data); } - - _warp(warp_arg, data); +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else + _warp(warp_arg, data); +#endif #pragma unroll for (int i = 0; i < elems; i++) { @@ -444,8 +672,7 @@ DS_D_INLINE void _block(cg::thread_block& tb, #pragma unroll for (int i = 0; i < elems; i++) { - mem_access::load_shared(data + i, - reduce_buffer + warp_arg.meta_group_rank() * elems + i); + mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); } } } @@ -460,7 +687,7 @@ us to obfuscate the details of the partitioned implementation. template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) { - _block(tb, warp, &val, 0); + _block(tb, warp, &val); } template @@ -470,7 +697,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val2) { float data[2] = {val1, val2}; - _block(tb, warp, data, 0); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; } @@ -483,7 +710,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val3) { float data[3] = {val1, val2, val3}; - _block(tb, warp, data, 0); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; @@ -498,7 +725,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val4) { float data[4] = {val1, val2, val3, val4}; - _block(tb, warp, data, 0); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; @@ -515,11 +742,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float& val) { if (num_threads <= hw_warp_size) { - _warp(warp, &val); + _warp(warp, &val); } else { constexpr int num_warps = num_threads / hw_warp_size; - const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1); - _block(tb, warp, &val, warp_offset); + _block(tb, warp, &val); } } @@ -532,11 +758,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[2] = {val1, val2}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1); - _block(tb, warp, data, warp_offset); + _block(tb, warp, data); } val1 = data[0]; @@ -553,11 +778,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[3] = {val1, val2, val3}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1); - _block(tb, warp, data, warp_offset); + _block(tb, warp, data); } val1 = data[0]; @@ -576,11 +800,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[4] = {val1, val2, val3, val4}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1); - _block(tb, warp, data, warp_offset); + _block(tb, warp, data); } val1 = data[0]; @@ -589,4 +812,48 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, val4 = data[3]; } +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + } // namespace reduce diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 712dd5b32e96..a205026ec7c1 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -12,8 +12,22 @@ #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) +#include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +template +inline T readAs(const void* src) +{ + T res; + std::memcpy(&res, src, sizeof(T)); + return res; +} +template +inline void writeAs(void* dst, const T& val) +{ + std::memcpy(dst, &val, sizeof(T)); +} + +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) @@ -24,13 +38,58 @@ #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_AND(x, y) _mm512_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm512_or_ps(x, y) +#define SIMD_XOR(x, y) _mm512_xor_ps(x, y) #define SIMD_WIDTH 16 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) +static __m512 load_16_bf16_as_f32(const void* data) +{ + __m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing + __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32 + __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by + // 16 bits (representing bf16->f32) + return readAs<__m512>(&c); // use memcpy to avoid aliasing +} + +static void store_16_f32_as_bf16_nearest(__m512 v, void* data) +{ + __m512i u32 = readAs<__m512i>(&v); + + // flow assuming non-nan: + + // uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + __m512i b = _mm512_srli_epi32(u32, 16); + __m512i lsb_mask = _mm512_set1_epi32(0x00000001); + __m512i c = _mm512_and_si512(b, lsb_mask); + __m512i bias_constant = _mm512_set1_epi32(0x00007fff); + __m512i rounding_bias = _mm512_add_epi32(c, bias_constant); + + // uint16_t res = static_cast((U32 + rounding_bias) >> 16); + __m512i d = _mm512_add_epi32(u32, rounding_bias); + __m512i e = _mm512_srli_epi32(d, 16); + __m256i non_nan_res = _mm512_cvtusepi32_epi16(e); + + // handle nan (exp is all 1s and mantissa != 0) + // if ((x & 0x7fffffffU) > 0x7f800000U) + __m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff); + __m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign); + __m512i nan_threshold = _mm512_set1_epi32(0x7f800000); + __mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT); + + // mix in results with nans as needed + __m256i nans = _mm256_set1_epi16(0x7fc0); + __m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans); + + writeAs(data, res); +} +#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x) +#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x) + +#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) +#define SIMD_STORE_FP16(x, d) \ + _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m256i #elif defined(__AVX256__) @@ -42,13 +101,17 @@ #define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_AND(x, y) _mm256_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm256_or_ps(x, y) +#define SIMD_XOR(x, y) _mm256_xor_ps(x, y) #define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) +#define SIMD_LOAD_BF16(x) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_STORE_BF16(x, d) static_assert(false && "AVX256 does not support BFloat16") +#define SIMD_LOAD_FP16(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) +#define SIMD_STORE_FP16(x, d) \ + _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #define INTV __m128i #endif @@ -62,20 +125,66 @@ union AVX_Data { // float data_f[16]; }; -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +template +inline typename std::enable_if_t, void> simd_store(T* dst, + AVX_Data* src) { - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); + size_t width = SIMD_WIDTH; #pragma unroll - for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } + for (size_t i = 0; i < span; ++i) { SIMD_STORE_FP16((float*)(dst + width * i), src[i].data); } } -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) + +template +inline typename std::enable_if_t, void> simd_store(T* dst, + AVX_Data* src) +{ +#ifdef __AVX512__ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE_BF16((float*)(dst + width * i), src[i].data); } +#else + throw std::runtime_error("AVX512 required for BFloat16"); +#endif +} + +template +inline typename std::enable_if_t, void> simd_store(T* dst, AVX_Data* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE(dst + width * i, src[i].data); } +} + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, + T* src) { - size_t width = (half_precision ? 1 : SIMD_WIDTH); + size_t width = SIMD_WIDTH; #pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_FP16((float*)(src + width * i)); } } + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, + T* src) +{ +#ifdef __AVX512__ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_BF16((float*)(src + width * i)); } +#else + throw std::runtime_error("AVX512 required for BFloat16"); +#endif +} + +template +inline typename std::enable_if_t, void> simd_load(AVX_Data* dst, T* src) +{ + size_t width = SIMD_WIDTH; +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD(src + width * i); } +} + template inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) { @@ -136,5 +245,55 @@ inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) #pragma unroll for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } } +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data); + } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); } +} #endif diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index cd126f4b0584..9767fcf589b8 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -77,7 +77,9 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -105,7 +107,8 @@ class StridedBatchGemm { stride_b, stride_c, _config.batch_size, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -149,7 +152,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[1])); #else cublasGemmAlgo_t(_config.gemm_algos[1])); @@ -178,7 +182,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[2])); #else cublasGemmAlgo_t(_config.gemm_algos[2])); diff --git a/csrc/lamb/fused_lamb_cuda.cpp b/csrc/lamb/fused_lamb_cuda.cpp index 4b5175e924bf..c629b93517d2 100644 --- a/csrc/lamb/fused_lamb_cuda.cpp +++ b/csrc/lamb/fused_lamb_cuda.cpp @@ -26,7 +26,7 @@ void fused_lamb_cuda(at::Tensor& p, at::Tensor& u_l2_i, at::Tensor& lamb_coeff_val); -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ diff --git a/csrc/lamb/fused_lamb_cuda_kernel.cu b/csrc/lamb/fused_lamb_cuda_kernel.cu index ca94a8e5ec2c..d9bacae73457 100644 --- a/csrc/lamb/fused_lamb_cuda_kernel.cu +++ b/csrc/lamb/fused_lamb_cuda_kernel.cu @@ -17,7 +17,7 @@ #include // #include -#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305 +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION > 305 #include #else #include @@ -109,7 +109,7 @@ __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b) cg::sync(cta); -#if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 502) +#if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 502) if (tid < 32) { cg::coalesced_group active = cg::coalesced_threads(); diff --git a/csrc/lion/cpu_lion.cpp b/csrc/lion/cpu_lion.cpp new file mode 100644 index 000000000000..c5cf3e9e9235 --- /dev/null +++ b/csrc/lion/cpu_lion.cpp @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_lion.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)"); + m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)"); + m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)"); +} diff --git a/csrc/lion/cpu_lion_impl.cpp b/csrc/lion/cpu_lion_impl.cpp new file mode 100644 index 000000000000..6a98162314f9 --- /dev/null +++ b/csrc/lion/cpu_lion_impl.cpp @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cpu_lion.h" + +using namespace std::string_literals; +static std::unordered_map> s_optimizers; + +// C++ interface + +template +void Lion_Optimizer::Step_1(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float alpha = _alpha; + float after_decay = 1 - alpha * _weight_decay; + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = (float)grads[k]; + float param = (float)_params[k]; + float momentum = _exp_avg[k]; + float tmp = momentum * _betta1; + tmp = grad * betta1_minus1 + tmp; + // Rely on portable C++ methods to manipulate the sign bit of a floating-point + // number. + tmp = -std::copysignf(alpha, tmp); + if (_weight_decay > 0) { + param = param * after_decay + tmp; + } else { + param = param + tmp; + } + momentum = momentum * _betta2; + momentum = grad * betta2_minus1 + momentum; + _params[k] = param; + _exp_avg[k] = momentum; + } + } + } +} + +template +void Lion_Optimizer::Step_4(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_param_size - rounded_size)); +} + +int create_lion_optimizer(int optimizer_id, + float alpha, + float betta1, + float betta2, + float weight_decay, + bool should_log) +{ + auto opt = std::make_shared(alpha, betta1, betta2, weight_decay); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Lion Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f\n", + alpha, + betta1, + betta2, + weight_decay); + } + + return 0; +} + +template +void Lion_Optimizer::Step_8(ds_params_precision_t* _params, + ds_params_precision_t* grads, + ds_state_precision_t* _exp_avg, + size_t _param_size) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_param_size - rounded_size)); +} + +template +void step_invoker(std::shared_ptr opt, + void* _params, + void* grads, + void* _exp_avg, + size_t _param_size) +{ + opt->Step_8((ds_params_precision_t*)(_params), + (ds_params_precision_t*)(grads), + (ds_state_precision_t*)(_exp_avg), + _param_size); +} + +std::map, + std::function, void*, void*, void*, size_t)>> + invokers; + +// Fill map with template functions for each type +template +void create_invoker() +{ + invokers[std::tuple(c10::CppTypeToScalarType(), + c10::CppTypeToScalarType())] = + step_invoker; +} +struct InvokerInitializer { + InvokerInitializer() + { + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + create_invoker(); + } +} _invoker_initializer; + +void invoke(std::shared_ptr opt, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + size_t param_size) +{ + c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype()); + c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype()); + + auto it = invokers.find(std::tuple(params_type, state_type)); + if (it == invokers.end()) { + throw std::runtime_error("Lion optimizer with param type "s + c10::toString(params_type) + + " and state type "s + c10::toString(state_type) + + " is not supported on current hardware"s); + } + + it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg.data_ptr(), param_size); +} + +int ds_lion_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, weight_decay); + + invoke(opt, params_c, grads_c, exp_avg_c, params_c.numel()); + + return 0; +} + +int destroy_lion_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} diff --git a/csrc/lion/fused_lion_frontend.cpp b/csrc/lion/fused_lion_frontend.cpp new file mode 100644 index 000000000000..e523f97ca309 --- /dev/null +++ b/csrc/lion/fused_lion_frontend.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +void multi_tensor_lion_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const int step, + const float weight_decay); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_lion", + &multi_tensor_lion_cuda, + "Compute and apply gradient update to parameters for Lion optimizer"); +} diff --git a/csrc/lion/multi_tensor_apply.cuh b/csrc/lion/multi_tensor_apply.cuh new file mode 100644 index 000000000000..12f41cb49c6b --- /dev/null +++ b/csrc/lion/multi_tensor_apply.cuh @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +#include +#include "compat.h" + +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply(int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/csrc/lion/multi_tensor_lion.cu b/csrc/lion/multi_tensor_lion.cu new file mode 100644 index 000000000000..f5fe6dfdd7d0 --- /dev/null +++ b/csrc/lion/multi_tensor_lion.cu @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +using MATH_T = float; + +template +struct LionFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<3>& tl, + const float beta1, + const float beta2, + const float lr, + const float decay) + { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + MATH_T after_decay = 1.0f - lr * decay; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + MATH_T c = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + MATH_T update = c > 0 ? (-lr) : lr; + r_p[ii] = r_p[ii] * after_decay + update; + r_m[ii] = beta2 * r_m[ii] + (1 - beta2) * r_g[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + } + } + } + } +}; + +void multi_tensor_lion_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const int step, + const float weight_decay) +{ + using namespace at; + + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "lion", + multi_tensor_apply<3>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LionFunctor(), + beta1, + beta2, + lr, + weight_decay);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/csrc/quantization/fake_quantizer.cu b/csrc/quantization/fake_quantizer.cu index 0f53e5235921..4c08cd4cc3d2 100644 --- a/csrc/quantization/fake_quantizer.cu +++ b/csrc/quantization/fake_quantizer.cu @@ -11,7 +11,7 @@ namespace cg = cooperative_groups; __global__ void fake_quantize_kernel(__half* vals, int group_size, int num_bits) { -#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) cg::thread_block b = cg::this_thread_block(); // tb cg::thread_block_tile<32> g = @@ -197,7 +197,7 @@ __global__ void sr_fake_quantize_kernel(__half* vals, int num_bits, std::pair seed) { -#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); @@ -475,7 +475,7 @@ template void launch_sr_fake_quantize_kernel(__half* vals, __global__ void fake_quantize_kernel_asym(__half* vals, int group_size, int num_bits) { -#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); @@ -720,7 +720,7 @@ __global__ void sr_fake_quantize_kernel_asym(__half* vals, int num_bits, std::pair seed) { -#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) +#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_AMD__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); diff --git a/csrc/quantization/pt_binding.cpp b/csrc/quantization/pt_binding.cpp index ccc0c15be1a6..b48eaacd0881 100644 --- a/csrc/quantization/pt_binding.cpp +++ b/csrc/quantization/pt_binding.cpp @@ -136,6 +136,237 @@ at::Tensor dequantize(at::Tensor& quantized_data, return output; } +at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in, + at::Tensor& scale_buffer, + at::Tensor& min_val_buffer, + int num_group, + int group_size) +{ + auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto output = torch::empty({num_group, group_size}, output_options); + + launch_dequantize_int4_to_half_experimental((uint8_t*)data_in.data_ptr(), + (half*)output.data_ptr(), + (half*)scale_buffer.data_ptr(), + (half*)min_val_buffer.data_ptr(), + num_group, + group_size, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in, + at::Tensor& scale_buffer, + at::Tensor& min_val_buffer, + int num_group, + int group_size) +{ + auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto output = torch::empty({num_group, group_size}, output_options); + + launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(), + (half*)output.data_ptr(), + (half*)scale_buffer.data_ptr(), + (half*)min_val_buffer.data_ptr(), + num_group, + group_size, + at::cuda::getCurrentCUDAStream()); + + return output; +} + +std::vector ds_loco_swizzle_quant(at::Tensor& input_vals, + at::Tensor& error_feedback, + float err_beta, + int groups, + int num_bits, + quantize::Type quant_type, + int pipeline_size, + int nodes, + int devices_per_node) +{ + auto scales_options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1; + auto scales = torch::empty({groups, scales_elems}, scales_options); + + auto output_options = at::TensorOptions() + .dtype(at::kChar) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + const int quantization_scalar = 8 / num_bits; + const int compressed_vals = at::numel(input_vals) / quantization_scalar; + + auto output = torch::empty({compressed_vals}, output_options); + const int elems_per_group = at::numel(input_vals) / groups; + + launch_loco_swizzled_quant(reinterpret_cast(output.data_ptr()), + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(input_vals.data_ptr()), + reinterpret_cast<__half*>(error_feedback.data_ptr()), + err_beta, + num_bits, + quant_type, + groups, + elems_per_group, + pipeline_size, + nodes, + devices_per_node, + at::cuda::getCurrentCUDAStream()); + + return {output, scales}; +} + +std::vector ds_swizzle_quant(at::Tensor& input_vals, + int groups, + int num_bits, + quantize::Type quant_type, + int pipeline_size, + int nodes, + int devices_per_node) +{ + auto scales_options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1; + auto scales = torch::empty({groups, scales_elems}, scales_options); + + auto output_options = at::TensorOptions() + .dtype(at::kChar) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + const int quantization_scalar = 8 / num_bits; + const int compressed_vals = at::numel(input_vals) / quantization_scalar; + + auto output = torch::empty({compressed_vals}, output_options); + const int elems_per_group = at::numel(input_vals) / groups; + + launch_swizzled_quant((int8_t*)output.data_ptr(), + (float*)scales.data_ptr(), + (__half*)input_vals.data_ptr(), + num_bits, + quant_type, + groups, + elems_per_group, + pipeline_size, + nodes, + devices_per_node, + at::cuda::getCurrentCUDAStream()); + + return {output, scales}; +} + +std::vector quantized_reduction(at::Tensor& input_vals, + at::Tensor& input_scales, + int in_groups, + int out_groups, + int num_bits, + quantize::Type quant_type, + int devices_per_node) +{ + auto scales_options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1; + auto scales = torch::empty({out_groups, scales_elems}, scales_options); + + auto output_options = at::TensorOptions() + .dtype(at::kChar) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + std::vector sz(input_vals.sizes().begin(), input_vals.sizes().end()); + sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes + const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node; + auto output = torch::empty(sz, output_options); + + const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node); + const int elems_per_out_group = elems_per_in_tensor / out_groups; + + launch_dequant_reduce((int8_t*)output.data_ptr(), + (float*)scales.data_ptr(), + (const int8_t*)input_vals.data_ptr(), + (const float*)input_scales.data_ptr(), + devices_per_node, + num_bits, + quant_type, + out_groups, + elems_per_out_group, + elems_per_in_tensor, + in_groups / devices_per_node, + elems_per_in_group, + at::cuda::getCurrentCUDAStream()); + return {output, scales}; +} + +std::vector loco_quantized_reduction(at::Tensor& input_vals, + at::Tensor& input_scales, + at::Tensor& error_feedback, + float err_beta, + int in_groups, + int out_groups, + int num_bits, + quantize::Type quant_type, + int devices_per_node) +{ + auto scales_options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1; + + auto scales = torch::empty({out_groups, scales_elems}, scales_options); + + auto output_options = at::TensorOptions() + .dtype(at::kChar) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + std::vector sz(input_vals.sizes().begin(), input_vals.sizes().end()); + sz[sz.size() - 1] = sz.back() / devices_per_node; + + const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node; + + auto output = torch::empty(sz, output_options); + + const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node); + const int elems_per_out_group = elems_per_in_tensor / out_groups; + + launch_loco_dequant_reduce((int8_t*)output.data_ptr(), + (float*)scales.data_ptr(), + (const int8_t*)input_vals.data_ptr(), + (const float*)input_scales.data_ptr(), + devices_per_node, + num_bits, + quant_type, + out_groups, + elems_per_out_group, + elems_per_in_tensor, + in_groups / devices_per_node, + elems_per_in_group, + (__half2*)error_feedback.data_ptr(), + err_beta, + at::cuda::getCurrentCUDAStream()); + + return {output, scales}; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ds_quantize_fp32", &ds_quantize, "DeepSpeed Quantize with fp32 (CUDA)"); @@ -158,4 +389,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("quantize", &quantize_kernel); m.def("dequantize", &dequantize<__half>); m.def("dequantize_fp32", &dequantize); + m.def("dequantize_int4_to_half_experimental", + &dequantize_int4_to_half_experimental, + "Dequantize int4 to half (experimental)"); + m.def("dequantize_int8_to_half_experimental", + &dequantize_int8_to_half_experimental, + "Dequantize int8 to half (experimental)"); + m.def("swizzle_quant", &ds_swizzle_quant); + m.def("quantized_reduction", &quantized_reduction); + m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel"); + m.def("loco_quantized_reduction", + &loco_quantized_reduction, + "LoCo Quantization and Reduction Kernel"); } diff --git a/csrc/quantization/quant_reduce.cu b/csrc/quantization/quant_reduce.cu new file mode 100644 index 000000000000..4100c5174b80 --- /dev/null +++ b/csrc/quantization/quant_reduce.cu @@ -0,0 +1,557 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "dequantization_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" +#include "quantization_utils.h" +#include "reduction_utils.h" + +using rop = reduce::ROpType; + +/* +TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense +to leverage some parallel reductions here to improve performance. +*/ + +template +__global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + int num_tensors) +{ + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // NOTE(cmikeh2): This probably could be hardcoded to a larger number, + // but that means even stronger restrictions on the number of elements per group + // A performance analysis here might be beneficial + constexpr int mem_granularity = (numBits == 8) ? 8 : 4; + constexpr int elems_per_load = mem_granularity / sizeof(int8_t); // div by 1 + constexpr int storage_values = 16 / sizeof(__half2); + + const int block_offset = tb.group_index().x * elems_per_out_group; + const int elem_offset = tb.thread_index().x * elems_per_load; + const int base_offset = block_offset + elem_offset; + const int stride = tb.group_dim().x * elems_per_load; + + __half2 local_buffer[totalChunks * storage_values]; + + quantize::GroupStats stats; + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + __half2* iteration_buffer = local_buffer + i * storage_values; + +#pragma unroll + for (int j = 0; j < storage_values; j++) { + iteration_buffer[j] = reduce::init(); + } + + const int iter_offset = i * stride + base_offset; + const int iter_scale_idx = iter_offset / elems_per_in_group; + bool do_loads = i * stride + elem_offset < elems_per_out_group; + + if (numTensors > 0) { +#pragma unroll + for (int j = 0; j < numTensors; j++) { + if (do_loads) { + int8_t load_buffer[elems_per_load]; + + mem_access::load_global( + load_buffer, input_data + j * elems_per_in_tensor + iter_offset); + + quantize::Params params( + input_scales + j * groups_per_in_tensor, iter_scale_idx); + + __half2 dequant_buffer[storage_values]; + dequantize::chunk(dequant_buffer, load_buffer, params); + +#pragma unroll + for (int k = 0; k < storage_values; k++) { + iteration_buffer[k] = + reduce::element(iteration_buffer[k], dequant_buffer[k]); + } + } + } + } else { +#pragma unroll 4 + for (int j = 0; j < num_tensors; j++) { + if (do_loads) { + int8_t load_buffer[elems_per_load]; + + mem_access::load_global( + load_buffer, input_data + j * elems_per_in_tensor + iter_offset); + + quantize::Params params( + input_scales + j * groups_per_in_tensor, iter_scale_idx); + + __half2 dequant_buffer[storage_values]; + dequantize::chunk(dequant_buffer, load_buffer, params); + +#pragma unroll + for (int k = 0; k < storage_values; k++) { + iteration_buffer[k] = + reduce::element(iteration_buffer[k], dequant_buffer[k]); + } + } + } + } + +#pragma unroll + for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); } + } + + auto params = stats.template get_params(tb, warp); + + if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); } + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + const int iter_offset = i * stride + base_offset; + if (i * stride + elem_offset < elems_per_out_group) { + int8_t local_output[elems_per_load]; + quantize::_chunk( + local_output, local_buffer + i * storage_values, params); + mem_access::store_global(reduced_data + iter_offset, local_output); + } + } +} + +template +int32_t pow2_round(int32_t raw_value) +{ + return (((raw_value - 1) >> Power) + 1) << Power; +} + +#define LAUNCH_DEQUANT_REDUCE(num_chunks) \ + dequant_reduce \ + <<>>(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_tensors); + +template +void launch_dequant_reduce_impl(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + int num_tensors, + cudaStream_t stream) +{ + // This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4 + constexpr int elems_per_thread = numBits; + const int one_step_threads = + next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread)); + // TODO(cmikeh2): Tune this + const int threads = (one_step_threads < 1024) ? one_step_threads : 1024; + + dim3 block(threads); + dim3 grid(out_groups); + + const int elems_per_step = threads * elems_per_thread; + const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step; + + const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw; + + if (unroll == 1) { + // 0-4096 elems + LAUNCH_DEQUANT_REDUCE(1); + } else if (unroll == 2) { + // 4097-8192 etc... + LAUNCH_DEQUANT_REDUCE(2); + } else if (unroll == 3) { + LAUNCH_DEQUANT_REDUCE(3); + } else if (unroll == 4) { + LAUNCH_DEQUANT_REDUCE(4); + } else if (unroll == 6) { + LAUNCH_DEQUANT_REDUCE(6); + } else if (unroll == 8) { + LAUNCH_DEQUANT_REDUCE(8); + } else if (unroll == 10) { + LAUNCH_DEQUANT_REDUCE(10); + } else if (unroll == 12) { + // 48k limit + LAUNCH_DEQUANT_REDUCE(12); + } else { + assert(false); + } +} + +#define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \ + launch_dequant_reduce_impl(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + out_groups, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_gpus, \ + stream); + +void launch_dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int num_gpus, + int num_bits, + quantize::Type quant_type, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + cudaStream_t stream) +{ + if (quant_type == quantize::Type::Symmetric) { + if (num_bits == 4) { + if (num_gpus == 8) { + LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric); + } else if (num_gpus == 16) { + LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric); + } else { + LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric); + } + } else if (num_bits == 8) { + if (num_gpus == 8) { + LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric); + } else if (num_gpus == 16) { + LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric); + } else { + LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric); + } + } + } else if (quant_type == quantize::Type::Asymmetric) { + if (num_bits == 4) { + if (num_gpus == 8) { + LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric); + } else if (num_gpus == 16) { + LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric); + } else { + LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric); + } + } else if (num_bits == 8) { + if (num_gpus == 8) { + LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric); + } else if (num_gpus == 16) { + LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric); + } else { + LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric); + } + } + } +} + +/* +Modified loco_dequant_reduce function that performs dequantization and reduction, +and incorporates error-feedback by updating the error_feedback tensor in-place. +*/ + +template +__global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + int num_tensors, + __half2* error_feedback, + const float err_beta) +{ + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + constexpr int mem_granularity = (numBits == 8) ? 8 : 4; + constexpr int elems_per_load = mem_granularity / sizeof(int8_t); + constexpr int storage_values = 16 / sizeof(__half2); + + const int block_offset = tb.group_index().x * elems_per_out_group; + const int elem_offset = tb.thread_index().x * elems_per_load; + const int base_offset = block_offset + elem_offset; + const int stride = tb.group_dim().x * elems_per_load; + + constexpr int scaling_factor = elems_per_load / storage_values; + const int block_offset_err = block_offset / scaling_factor; + const int elem_offset_err = tb.thread_index().x * storage_values; + const int base_offset_err = block_offset_err + elem_offset_err; + const int stride_err = tb.group_dim().x * storage_values; + + __half2 local_buffer[totalChunks * storage_values]; + __half2 err_buffer[totalChunks * storage_values]; + + quantize::GroupStats stats; + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + __half2* iteration_buffer = local_buffer + i * storage_values; + __half2* iter_err_buffer = err_buffer + i * storage_values; + +#pragma unroll + for (int j = 0; j < storage_values; j++) { + iteration_buffer[j] = reduce::init(); + } + + const int iter_offset = i * stride + base_offset; + const int iter_offset_err = i * stride_err + base_offset_err; + const int iter_scale_idx = iter_offset / elems_per_in_group; + bool do_loads = i * stride + elem_offset < elems_per_out_group; + + if (numTensors > 0) { +#pragma unroll + for (int j = 0; j < numTensors; j++) { + if (do_loads) { + int8_t load_buffer[elems_per_load]; + + mem_access::load_global( + load_buffer, input_data + j * elems_per_in_tensor + iter_offset); + + quantize::Params params( + input_scales + j * groups_per_in_tensor, iter_scale_idx); + + __half2 dequant_buffer[storage_values]; + dequantize::chunk(dequant_buffer, load_buffer, params); + +#pragma unroll + for (int k = 0; k < storage_values; k++) { + iteration_buffer[k] = + reduce::element(iteration_buffer[k], dequant_buffer[k]); + } + } + } + } else { +#pragma unroll 4 + for (int j = 0; j < num_tensors; j++) { + if (do_loads) { + int8_t load_buffer[elems_per_load]; + + mem_access::load_global( + load_buffer, input_data + j * elems_per_in_tensor + iter_offset); + + quantize::Params params( + input_scales + j * groups_per_in_tensor, iter_scale_idx); + + __half2 dequant_buffer[storage_values]; + dequantize::chunk(dequant_buffer, load_buffer, params); + +#pragma unroll + for (int k = 0; k < storage_values; k++) { + iteration_buffer[k] = + reduce::element(iteration_buffer[k], dequant_buffer[k]); + } + } + } + } + mem_access::load_global( + iter_err_buffer, error_feedback + iter_offset_err, do_loads); +#pragma unroll + for (int k = 0; k < storage_values; k++) { + iteration_buffer[k] = __hadd2(iteration_buffer[k], iter_err_buffer[k]); + stats.update(iteration_buffer[k]); + } + } + + auto params = stats.template get_params(tb, warp); + + // Initialize dequantization parameters based on params + auto de_params = params; + de_params.scale = 1.0f / params.scale; + if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; } + + if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); } + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + const int iter_offset = i * stride + base_offset; + const int iter_offset_err = i * stride_err + base_offset_err; + __half2* iteration_buffer = local_buffer + i * storage_values; + __half2* iter_err_buffer = err_buffer + i * storage_values; + + if (i * stride + elem_offset < elems_per_out_group) { + // ----------- Begin Error-Feedback Modification ----------- + int8_t local_output[elems_per_load]; + quantize::_chunk(local_output, iteration_buffer, params); + mem_access::store_global(reduced_data + iter_offset, local_output); + + // Dequantize the quantized output to compute the dequantized value + __half2 dequant_buffer[storage_values]; + dequantize::chunk(dequant_buffer, local_output, de_params); + +#pragma unroll + for (int k = 0; k < storage_values; k++) { + // __half2 to float2 + float2 iter_buf_f = __half22float2(iteration_buffer[k]); + float2 dequant_buf_f = __half22float2(dequant_buffer[k]); + + // Update within float precision + float2 new_error_f; + new_error_f.x = iter_buf_f.x - dequant_buf_f.x; + new_error_f.y = iter_buf_f.y - dequant_buf_f.y; + + float2 iter_err_buf_f = __half22float2(iter_err_buffer[k]); + + iter_err_buf_f.x = err_beta * iter_err_buf_f.x + (1.0f - err_beta) * new_error_f.x; + iter_err_buf_f.y = err_beta * iter_err_buf_f.y + (1.0f - err_beta) * new_error_f.y; + + // float2 back to __half2 + iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f); + } + mem_access::store_global(error_feedback + iter_offset_err, + iter_err_buffer); + } + } +} + +#define LAUNCH_LOCO_DEQUANT_REDUCE(num_chunks) \ + loco_dequant_reduce \ + <<>>(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_tensors, \ + error_feedback, \ + err_beta); + +template +void launch_loco_dequant_reduce_impl(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + int num_tensors, + __half2* error_feedback, + const float err_beta, + cudaStream_t stream) +{ + constexpr int elems_per_thread = numBits; + const int one_step_threads = + next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread)); + const int threads = (one_step_threads < 1024) ? one_step_threads : 1024; + + dim3 block(threads); + dim3 grid(out_groups); + + const int elems_per_step = threads * elems_per_thread; + const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step; + + const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw; + + if (unroll == 1) { + LAUNCH_LOCO_DEQUANT_REDUCE(1); + } else if (unroll == 2) { + LAUNCH_LOCO_DEQUANT_REDUCE(2); + } else if (unroll == 3) { + LAUNCH_LOCO_DEQUANT_REDUCE(3); + } else if (unroll == 4) { + LAUNCH_LOCO_DEQUANT_REDUCE(4); + } else if (unroll == 6) { + LAUNCH_LOCO_DEQUANT_REDUCE(6); + } else if (unroll == 8) { + LAUNCH_LOCO_DEQUANT_REDUCE(8); + } else if (unroll == 10) { + LAUNCH_LOCO_DEQUANT_REDUCE(10); + } else if (unroll == 12) { + LAUNCH_LOCO_DEQUANT_REDUCE(12); + } else { + assert(false); + } +} + +#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \ + launch_loco_dequant_reduce_impl(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + out_groups, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_gpus, \ + error_feedback, \ + err_beta, \ + stream); + +void launch_loco_dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int num_gpus, + int num_bits, + quantize::Type quant_type, + int out_groups, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + __half2* error_feedback, + const float err_beta, + cudaStream_t stream) +{ + if (quant_type == quantize::Type::Symmetric) { + if (num_bits == 4) { + if (num_gpus == 8) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric); + } else if (num_gpus == 16) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric); + } else { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric); + } + } else if (num_bits == 8) { + if (num_gpus == 8) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric); + } else if (num_gpus == 16) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric); + } else { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric); + } + } + } else if (quant_type == quantize::Type::Asymmetric) { + if (num_bits == 4) { + if (num_gpus == 8) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric); + } else if (num_gpus == 16) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric); + } else { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric); + } + } else if (num_bits == 8) { + if (num_gpus == 8) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric); + } else if (num_gpus == 16) { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric); + } else { + LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric); + } + } + } +} diff --git a/csrc/quantization/quantize_intX.cu b/csrc/quantization/quantize_intX.cu new file mode 100644 index 000000000000..b26151ab5c8c --- /dev/null +++ b/csrc/quantization/quantize_intX.cu @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include "memory_access_utils.h" + +template +struct alignas(sizeof(T) * N) AlignedArray { + using Element = T; + static const int kElements = N; + + __device__ __host__ AlignedArray() {} + + __device__ __host__ AlignedArray(const T& rhs) + { +#pragma unroll + for (int idx = 0; idx < kElements; ++idx) { this->at(idx) = rhs; } + } + + __device__ __host__ T& operator[](int offset) + { + return reinterpret_cast(this->buffer[offset]); + } + + __device__ __host__ const T& operator[](int offset) const + { + return reinterpret_cast(this->buffer[offset]); + } + + __device__ __host__ T& at(int offset) { return reinterpret_cast(this->buffer[offset]); } + + __device__ __host__ const T& at(int offset) const + { + return reinterpret_cast(this->buffer[offset]); + } + + __device__ __host__ AlignedArray operator+(const AlignedArray& rhs) const + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < kElements; ++idx) { ret[idx] = this->at(idx) + rhs.at(idx); } + + return ret; + } + + __device__ __forceinline__ void clear() + { +#pragma unroll + for (int idx = 0; idx < kElements; ++idx) { this->at(idx) = Element(0); } + } + + Element buffer[N]; +}; + +template +struct reduce_max { + __device__ __forceinline__ T operator()(const T& lhs, const T& rhs) + { + return lhs > rhs ? lhs : rhs; + } +}; + +template +struct reduce_min { + __device__ __forceinline__ T operator()(const T& lhs, const T& rhs) + { + return lhs < rhs ? lhs : rhs; + } +}; + +template +struct subtract { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs, + const T& rhs) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] - rhs; } + + return ret; + } +}; + +template +struct plus { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs, + const T& rhs) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] + rhs; } + + return ret; + } +}; + +template +struct multiply { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs, + const T& rhs) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] * rhs; } + + return ret; + } +}; + +template +struct clamp { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs, + const T& min_val, + const T& max_val) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { + ret[idx] = reduce_max()(reduce_min()(lhs[idx], max_val), min_val); + } + + return ret; + } +}; + +template +struct round_int; + +template +struct round_int { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { ret[idx] = hrint(lhs[idx]); } + + return ret; + } +}; + +template +struct divide { + __device__ __forceinline__ AlignedArray operator()(const AlignedArray& lhs, + const T& rhs) + { + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; ++idx) { ret[idx] = lhs[idx] / rhs; } + + return ret; + } +}; + +template +__device__ __forceinline__ T to_scalar(const AlignedArray& data) +{ + Reducer re; + T res = data[0]; + +#pragma unroll + for (int idx = 1; idx < N; ++idx) { res = re(res, data[idx]); } + + return res; +} + +template +__device__ __forceinline__ AlignedArray int4_to_half( + const AlignedArray& data) +{ + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N * 2; idx += 2) { + ret[idx] = half(int(data[idx / 2] >> 4)); + ret[idx + 1] = half(int(data[idx / 2] & 0xf)); + } + + return ret; +} + +__global__ void dequantize_int4_to_half(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size) +{ + using AccessType = AlignedArray; + using AccessTypeOut = AlignedArray; + + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8; + idx += blockDim.x * gridDim.x) { + int id_group = idx / (group_size / 8); + AccessType value = reinterpret_cast(data_in)[idx]; + half scale = scale_buffer[id_group]; + half min_value = min_val_buffer[id_group]; + + AccessTypeOut output = int4_to_half(value); + output = divide()(output, scale); + output = plus()(output, min_value); + + reinterpret_cast(data_out)[idx] = output; + } +} + +void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream) +{ + int num_warp = num_group / 4; + int num_block = num_warp / 8; // 256 trd / block + + dequantize_int4_to_half<<>>( + data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); +} + +template +__device__ __forceinline__ AlignedArray int8_to_half(const AlignedArray& data) +{ + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); } + + return ret; +} + +__global__ void dequantize_int8_to_half(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size) +{ + using AccessType = AlignedArray; + using AccessTypeOut = AlignedArray; + + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8; + idx += blockDim.x * gridDim.x) { + int id_group = idx / (group_size / 8); + AccessType value = reinterpret_cast(data_in)[idx]; + half scale = scale_buffer[id_group]; + half min_value = min_val_buffer[id_group]; + + AccessTypeOut output = int8_to_half(value); + output = divide()(output, scale); + output = plus()(output, min_value); + + reinterpret_cast(data_out)[idx] = output; + } +} + +void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream) +{ + int num_warp = num_group / 4; + int num_block = num_warp / 8; // 256 trd / block + + dequantize_int8_to_half<<>>( + data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); +} diff --git a/csrc/quantization/swizzled_quantize.cu b/csrc/quantization/swizzled_quantize.cu new file mode 100644 index 000000000000..a4b6096c81af --- /dev/null +++ b/csrc/quantization/swizzled_quantize.cu @@ -0,0 +1,427 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "dequantization_utils.h" +#include "memory_access_utils.h" +#include "quantization_utils.h" +#include "reduction_utils.h" + +using rop = reduce::ROpType; + +namespace swiz_quant { +constexpr int max_threads = 512; +constexpr int min_threads = 32; + +constexpr int step_granularity = 2; +constexpr int h_per_step = step_granularity * quantize::h_per_load; +} // namespace swiz_quant + +template +__global__ void swizzled_quant_kernel(int8_t* quantized_data, + float* quantized_scales, + const __half* uncompressed_data, + int elems_per_group, + int nodes, + int devices_per_node) +{ + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Indexing offsets, same as normal quantization for in-case + const int block_rank = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + const int block_offset = block_rank * elems_per_group; + const int elem_offset = tb.thread_index().x * quantize::h_per_load; + const int base_offset = block_offset + elem_offset; + const int stride = tb.size() * quantize::h_per_load; + const __half* input_base = uncompressed_data + base_offset; + + // Local buffer + __half2 local_buffer[totalChunks * quantize::h2_per_load]; + + quantize::GroupStats stats; +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + __half2* iteration_buffer = local_buffer + i * quantize::h2_per_load; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, elem_offset + i * stride < elems_per_group); + +#pragma unroll + for (int j = 0; j < quantize::h2_per_load; j++) { stats.update(iteration_buffer[j]); } + } + + auto params = stats.template get_params(tb, warp); + + const int partition_id = blockIdx.z; + const int partition_offset = partition_id / devices_per_node; + const int partition_base = (partition_id % devices_per_node) * nodes; + const int pipelining_offset = blockIdx.y * (devices_per_node * nodes); + const int output_partition = (pipelining_offset + partition_base + partition_offset); + + constexpr int out_scalar_effect = 8 / numBits; + const int out_block_rank = output_partition * gridDim.x + blockIdx.x; + const int out_block_offset = out_block_rank * elems_per_group / out_scalar_effect; + const int out_base_offset = out_block_offset + elem_offset / out_scalar_effect; + int8_t* out_base = quantized_data + out_base_offset; + + const int out_stride = stride / out_scalar_effect; + constexpr int num_int8_out = quantize::h_per_load / out_scalar_effect; + + if (tb.thread_index().x == 0) { params.store(quantized_scales, out_block_rank); } + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + if (i * stride + elem_offset < elems_per_group) { + int8_t local_output[quantize::h_per_load / out_scalar_effect]; + quantize::_chunk( + local_output, local_buffer + i * quantize::h2_per_load, params); + mem_access::store_global(out_base + i * out_stride, local_output); + } + } +} + +#define LAUNCH_SWIZZLE_QUANT(total_chunks, threads) \ + swizzled_quant_kernel<<>>( \ + q_data, q_scales, input_data, elems_per_group, nodes, devices_per_node); + +/* +Swizzled quantization reorganizes the quantized groups in order to better facilitate +communication. As an example of the partitioning scheme we have the following example +of 2 node, 4 device swizzling: + + --- --- --- --- --- --- --- --- +| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + --- --- --- --- --- --- --- --- +becomes + --- --- --- --- --- --- --- --- +| 0 | 4 | 1 | 5 | 2 | 6 | 3 | 7 | + --- --- --- --- --- --- --- --- + +Multiple quantization groups may be mapped into a single partition. In order to better support +later pipelining, we may also perform an additional slicing. In two-way slicing, for instance, +the first halves of each partition are concatenated. +*/ + +template +void launch_swizzled_quant_impl(int8_t* q_data, + float* q_scales, + const __half* input_data, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream) +{ + const int one_step_threads = + next_pow2((elems_per_group + swiz_quant::h_per_step - 1) / swiz_quant::h_per_step); + const int max_threads = (one_step_threads < swiz_quant::max_threads) ? one_step_threads + : swiz_quant::max_threads; + const int threads = (max_threads < swiz_quant::min_threads) ? swiz_quant::min_threads + : max_threads; + + dim3 block(threads); + const int groups_per_partition = groups / (nodes * devices_per_node); + assert(groups_per_partition % pipelining == 0); + const int contiguous_groups = groups_per_partition / pipelining; + const int partitions = nodes * devices_per_node; + dim3 grid(contiguous_groups, pipelining, partitions); + + const int elems_per_step = threads * swiz_quant::h_per_step; + const int external_unroll = ((elems_per_group + elems_per_step - 1) / elems_per_step); + const int total_unroll = external_unroll * swiz_quant::step_granularity; + + assert(total_unroll % 2 == 0); + + if (threads == 32) { + LAUNCH_SWIZZLE_QUANT(2, 32); + } else if (threads == 64) { + LAUNCH_SWIZZLE_QUANT(2, 64); + } else if (threads == 128) { + LAUNCH_SWIZZLE_QUANT(2, 128); + } else if (threads == 256) { + LAUNCH_SWIZZLE_QUANT(2, 256); + } else if (threads == 512) { + if (total_unroll == 2) { + LAUNCH_SWIZZLE_QUANT(2, 512); + } else if (total_unroll == 4) { + LAUNCH_SWIZZLE_QUANT(4, 512); + } else if (total_unroll == 6) { + LAUNCH_SWIZZLE_QUANT(6, 512); + } else if (total_unroll == 8) { + LAUNCH_SWIZZLE_QUANT(8, 512); + } else if (total_unroll == 10) { + LAUNCH_SWIZZLE_QUANT(10, 512); + } + } +} + +#define DISPATCH_SWIZZLE_QUANT(num_bits, qtype) \ + launch_swizzled_quant_impl(q_data, \ + q_scales, \ + input_data, \ + groups, \ + elems_per_group, \ + pipelining, \ + nodes, \ + devices_per_node, \ + stream); + +void launch_swizzled_quant(int8_t* q_data, + float* q_scales, + const __half* input_data, + int num_bits, + quantize::Type q_type, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream) +{ + if (num_bits == 4) { + if (q_type == quantize::Type::Asymmetric) { + DISPATCH_SWIZZLE_QUANT(4, quantize::Type::Asymmetric); + } else if (q_type == quantize::Type::Symmetric) { + DISPATCH_SWIZZLE_QUANT(4, quantize::Type::Symmetric); + } + } else if (num_bits == 8) { + if (q_type == quantize::Type::Asymmetric) { + DISPATCH_SWIZZLE_QUANT(8, quantize::Type::Asymmetric); + } else if (q_type == quantize::Type::Symmetric) { + DISPATCH_SWIZZLE_QUANT(8, quantize::Type::Symmetric); + } + } +} + +template +__global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, + float* quantized_scales, + const __half* uncompressed_data, + __half* error_feedback, + const float err_beta, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node) +{ + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Indexing offsets, same as normal quantization for in-case + const int block_rank_data = + blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + const int block_offset_data = block_rank_data * elems_per_group; + const int elem_offset = tb.thread_index().x * quantize::h_per_load; + const int base_offset_data = block_offset_data + elem_offset; + const int stride = tb.size() * quantize::h_per_load; + const __half* uncompressed_data_base = uncompressed_data + base_offset_data; + + const int partition_id = blockIdx.z; + const int partition_offset = partition_id / devices_per_node; + const int partition_base = (partition_id % devices_per_node) * nodes; + const int pipelining_offset = blockIdx.y * (devices_per_node * nodes); + const int output_partition = (pipelining_offset + partition_base + partition_offset); + const int block_rank_err = output_partition * gridDim.x + blockIdx.x; + + const int block_offset_err = block_rank_err * elems_per_group; + const int base_offset_err = block_offset_err + elem_offset; + __half* error_feedback_base = error_feedback + base_offset_err; + + __half2 local_buffer[totalChunks * quantize::h2_per_load]; + __half2 err_buffer[totalChunks * quantize::h2_per_load]; + + quantize::GroupStats stats; + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + __half2* iteration_buffer = local_buffer + i * quantize::h2_per_load; + __half2* iter_err_buffer = err_buffer + i * quantize::h2_per_load; + const int i_stride = i * stride; + bool do_loads = (elem_offset + i_stride) < elems_per_group; + + mem_access::load_global( + iteration_buffer, uncompressed_data_base + i_stride, do_loads); + + mem_access::load_global( + iter_err_buffer, error_feedback_base + i_stride, do_loads); + +#pragma unroll + for (int j = 0; j < quantize::h2_per_load; j++) { + iteration_buffer[j] = __hadd2(iteration_buffer[j], iter_err_buffer[j]); + stats.update(iteration_buffer[j]); + } + } + + auto params = stats.template get_params(tb, warp); + + // Initialize dequantization parameters based on params + auto de_params = params; + de_params.scale = 1.0f / params.scale; + if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; } + + if (threadIdx.x == 0) { params.store(quantized_scales, block_rank_err); } + + constexpr int out_scalar_effect = 8 / numBits; + const int out_block_offset = block_rank_err * elems_per_group / out_scalar_effect; + const int out_base_offset = out_block_offset + elem_offset / out_scalar_effect; + int8_t* out_base = quantized_data + out_base_offset; + + const int out_stride = stride / out_scalar_effect; + constexpr int num_int8_out = quantize::h_per_load / out_scalar_effect; + +#pragma unroll + for (int i = 0; i < totalChunks; i++) { + const int i_stride = i * stride; + __half2* iteration_buffer = local_buffer + i * quantize::h2_per_load; + __half2* iter_err_buffer = err_buffer + i * quantize::h2_per_load; + + if (i_stride + elem_offset < elems_per_group) { + int8_t local_output[quantize::h_per_load / out_scalar_effect]; + quantize::_chunk(local_output, iteration_buffer, params); + mem_access::store_global(out_base + i * out_stride, local_output); + + // Dequantize the quantized output to compute the dequantized value + __half2 dequant_buffer[quantize::h2_per_load]; + dequantize::chunk(dequant_buffer, local_output, de_params); + +// Compute new error: sum - dequant_buffer +#pragma unroll + for (int k = 0; k < quantize::h2_per_load; k++) { + // __half2 to float2 + float2 iter_buf_f = __half22float2(iteration_buffer[k]); + float2 dequant_buf_f = __half22float2(dequant_buffer[k]); + + // Update within float precision + float2 new_error_f; + new_error_f.x = iter_buf_f.x - dequant_buf_f.x; + new_error_f.y = iter_buf_f.y - dequant_buf_f.y; + + float2 iter_err_buf_f = __half22float2(iter_err_buffer[k]); + + iter_err_buf_f.x = err_beta * iter_err_buf_f.x + (1.0f - err_beta) * new_error_f.x; + iter_err_buf_f.y = err_beta * iter_err_buf_f.y + (1.0f - err_beta) * new_error_f.y; + + // float2 back to __half2 + iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f); + } + __half2* error_feedback_base_h2 = reinterpret_cast<__half2*>(error_feedback_base); + mem_access::store_global(error_feedback_base_h2 + i_stride / 2, + iter_err_buffer); + } + } +} + +#define LAUNCH_LOCO_SWIZZLE_QUANT(total_chunks, threads) \ + loco_swizzled_quant_kernel \ + <<>>(output_data, \ + params, \ + input_data, \ + error_feedback, \ + err_beta, \ + groups, \ + elems_per_group, \ + pipelining, \ + nodes, \ + devices_per_node); + +template +void launch_loco_swizzled_quant_impl(int8_t* output_data, + float* params, + const __half* input_data, + __half* error_feedback, + const float err_beta, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream) +{ + const int one_step_threads = + next_pow2((elems_per_group + swiz_quant::h_per_step - 1) / swiz_quant::h_per_step); + const int max_threads = (one_step_threads < swiz_quant::max_threads) ? one_step_threads + : swiz_quant::max_threads; + const int threads = (max_threads < swiz_quant::min_threads) ? swiz_quant::min_threads + : max_threads; + + dim3 block(threads); + const int groups_per_partition = groups / (nodes * devices_per_node); + assert(groups_per_partition % pipelining == 0); + const int contiguous_groups = groups_per_partition / pipelining; + const int partitions = nodes * devices_per_node; + dim3 grid(contiguous_groups, pipelining, partitions); + + const int elems_per_step = threads * swiz_quant::h_per_step; + const int external_unroll = ((elems_per_group + elems_per_step - 1) / elems_per_step); + const int total_unroll = external_unroll * swiz_quant::step_granularity; + + assert(total_unroll % 2 == 0); + + if (threads == 32) { + LAUNCH_LOCO_SWIZZLE_QUANT(2, 32); + } else if (threads == 64) { + LAUNCH_LOCO_SWIZZLE_QUANT(2, 64); + } else if (threads == 128) { + LAUNCH_LOCO_SWIZZLE_QUANT(2, 128); + } else if (threads == 256) { + LAUNCH_LOCO_SWIZZLE_QUANT(2, 256); + } else if (threads == 512) { + if (total_unroll == 2) { + LAUNCH_LOCO_SWIZZLE_QUANT(2, 512); + } else if (total_unroll == 4) { + LAUNCH_LOCO_SWIZZLE_QUANT(4, 512); + } else if (total_unroll == 6) { + LAUNCH_LOCO_SWIZZLE_QUANT(6, 512); + } else if (total_unroll == 8) { + LAUNCH_LOCO_SWIZZLE_QUANT(8, 512); + } else if (total_unroll == 10) { + LAUNCH_LOCO_SWIZZLE_QUANT(10, 512); + } + } +} + +#define DISPATCH_LOCO_SWIZZLE_QUANT(num_bits, qtype) \ + launch_loco_swizzled_quant_impl(output_data, \ + params, \ + input_data, \ + error_feedback, \ + err_beta, \ + groups, \ + elems_per_group, \ + pipelining, \ + nodes, \ + devices_per_node, \ + stream); + +void launch_loco_swizzled_quant(int8_t* output_data, + float* params, + const __half* input_data, + __half* error_feedback, + const float err_beta, + int num_bits, + quantize::Type q_type, + int groups, + int elems_per_group, + int pipelining, + int nodes, + int devices_per_node, + cudaStream_t stream) +{ + if (num_bits == 4) { + if (q_type == quantize::Type::Asymmetric) { + DISPATCH_LOCO_SWIZZLE_QUANT(4, quantize::Type::Asymmetric); + } else if (q_type == quantize::Type::Symmetric) { + DISPATCH_LOCO_SWIZZLE_QUANT(4, quantize::Type::Symmetric); + } + } else if (num_bits == 8) { + if (q_type == quantize::Type::Asymmetric) { + DISPATCH_LOCO_SWIZZLE_QUANT(8, quantize::Type::Asymmetric); + } else if (q_type == quantize::Type::Symmetric) { + DISPATCH_LOCO_SWIZZLE_QUANT(8, quantize::Type::Symmetric); + } + } +} diff --git a/csrc/random_ltd/token_sort.cu b/csrc/random_ltd/token_sort.cu index aae1b93e1c24..3c1dff49429f 100644 --- a/csrc/random_ltd/token_sort.cu +++ b/csrc/random_ltd/token_sort.cu @@ -15,8 +15,8 @@ constexpr int granularity = 16; constexpr int mem_vals = granularity / sizeof(int32_t); constexpr int max_buffer_size = (threads + 1) * mem_vals; -#ifdef __HIP_PLATFORM_HCC__ -constexpr int warp_size = 64; +#ifdef __HIP_PLATFORM_AMD__ +constexpr int warp_size = ROCM_WAVEFRONT_SIZE; #else constexpr int warp_size = 32; #endif diff --git a/csrc/spatial/includes/spatial_cuda_layers.h b/csrc/spatial/includes/spatial_cuda_layers.h index 16eaa4abde60..4f56f89f16cd 100644 --- a/csrc/spatial/includes/spatial_cuda_layers.h +++ b/csrc/spatial/includes/spatial_cuda_layers.h @@ -9,7 +9,12 @@ #define HALF_PRECISION_AVAILABLE = 1 #endif +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif + #include #include diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu index b48b7086921a..d982e65b8a81 100644 --- a/csrc/transformer/cublas_wrappers.cu +++ b/csrc/transformer/cublas_wrappers.cu @@ -5,7 +5,9 @@ #include "cublas_wrappers.h" -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -33,7 +35,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -67,20 +70,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -96,7 +118,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -124,7 +147,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -158,20 +182,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, (void*)C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -187,7 +230,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -223,7 +267,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -263,24 +308,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -297,7 +361,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -333,7 +398,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -373,24 +439,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index 80cbd72d09a4..b637bb710c67 100644 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -44,7 +44,7 @@ unsigned get_workspace_size(unsigned maxBatchSize, } // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ @@ -145,7 +145,7 @@ BertTransformerLayer::~BertTransformerLayer() template void BertTransformerLayer::Initialize() { -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); #endif } diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index a4193da94702..a987eec5ef0b 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -162,7 +162,7 @@ void launch_fused_add2(float* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } @@ -179,7 +179,7 @@ void launch_fused_add2<__half>(__half* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 38b57951093d..bbb8a7f00b1f 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -3,207 +3,131 @@ // DeepSpeed Team +#include "conversion_utils.h" +#ifdef __HIP_PLATFORM_AMD__ +#include "hip/hip_cooperative_groups.h" +#else +#include "cooperative_groups.h" +#endif +#include "ds_kernel_utils.h" #include "inference_cuda_layers.h" +#include "memory_access_utils.h" -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif namespace cg = cooperative_groups; -namespace cg = cooperative_groups; - -__global__ void apply_rotary_pos_emb(float* mixed_query, - float* key_layer, - unsigned rotary_dim, - unsigned seq_len, - unsigned seq_offset, - unsigned num_heads, - unsigned head_size, - unsigned total_count, - int max_out_tokens) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int id = threadIdx.x; - int gid = id >> 5; - int lane = id & 0x1f; - - unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; - unsigned offset = head_id * head_size; - - unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; - unsigned seq_index = head_id % seq_len; - unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; - - if (head_id < total_count) { - while (lane < rotary_dim) { - float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = mixed_query[offset + lane]; - float k = key_layer[k_offset + lane]; - float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); - float q_rot = (q * rotary_sign); - float k_rot = (k * rotary_sign); - q_rot = g.shfl_xor(q_rot, 1); - k_rot = g.shfl_xor(k_rot, 1); - q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); - k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); - - mixed_query[offset + lane] = q; - key_layer[k_offset + lane] = k; - - lane += WARP_SIZE; - } - } -} - -__global__ void apply_rotary_pos_emb(__half* mixed_query, - __half* key_layer, - unsigned rotary_dim, - unsigned seq_len, - unsigned seq_offset, - unsigned num_heads, - unsigned head_size, - unsigned total_count, - int max_out_tokens) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int id = threadIdx.x; - int gid = id >> 5; - int lane = id & 0x1f; - - unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; - unsigned offset = head_id * head_size; - - unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; - unsigned seq_index = head_id % seq_len; - unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; - - if (head_id < total_count) { - while (lane < rotary_dim) { - float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[k_offset + lane]; - float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); - float q_rot = (q * rotary_sign); - float k_rot = (k * rotary_sign); - q_rot = g.shfl_xor(q_rot, 1); - k_rot = g.shfl_xor(k_rot, 1); - q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); - k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); - mixed_query[offset + lane] = (__half)q; - key_layer[k_offset + lane] = (__half)k; +namespace rot_half { +constexpr int threads = 256; +} // namespace rot_half - lane += WARP_SIZE; - } - } -} -__global__ void apply_rotary_pos_emb1(float* mixed_query, - float* key_layer, +template +__global__ void apply_rotary_pos_half(T* mixed_query, + T* key_layer, unsigned rotary_dim, unsigned seq_len, unsigned seq_offset, unsigned num_heads, unsigned head_size, unsigned total_count, + float rope_theta, int max_out_tokens) { - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int id = threadIdx.x; - int gid = id >> 5; - int lane = id & 0x1f; - - unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; - unsigned offset = head_id * head_size; - - unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; - unsigned seq_index = head_id % seq_len; - unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; - - if (head_id < total_count) { - while (lane < rotary_dim) { - float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = mixed_query[offset + lane]; - float k = key_layer[k_offset + lane]; - float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); - float q_rot = (q * rotary_sign); - float k_rot = (k * rotary_sign); - q_rot = g.shfl_xor(q_rot, 1); - k_rot = g.shfl_xor(k_rot, 1); - q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); - k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); - - mixed_query[offset + lane] = q; - key_layer[k_offset + lane] = k; - - lane += WARP_SIZE; + constexpr int T_per_thread = granularity / sizeof(T); + constexpr int heads_per_block = rot_half::threads / threadsPerHead; + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile head_group = cg::tiled_partition(tb); + + const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead; + const int cur_seq_idx = head_idx % seq_len; + const int offset = head_idx * head_size; + const int k_offset = (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; + + const int seq_idx = cur_seq_idx + seq_offset; + const int half_dim = rotary_dim >> 1; + const int half_dim_threads = half_dim / T_per_thread; + + if (head_idx < total_count) { + const int base_neuron_idx = head_group.thread_rank() * T_per_thread; + + T q[T_per_thread], k[T_per_thread]; + mem_access::load_global(q, mixed_query + offset + base_neuron_idx); + mem_access::load_global(k, key_layer + k_offset + base_neuron_idx); + +#pragma unroll + for (int i = 0; i < T_per_thread; i++) { + const int neuron_idx = base_neuron_idx + i; + if (neuron_idx < rotary_dim) { + float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx; + + float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); + float q_rot = conversion::to(q[i]) * rotary_sign; + float k_rot = conversion::to(k[i]) * rotary_sign; + + const int target_lane = (neuron_idx < half_dim) + ? head_group.thread_rank() + half_dim_threads + : head_group.thread_rank() - half_dim_threads; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); + + q[i] = conversion::to(conversion::to(q[i]) * cosf(inv_freq) + + q_rot_temp * sinf(inv_freq)); + k[i] = conversion::to(conversion::to(k[i]) * cosf(inv_freq) + + k_rot_temp * sinf(inv_freq)); + } } + + mem_access::store_global(mixed_query + offset + base_neuron_idx, q); + mem_access::store_global(key_layer + k_offset + base_neuron_idx, k); } } -__global__ void apply_rotary_pos_emb1(__half* mixed_query, - __half* key_layer, - unsigned rotary_dim, - unsigned seq_len, - unsigned seq_offset, - unsigned num_heads, - unsigned head_size, - unsigned total_count, - int max_out_tokens) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - int id = threadIdx.x; - int gid = id >> 5; - int lane = id & 0x1f; - - unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; - unsigned seq_index = head_id % seq_len; - unsigned offset = head_id * head_size; - unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; - - constexpr unsigned mask[32] = { - 0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000, - 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000, - 0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, - 0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80, - 0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000, - 0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000, - 0x40000000, 0x80000000}; - - unsigned seq_id = (head_id % seq_len) + seq_offset; - unsigned half_dim = rotary_dim >> 1; - if (head_id < total_count) { - while (lane < rotary_dim) { - float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[k_offset + lane]; - float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0); - float q_rot = (q * rotary_sign); - float k_rot = (k * rotary_sign); - auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim) - : __shfl_sync(mask[lane], q_rot, lane - half_dim); - auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim) - : __shfl_sync(mask[lane], k_rot, lane - half_dim); - q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); - k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); - - mixed_query[offset + lane] = (__half)q; - key_layer[k_offset + lane] = (__half)k; - - lane += WARP_SIZE; - } +#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ + apply_rotary_pos_half<<>>(mixed_query, \ + key_layer, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + rope_theta, \ + max_out_tokens); + +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else if (threads_per_head == 64) { \ + LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ + } else { \ + assert(false); \ } -} +#else +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#endif template void launch_apply_rotary_pos_emb(T* mixed_query, @@ -214,193 +138,62 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, - bool rotate_half, - bool rotate_every_two, + float rope_theta, cudaStream_t stream, int max_out_tokens) { - int total_count = batch * num_heads * seq_len; - dim3 block_dims(1024); - dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); - if (rotate_every_two) - apply_rotary_pos_emb<<>>(mixed_query, - key_layer, - rotary_dim, - seq_len, - offset, - num_heads, - head_size, - total_count, - max_out_tokens); - else if (rotate_half) - apply_rotary_pos_emb1<<>>(mixed_query, - key_layer, - rotary_dim, - seq_len, - offset, - num_heads, - head_size, - total_count, - max_out_tokens); -} - -template void launch_apply_rotary_pos_emb(float*, - float*, - unsigned, - unsigned, - unsigned, - unsigned, - unsigned, - unsigned, - bool, - bool, - cudaStream_t, - int); -template void launch_apply_rotary_pos_emb<__half>(__half*, - __half*, - unsigned, - unsigned, - unsigned, - unsigned, - unsigned, - unsigned, - bool, - bool, - cudaStream_t, - int); - -/* -__global__ void apply_rotary_pos_emb(float* mixed_query, -float* key_layer, -unsigned rotary_dim, -unsigned seq_len, -unsigned seq_offset, -unsigned num_heads, -unsigned head_size, -unsigned total_count) -{ -cg::thread_block b = cg::this_thread_block(); -cg::thread_block_tile g = cg::tiled_partition(b); + const int half_dim = rotary_dim >> 1; + + int alignment = sizeof(T); + if (half_dim % (16 / sizeof(T)) == 0) { + alignment = 16; + } else if (half_dim % (8 / sizeof(T)) == 0) { + alignment = 8; + } else if (half_dim % (4 / sizeof(T)) == 0) { + alignment = 4; + } else { + assert(false); + } + const int T_per_elem = alignment / sizeof(T); -int id = threadIdx.x; -int gid = id >> 5; -int lane = id & 0x1f; + int total_count = batch * num_heads * seq_len; -unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; -unsigned offset = head_id * head_size; + const int padded_head_size = next_pow2(head_size); -unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + assert(padded_head_size <= hw_warp_size * T_per_elem); -if (head_id < total_count) { -while (lane < rotary_dim) { -float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; -inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; -float q = mixed_query[offset + lane]; -float k = key_layer[offset + lane]; -float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); -float q_rot = (q * rotary_sign); -float k_rot = (k * rotary_sign); -q_rot = g.shfl_xor(q_rot, 1); -k_rot = g.shfl_xor(k_rot, 1); -q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); -k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); + const int threads_per_head = padded_head_size / T_per_elem; + const int heads_per_block = rot_half::threads / threads_per_head; -mixed_query[offset + lane] = q; -key_layer[offset + lane] = k; + dim3 block(rot_half::threads); + dim3 grid((total_count + heads_per_block - 1) / heads_per_block); -lane += WARP_SIZE; -} -} + if (alignment == 4) { + LAUNCH_FOR_ALIGNMENT(4); + } else if (alignment == 8) { + LAUNCH_FOR_ALIGNMENT(8); + } else if (alignment == 16) { + LAUNCH_FOR_ALIGNMENT(16); + } else { + assert(false); + } } -__global__ void apply_rotary_pos_emb(__half* mixed_query, -__half* key_layer, -unsigned rotary_dim, -unsigned seq_len, -unsigned seq_offset, -unsigned num_heads, -unsigned head_size, -unsigned total_count) -{ -#if __CUDA_ARCH__ >= 700 -cg::thread_block b = cg::this_thread_block(); -cg::thread_block_tile g = cg::tiled_partition(b); - -int id = threadIdx.x; -int gid = id >> 5; -int lane = id & 0x1f; - -unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; -unsigned offset = head_id * head_size; -constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, -0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, -0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000, -0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8, -0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80, -0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, -0x1000000, 0x2000000, 0x4000000, 0x8000000, -0x10000000, 0x20000000, 0x40000000, 0x80000000}; -unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; - -if (head_id < total_count) { -while (lane < rotary_dim) { -//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; -float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim; -inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; -float q = (float)mixed_query[offset + lane]; -float k = (float)key_layer[offset + lane]; -float rotary_sign = (lane > 11 ? -1.0 : 1.0); -float q_rot = (q * rotary_sign); -float k_rot = (k * rotary_sign); -auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane], -q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], -k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q * -cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); - -mixed_query[offset + lane] = (__half)q; -key_layer[offset + lane] = (__half)k; +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + cudaStream_t, \ + int); -lane += WARP_SIZE; -} -} +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__nv_bfloat16); #endif -} - -template -void launch_apply_rotary_pos_emb(T* mixed_query, -T* key_layer, -unsigned head_size, -unsigned seq_len, -unsigned rotary_dim, -unsigned offset, -unsigned num_heads, -unsigned batch, -cudaStream_t stream) -{ -int total_count = batch * num_heads * seq_len; -dim3 block_dims(1024); -dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); - -apply_rotary_pos_emb<<>>( -mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); -} - -template void launch_apply_rotary_pos_emb(float*, -float*, -unsigned, -unsigned, -unsigned, -unsigned, -unsigned, -unsigned, -cudaStream_t); -template void launch_apply_rotary_pos_emb<__half>(__half*, -__half*, -unsigned, -unsigned, -unsigned, -unsigned, -unsigned, -unsigned, -cudaStream_t); -*/ +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__half); diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 539d4de60929..7a8e7ca446b0 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -3,6 +3,7 @@ // DeepSpeed Team +#include "conversion_utils.h" #include "inference_cuda_layers.h" #define MAX_QUANTIZE_GROUPING 1024 @@ -10,7 +11,8 @@ #define loop_unroll 1 #define loop_unroll_bits 1 -__global__ void dequantize_kernel(float* output, +template +__global__ void dequantize_kernel(T* output, const int8_t* input, const float* qscale, int output_size, @@ -38,40 +40,7 @@ __global__ void dequantize_kernel(float* output, float scale_data = qscale[scale_index]; - output[q_index] = (scale_data * (float)q); - tid += blockDim.x; - } -} - -__global__ void dequantize_kernel(__half* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count) -{ - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = __float2half(scale_data * (float)q); + output[q_index] = conversion::to(scale_data * (float)q); tid += blockDim.x; } } @@ -94,22 +63,15 @@ void launch_dequantize(T* output, output, input, qscale, output_size, hidden_dim, groups, merge_count); } -template void launch_dequantize(float*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); -template void launch_dequantize<__half>(__half*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); +#define INSTANTIATE_DEQUANTIZE_MERGE(T) \ + template void launch_dequantize( \ + T*, const int8_t*, const float*, unsigned, unsigned, unsigned, unsigned, cudaStream_t); + +INSTANTIATE_DEQUANTIZE_MERGE(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_DEQUANTIZE_MERGE(__nv_bfloat16); +#endif +INSTANTIATE_DEQUANTIZE_MERGE(__half); __global__ void dequantize_kernel(float* output, const int8_t* input, @@ -120,7 +82,8 @@ __global__ void dequantize_kernel(float* output, { } -__global__ void dequantize_kernel(__half* output, +template +__global__ void dequantize_kernel(T* output, const int8_t* input, const float* qscale, unsigned hidden_dim, @@ -144,12 +107,12 @@ __global__ void dequantize_kernel(__half* output, int8_t* q_int8 = (int8_t*)&q; float2 q_f; - __half* q_h = (__half*)&q_f; + T* q_h = (T*)&q_f; - q_h[0] = __float2half(local_scale * (float)q_int8[0]); - q_h[1] = __float2half(local_scale * (float)q_int8[1]); - q_h[2] = __float2half(local_scale * (float)q_int8[2]); - q_h[3] = __float2half(local_scale * (float)q_int8[3]); + q_h[0] = conversion::to(local_scale * (float)q_int8[0]); + q_h[1] = conversion::to(local_scale * (float)q_int8[1]); + q_h[2] = conversion::to(local_scale * (float)q_int8[2]); + q_h[3] = conversion::to(local_scale * (float)q_int8[3]); output_cast[tid] = q_f; tid += blockDim.x; } @@ -167,29 +130,24 @@ void launch_dequantize(T* output, { unsigned threads = 1024; hidden_dim /= 4; - unsigned hid_cnt = threads / hidden_dim; unsigned thd_cnt = (hidden_dim - 1) / threads + 1; - hid_cnt = hid_cnt > 0 ? hid_cnt : 1; - unsigned blocks = (output_size + hid_cnt * groups - 1) / (hid_cnt * groups); + assert(output_size % groups == 0); + unsigned blocks = output_size / groups; + dim3 block_dims(threads); dim3 grid_dims(groups, blocks); dequantize_kernel<<>>( - output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt); + output, input, qscale, hidden_dim, hidden_dim, thd_cnt); } -template void launch_dequantize(float*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - cudaStream_t); -template void launch_dequantize<__half>(__half*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - cudaStream_t); +#define INSTANTIATE_DEQUANTIZE_NO_MERGE(T) \ + template void launch_dequantize( \ + T*, const int8_t*, const float*, unsigned, unsigned, unsigned, cudaStream_t); + +INSTANTIATE_DEQUANTIZE_NO_MERGE(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_DEQUANTIZE_NO_MERGE(__nv_bfloat16); +#endif +INSTANTIATE_DEQUANTIZE_NO_MERGE(__half); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index d62b135f509b..97857bc3f70b 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -11,10 +11,21 @@ namespace cg = cooperative_groups; #define MAX_CAP 4 #define MAX_SEQ 2048 +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +#if defined(__CUDA_BF16_H__) +static_assert(sizeof(__nv_bfloat162) == sizeof(__half2), + "CUDA's __nv_bfloat162 doesn't match __half2 size"); +#else +// Fallback to simple typedef only if CUDA doesn't provide it +using __nv_bfloat162 = __half2; +#endif +#endif + inline __device__ float gelu(const float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); } @@ -33,7 +44,8 @@ __global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int in T data[values_per_access]; T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global(data_bias, bias + (offset % intermediate_size)); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll for (int i = 0; i < values_per_access; i++) { @@ -65,8 +77,14 @@ void launch_bias_gelu(T* input, input, bias, total_count, intermediate_size); } -template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); -template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); +#define INSTANTIATE_LAUNCH_BIAS_GELU(T) \ + template void launch_bias_gelu(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_GELU(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_GELU(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_GELU(__half) /* In-place channels-last bias add @@ -83,7 +101,8 @@ __global__ void fused_bias_add(T* input, const T* bias, int total_count, int int T data[values_per_access]; T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global(data_bias, bias + (offset % intermediate_size)); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll for (int i = 0; i < values_per_access; i++) { @@ -115,8 +134,14 @@ void launch_bias_add(T* input, input, bias, total_count, intermediate_size); } -template void launch_bias_add(float*, const float*, int, int, cudaStream_t); -template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); +#define INSTANTIATE_LAUNCH_BIAS_ADD(T) \ + template void launch_bias_add(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_ADD(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_ADD(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_ADD(__half) __global__ void fused_bias_residual(float* residual, const float* hidden_state, @@ -163,16 +188,19 @@ __global__ void fused_bias_residual(float* residual, } } -__global__ void fused_bias_residual(__half* residual, - const __half* hidden_state, - const __half* attn, - const __half* bias, - const __half* attn_bias, +template +__global__ void fused_bias_residual(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, const int total_count, const int intermediate_size, const float mp_scale, const bool preln) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -187,26 +215,26 @@ __global__ void fused_bias_residual(__half* residual, const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - __half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); - const __half2* hs_half2 = reinterpret_cast(&hs_fl2); - const __half2* attn_half2 = reinterpret_cast(&attn_fl2); - const __half2* bias_half2 = reinterpret_cast(&bias_fl2); - const __half2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - float2 res_low = __half22float2(res_half2[0]); - float2 res_high = __half22float2(res_half2[1]); + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); - const float2 hs_low = __half22float2(hs_half2[0]); - const float2 hs_high = __half22float2(hs_half2[1]); + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); - const float2 attn_low = __half22float2(attn_half2[0]); - const float2 attn_high = __half22float2(attn_half2[1]); + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); - const float2 bias_low = __half22float2(bias_half2[0]); - const float2 bias_high = __half22float2(bias_half2[1]); + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); - const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); - const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); if (preln) { // residual = (residual + attention + bias + attention_bias) * @@ -226,8 +254,8 @@ __global__ void fused_bias_residual(__half* residual, res_high.x = (res_high.x + hs_high.x + bias_high.x); res_high.y = (res_high.y + hs_high.y + bias_high.y); } - res_half2[0] = __float22half2_rn(res_low); - res_half2[1] = __float22half2_rn(res_high); + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); res_fl2_ptr[offset] = res_fl2; } @@ -260,10 +288,14 @@ void launch_bias_residual(T* residual, preln); } -template void launch_bias_residual< - float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t); -template void launch_bias_residual< - __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t); +#define INSTANTIATE_LAUNCH_BIAS_RESIDUAL(T) \ + template void launch_bias_residual(T*, T*, T*, T*, T*, int, int, int, bool, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_BIAS_RESIDUAL(__half); __global__ void gptj_residual_add(float* residual, const float* hidden_state, @@ -305,15 +337,18 @@ __global__ void gptj_residual_add(float* residual, } } -__global__ void gptj_residual_add(__half* residual, - const __half* hidden_state, - const __half* attn, - const __half* bias, - const __half* attn_bias, +template +__global__ void gptj_residual_add(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, const int total_count, const int intermediate_size, const float mp_scale) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -327,28 +362,28 @@ __global__ void gptj_residual_add(__half* residual, const float2 attn_fl2 = attn_fl2_ptr[offset]; const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - __half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); - const __half2* hs_half2 = reinterpret_cast(&hs_fl2); - const __half2* attn_half2 = reinterpret_cast(&attn_fl2); - const __half2* bias_half2 = reinterpret_cast(&bias_fl2); + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); - float2 res_low = __half22float2(res_half2[0]); - float2 res_high = __half22float2(res_half2[1]); + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); - const float2 hs_low = __half22float2(hs_half2[0]); - const float2 hs_high = __half22float2(hs_half2[1]); + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); - const float2 attn_low = __half22float2(attn_half2[0]); - const float2 attn_high = __half22float2(attn_half2[1]); + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); - const float2 bias_low = __half22float2(bias_half2[0]); - const float2 bias_high = __half22float2(bias_half2[1]); + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); if (attn_bias) { const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - const __half2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); - const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); // residual += attention_bias res_low.x += attn_bias_low.x; res_low.y += attn_bias_low.y; @@ -361,8 +396,8 @@ __global__ void gptj_residual_add(__half* residual, res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; - res_half2[0] = __float22half2_rn(res_low); - res_half2[1] = __float22half2_rn(res_high); + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); res_fl2_ptr[offset] = res_fl2; } @@ -387,24 +422,15 @@ void launch_gptj_residual_add(T* residual, residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); } -template void launch_gptj_residual_add(float*, - float*, - float*, - float*, - float*, - int, - int, - int, - cudaStream_t); -template void launch_gptj_residual_add<__half>(__half*, - __half*, - __half*, - __half*, - __half*, - int, - int, - int, - cudaStream_t); +#define INSTANTIATE_GPT_RES_ADD(T) \ + template void launch_gptj_residual_add(T*, T*, T*, T*, T*, int, int, int, cudaStream_t); + +INSTANTIATE_GPT_RES_ADD(float); +INSTANTIATE_GPT_RES_ADD(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_GPT_RES_ADD(__nv_bfloat16); +#endif + template __global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) { @@ -449,24 +475,20 @@ void launch_moe_res_matmul(T* residual, residual, coef, mlp_out, seq_len, hidden_dim); } -template void launch_moe_res_matmul(float* residual, - float* coef, - float* mlp_out, - int seq_len, - int hidden_dim, - cudaStream_t stream); -template void launch_moe_res_matmul(__half* residual, - __half* coef, - __half* mlp_out, - int seq_len, - int hidden_dim, - cudaStream_t stream); +#define INSTANTIATE_LAUNCH_MOE_RES_MATMUL(T) \ + template void launch_moe_res_matmul(T*, T*, T*, int, int, cudaStream_t); -__global__ void pad_data_kernel(__half* padded_output, - __half* output, - int head_size, - int padded_head_size) +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_MOE_RES_MATMUL(__half); + +template +__global__ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4* padded_output_cast = reinterpret_cast(padded_output); float4* output_cast = reinterpret_cast(output); int bid = blockIdx.x * (blockDim.y) + threadIdx.y; @@ -474,8 +496,8 @@ __global__ void pad_data_kernel(__half* padded_output, padded_output_cast += (bid * padded_head_size); output_cast += (bid * head_size); float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; if (idx < head_size) @@ -483,12 +505,14 @@ __global__ void pad_data_kernel(__half* padded_output, else padded_output_cast[idx] = ZERO; } + __global__ void pad_data_kernel(float* padded_output, float* output, int head_size, int padded_head_size) { } + template void pad_data(T* padded_output, T* output, @@ -502,26 +526,25 @@ void pad_data(T* padded_output, pad_data_kernel<<>>( padded_output, output, head_size / 8, padded_head_size / 8); } -template void pad_data(__half* padded_output, - __half* output, - int bsz, - int head_size, - int padded_head_size, - cudaStream_t stream); -template void pad_data(float* padded_output, - float* output, - int bsz, - int head_size, - int padded_head_size, - cudaStream_t stream); - -__global__ void pad_head_seq_kernel(__half* padded_output, - __half* output, + +#define INSTANTIATE_PAD_DATA(T) template void pad_data(T*, T*, int, int, int, cudaStream_t stream); + +INSTANTIATE_PAD_DATA(float); +INSTANTIATE_PAD_DATA(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_PAD_DATA(__nv_bfloat16); +#endif + +template +__global__ void pad_head_seq_kernel(T* padded_output, + T* output, int seq_len, int padded_seq_len, int head_size, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4* padded_output_cast = reinterpret_cast(padded_output); float4* output_cast = reinterpret_cast(output); int bsz = blockIdx.x; @@ -530,8 +553,8 @@ __global__ void pad_head_seq_kernel(__half* padded_output, padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; output_cast += (bsz * seq_len + bid) * head_size; float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; @@ -540,6 +563,7 @@ __global__ void pad_head_seq_kernel(__half* padded_output, else padded_output_cast[idx] = ZERO; } + __global__ void pad_head_seq_kernel(float* padded_output, float* output, int seq_len, @@ -548,6 +572,7 @@ __global__ void pad_head_seq_kernel(float* padded_output, int padded_head_size) { } + template void pad_head_seq(T* padded_output, T* output, @@ -563,22 +588,15 @@ void pad_head_seq(T* padded_output, pad_head_seq_kernel<<>>( padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); } -template void pad_head_seq(__half* padded_output, - __half* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream); -template void pad_head_seq(float* padded_output, - float* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - cudaStream_t stream); + +#define INSTANTIATE_PAD_HEAD_SEQ(T) \ + template void pad_head_seq(T*, T*, int, int, int, int, int, cudaStream_t); + +INSTANTIATE_PAD_HEAD_SEQ(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_PAD_HEAD_SEQ(__nv_bfloat16); +#endif +INSTANTIATE_PAD_HEAD_SEQ(float); // TODO(cmikeh2): evaluate different GeLU performance __device__ __forceinline__ float old_gelu(float val) @@ -594,12 +612,15 @@ constexpr int steps = 2; constexpr int granularity = 16; } // namespace fused_geglu -template -__global__ void fused_bias_geglu(T* output, - const T* activation, - const T* bias, - int base_channels, - int total_elems) +__device__ __forceinline__ float silu(float val) { return val / (1.0f + expf(-val)); } + +template +__global__ void fused_gate_activation(T* output, + const T* activation, + const T* bias, + int base_channels, + int output_stride, + int total_elems) { constexpr int T_per_access = fused_geglu::granularity / sizeof(T); constexpr int T_per_step = T_per_access * fused_geglu::threads; @@ -624,9 +645,10 @@ __global__ void fused_bias_geglu(T* output, activation + seq_offset + channel_id); mem_access::load_global( activation_buffer_2, activation + seq_offset + channel_id + base_channels); - mem_access::load_global(bias_buffer_1, bias + channel_id); - mem_access::load_global(bias_buffer_2, - bias + channel_id + base_channels); + mem_access::load_global( + bias_buffer_1, bias + channel_id, bias != nullptr); + mem_access::load_global( + bias_buffer_2, bias + channel_id + base_channels, bias != nullptr); // Since the GeLU is going to happen at float, might as well // convert @@ -634,23 +656,26 @@ __global__ void fused_bias_geglu(T* output, for (int v = 0; v < T_per_access; v++) { T hidden_state = activation_buffer_1[v] + bias_buffer_1[v]; T pre_gate = activation_buffer_2[v] + bias_buffer_2[v]; - float gate_f = old_gelu(conversion::to(pre_gate)); + float pre_gate_f = conversion::to(pre_gate); + float gate_f = (useGelu) ? old_gelu(pre_gate_f) : silu(pre_gate_f); T gate = conversion::to(gate_f); activation_buffer_1[v] = hidden_state * gate; } - mem_access::store_global(output + iter_id, - activation_buffer_1); + mem_access::store_global( + output + seq_id * output_stride + channel_id, activation_buffer_1); } } } template -void launch_fused_bias_geglu(T* output, +void launch_gated_activation(T* output, const T* activation, const T* bias, int rows, + int output_stride, int elems_per_row, + bool use_gelu, cudaStream_t stream) { /* @@ -671,14 +696,21 @@ void launch_fused_bias_geglu(T* output, dim3 block(fused_geglu::threads); dim3 grid((total_elems + T_per_block - 1) / T_per_block); - fused_bias_geglu<<>>( - output, activation, bias, base_channels, total_elems); + if (use_gelu) { + fused_gate_activation<<>>( + output, activation, bias, base_channels, output_stride, total_elems); + } else { + fused_gate_activation<<>>( + output, activation, bias, base_channels, output_stride, total_elems); + } } -template void launch_fused_bias_geglu(__half*, - const __half*, - const __half*, - int, - int, - cudaStream_t); -template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t); +#define INSTANTIATE_LAUNCH_GATED_ACTIVATION(T) \ + template void launch_gated_activation( \ + T*, const T*, const T*, int, int, int, bool, cudaStream_t); + +INSTANTIATE_LAUNCH_GATED_ACTIVATION(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_GATED_ACTIVATION(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_GATED_ACTIVATION(float); diff --git a/csrc/transformer/inference/csrc/layer_norm.cu b/csrc/transformer/inference/csrc/layer_norm.cu index ce3e471b38f9..e5e7e89c9d20 100644 --- a/csrc/transformer/inference/csrc/layer_norm.cu +++ b/csrc/transformer/inference/csrc/layer_norm.cu @@ -46,7 +46,7 @@ __global__ void fused_ln(T* output, (tb.thread_index().y * elems_per_row); const int thread_offset = tb.thread_index().x * T_per_load; const int base_offset = block_offset + thread_offset; - const int stride = tb.size() * T_per_load; + const int stride = blockDim.x * T_per_load; float sum = reduce::init(); @@ -57,8 +57,6 @@ __global__ void fused_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { T* iteration_buffer = local_buffer + i * T_per_load; - T residual_buffer[T_per_load]; - T bias_buffer[T_per_load]; mem_access::load_global( iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); @@ -91,8 +89,8 @@ __global__ void fused_ln(T* output, const float variance = mean_diff / elems_per_row; const float denom = __frsqrt_rn(variance + epsilon); - const T mean_compute = conversion::to(mean); - const T denom_compute = conversion::to(denom); + // const T mean_compute = conversion::to(mean); + // const T denom_compute = conversion::to(denom); T* block_output = output + block_offset; @@ -109,8 +107,11 @@ __global__ void fused_ln(T* output, #pragma unRoll for (int j = 0; j < T_per_load; j++) { - iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute; - iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j]; + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); } if (do_loads) { @@ -189,16 +190,14 @@ void launch_fused_ln(T* output, } } -template void launch_fused_ln(__half*, - const __half*, - const __half*, - const __half*, - float, - int, - int, - cudaStream_t); -template void -launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t); +#define INSTANTIATE_FUSED_LN(T) \ + template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_FUSED_LN(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_FUSED_LN(__nv_bfloat16); +#endif +INSTANTIATE_FUSED_LN(float); /* Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8 @@ -274,7 +273,7 @@ __global__ void fused_residual_ln(T* output, float vals_up_cast = conversion::to(iteration_buffer[j]); float res_up_cast = conversion::to(residual_buffer[j]); float bias_up_cast = conversion::to(bias_buffer[j]); - vals_up_cast += res_up_cast + bias_up_cast; + vals_up_cast = vals_up_cast + bias_up_cast + res_up_cast; sum = reduce::element(sum, vals_up_cast); iteration_buffer[j] = conversion::to(vals_up_cast); } @@ -305,9 +304,6 @@ __global__ void fused_residual_ln(T* output, const float variance = mean_diff / elems_per_row; const float denom = __frsqrt_rn(variance + epsilon); - const T mean_compute = conversion::to(mean); - const T denom_compute = conversion::to(denom); - T* block_output = output + block_offset; #pragma unRoll @@ -323,8 +319,13 @@ __global__ void fused_residual_ln(T* output, #pragma unRoll for (int j = 0; j < T_per_load; j++) { - iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute; - iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j]; + // iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute; + // iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j]; + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); } if (do_loads) { @@ -481,50 +482,22 @@ void launch_fused_residual_ln_store_pre_ln_res(T* norm_output, } } -// No-store specializations -template void launch_fused_residual_ln(__half*, - const __half*, - const __half*, - const __half*, - const __half*, - const __half*, - float, - int, - int, - cudaStream_t); - -template void launch_fused_residual_ln(float*, - const float*, - const float*, - const float*, - const float*, - const float*, - float, - int, - int, - cudaStream_t); - -// Store specializations -template void launch_fused_residual_ln_store_pre_ln_res(__half*, - __half*, - const __half*, - const __half*, - const __half*, - const __half*, - const __half*, - float, - int, - int, - cudaStream_t); - -template void launch_fused_residual_ln_store_pre_ln_res(float*, - float*, - const float*, - const float*, - const float*, - const float*, - const float*, - float, - int, - int, - cudaStream_t); +#define INSTANTIATE_RES_LN(T) \ + template void launch_fused_residual_ln( \ + T*, const T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +#define INSTANTIATE_PRE_LN_RES(T) \ + template void launch_fused_residual_ln_store_pre_ln_res( \ + T*, T*, const T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_RES_LN(__half); +INSTANTIATE_RES_LN(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_RES_LN(__nv_bfloat16); +#endif + +INSTANTIATE_PRE_LN_RES(__half); +INSTANTIATE_PRE_LN_RES(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_PRE_LN_RES(__nv_bfloat16); +#endif diff --git a/csrc/transformer/inference/csrc/pointwise_ops.cu b/csrc/transformer/inference/csrc/pointwise_ops.cu new file mode 100644 index 000000000000..0301ff777042 --- /dev/null +++ b/csrc/transformer/inference/csrc/pointwise_ops.cu @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace pwise { +constexpr int granularity = 16; +constexpr int unroll = 4; +constexpr int threads = 256; +} // namespace pwise + +template +__global__ void vector_add_kernel(T* out, const T* a, const T* b, float gamma, int num_elems) +{ + constexpr int T_per_access = pwise::granularity / sizeof(T); + + const int block_offset = blockIdx.x * pwise::threads * pwise::unroll * T_per_access; + const int thread_offset = threadIdx.x * T_per_access; + const int total_offset = block_offset + thread_offset; + constexpr int stride = pwise::threads * T_per_access; + +#pragma unroll + for (int i = 0; i < pwise::unroll; i++) { + T temp_buf_a[T_per_access], temp_buf_b[T_per_access]; + + const int iter_idx = total_offset + i * stride; + + mem_access::load_global(temp_buf_a, a + iter_idx, iter_idx < num_elems); + mem_access::load_global(temp_buf_b, b + iter_idx, iter_idx < num_elems); + +#pragma unroll + for (int j = 0; j < T_per_access; j++) { + float up_cast_a = conversion::to(temp_buf_a[j]); + float up_cast_b = conversion::to(temp_buf_b[j]); + temp_buf_a[j] = conversion::to((gamma * up_cast_a) + up_cast_b); + } + + if (iter_idx < num_elems) { + mem_access::store_global(out + iter_idx, temp_buf_a); + } + } +} + +template +void launch_vector_add(T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + cudaStream_t stream) +{ + constexpr int T_per_access = pwise::granularity / sizeof(T); + constexpr int T_per_block = pwise::threads * T_per_access * pwise::unroll; + + dim3 block(pwise::threads); + dim3 grid((num_elems + T_per_block - 1) / T_per_block); + + vector_add_kernel<<>>(out, a, b, gamma, num_elems); +} + +#define INSTANTIATE_VECTOR_ADD(T) \ + template void launch_vector_add( \ + T * out, const T* a, const T* b, float gamma, int num_elems, cudaStream_t stream); + +INSTANTIATE_VECTOR_ADD(float) +INSTANTIATE_VECTOR_ADD(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_VECTOR_ADD(__nv_bfloat16) +#endif diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 3de59e11377a..19dbe73726f7 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -16,7 +16,9 @@ std::array gemm_algos = std::array({99, 99, 99}); // NOTE: This activation function type enum should be always in sync // with the python counterpart, otherwise the casting from python binding // will be incorrect. -enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 }; +enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2, GATED_GELU = 3, GATED_SILU = 4 }; + +enum class NormType { UNKNOWN = 0, LayerNorm = 1, GroupNorm = 2, RMSNorm = 3 }; enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 }; @@ -141,7 +143,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) /* // Reallocate memory if we received a new prompt if (!workspace || input.size(1) != 1) { - allocate_workspace(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(), + allocate_workspace(W.size(1), InferenceContext::Instance().GetMaxTokenLength(), Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace(); } */ @@ -161,7 +163,9 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) (T*)W.data_ptr(), (T*)Q.data_ptr(), (T*)O.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -214,7 +218,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -251,7 +256,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -382,11 +388,12 @@ void attention_unfused(T* prev_key_cont, workspace, CUBLAS_OP_T, CUBLAS_OP_N, - InferenceContext::Instance().GetMaxTokenLenght() * k, + InferenceContext::Instance().GetMaxTokenLength() * k, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -415,11 +422,12 @@ void attention_unfused(T* prev_key_cont, (T*)output, CUBLAS_OP_N, CUBLAS_OP_N, - InferenceContext::Instance().GetMaxTokenLenght() * k, + InferenceContext::Instance().GetMaxTokenLength() * k, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -435,6 +443,7 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bool rotate_half, bool rotate_every_two, int heads, + int num_kv, float norm_factor, bool triangular, bool local_attention, @@ -442,18 +451,22 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bool no_masking, unsigned layer_id, unsigned num_layers, - at::Tensor& alibi) + at::Tensor& alibi, + float rope_theta, + bool is_prompt, + std::optional token_idx, + std::optional position_ids) { unsigned bsz = query_key_value.size(0); unsigned seq_len = query_key_value.size(1); - unsigned hidden_dim = query_key_value.size(2) / 3; + int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads)); + unsigned hidden_dim = heads * k; - bool is_prompt = (seq_len > 1); + is_prompt = (seq_len > 1); if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len); unsigned soft_len = InferenceContext::Instance().current_tokens(); - int k = hidden_dim / heads; auto options = at::TensorOptions() .dtype(query_key_value.options().dtype()) .layout(at::kStrided) @@ -462,15 +475,15 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); size_t buf_size = bsz * seq_len * hidden_dim; - auto output = torch::from_blob(workspace + 3 * buf_size, {bsz, seq_len, hidden_dim}, options); + auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); - auto query_cont = workspace + 4 * buf_size; + auto query_cont = workspace + 5 * buf_size; size_t offset = - 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) + - layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim; + 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLength()) + + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; unsigned all_tokens = soft_len; auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim; + size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; T* temp_buf = (T*)output.data_ptr() + at::numel(output); launch_bias_add_transform_0213((T*)query_cont, @@ -484,12 +497,14 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, soft_len, hidden_dim, heads, + (num_kv > 0 ? num_kv : heads), rotary_dim, rotate_half, rotate_every_two, InferenceContext::Instance().GetCurrentStream(), 3, - InferenceContext::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetMaxTokenLength(), + rope_theta); if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, kv_cache, @@ -499,10 +514,9 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, (is_prompt ? 0 : soft_len - 1), heads, bsz, - rotate_half, - rotate_every_two, + rope_theta, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetMaxTokenLength()); attention_unfused(workspace + offset, (T*)query_cont, @@ -531,22 +545,23 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, 1); if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); - auto prev_key = torch::from_blob(workspace + offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(), - k * InferenceContext::Instance().GetMaxTokenLenght(), - k, - 1}, - options); - - auto prev_value = - torch::from_blob(workspace + offset + value_offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(), - k * InferenceContext::Instance().GetMaxTokenLenght(), - k, - 1}, - options); + auto prev_key = torch::from_blob( + workspace + offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); + + auto prev_value = torch::from_blob( + workspace + offset + value_offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); return {output, prev_key, prev_value}; } @@ -567,12 +582,29 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) return input_cont; } -at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias) +#define DISPATCH_GATED_ACT(T_TYPE, C_TYPE) \ + if (activation.options().dtype() == torch::T_TYPE) { \ + launch_gated_activation((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)activation.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + rows, \ + out_channels, \ + channels, \ + activation_type == ActivationFuncType::GATED_GELU, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_gated_activation(at::Tensor& activation, at::Tensor& bias, int actFun) { /* Used in FF of Stable diffusion */ + const ActivationFuncType activation_type = static_cast(actFun); + + assert(activation_type == ActivationFuncType::GATED_GELU || + activation_type == ActivationFuncType::GATED_SILU); + const int batch_size = activation.size(0); const int seq_len = activation.size(1); const int channels = activation.size(2); @@ -583,21 +615,11 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias) auto output = at::empty({batch_size, seq_len, out_channels}, activation.options()); - if (activation.options().dtype() == torch::kFloat32) { - launch_fused_bias_geglu((float*)output.data_ptr(), - (const float*)activation.data_ptr(), - (const float*)bias.data_ptr(), - rows, - channels, - InferenceContext::Instance().GetCurrentStream()); - } else { - launch_fused_bias_geglu((__half*)output.data_ptr(), - (const __half*)activation.data_ptr(), - (const __half*)bias.data_ptr(), - rows, - channels, - InferenceContext::Instance().GetCurrentStream()); - } + DISPATCH_GATED_ACT(kFloat, float); + DISPATCH_GATED_ACT(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_GATED_ACT(kBFloat16, __nv_bfloat16); +#endif return output; } @@ -651,35 +673,99 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& return input_cont; } +#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, float epsilon) { const int rows = input.size(0) * input.size(1); const int elems_per_row = input.size(2); auto output = at::empty_like(input); - if (input.options().dtype() == torch::kFloat16) { - launch_fused_ln((__half*)output.data_ptr(), - (const __half*)input.data_ptr(), - (const __half*)gamma.data_ptr(), - (const __half*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); - } else { - launch_fused_ln((float*)output.data_ptr(), - (const float*)input.data_ptr(), - (const float*)gamma.data_ptr(), - (const float*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); + DISPATCH_LAYER_NORM(kFloat, float); + DISPATCH_LAYER_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16); +#endif + + return output; +} + +#define DISPATCH_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm((C_TYPE*)output.data_ptr(), \ + (C_TYPE*)nullptr, \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)nullptr, \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ } +at::Tensor ds_rms_norm(at::Tensor& input, at::Tensor& gamma, float epsilon) +{ + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + + DISPATCH_RMS_NORM(kFloat, float); + DISPATCH_RMS_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_RMS_NORM(kBFloat16, __nv_bfloat16); +#endif + return output; } +#define DISPATCH_PRE_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm((C_TYPE*)output.data_ptr(), \ + (C_TYPE*)res_out.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +std::vector ds_pre_rms_norm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + float epsilon) +{ + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + auto res_out = at::empty_like(residual); + + DISPATCH_PRE_RMS_NORM(kFloat, float); + DISPATCH_PRE_RMS_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_RMS_NORM(kBFloat16, __nv_bfloat16); +#endif + + return {output, res_out}; +} + template void ds_layer_norm_internal(T* workspace, at::Tensor& input, @@ -698,6 +784,20 @@ void ds_layer_norm_internal(T* workspace, InferenceContext::Instance().GetCurrentStream()); } +#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + /* Currently only used in unit testing */ at::Tensor ds_layer_norm_residual(at::Tensor& input, at::Tensor& bias, @@ -710,33 +810,31 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input, const int elems_per_row = input.size(2); auto output = at::empty_like(input); - if (input.options().dtype() == torch::kFloat16) { - launch_fused_residual_ln((__half*)output.data_ptr(), - (const __half*)input.data_ptr(), - (const __half*)residual.data_ptr(), - (const __half*)bias.data_ptr(), - (const __half*)gamma.data_ptr(), - (const __half*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); - } else { - launch_fused_residual_ln((float*)output.data_ptr(), - (const float*)input.data_ptr(), - (const float*)residual.data_ptr(), - (const float*)bias.data_ptr(), - (const float*)gamma.data_ptr(), - (const float*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); - } + DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif return output; } +#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln_store_pre_ln_res( \ + (C_TYPE*)norm_output.data_ptr(), \ + (C_TYPE*)res_output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + /* Currently only used in unit testing */ std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& input, at::Tensor& bias, @@ -750,31 +848,11 @@ std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu auto norm_output = at::empty_like(input); auto res_output = at::empty_like(input); - if (input.options().dtype() == torch::kFloat16) { - launch_fused_residual_ln_store_pre_ln_res((__half*)norm_output.data_ptr(), - (__half*)res_output.data_ptr(), - (const __half*)input.data_ptr(), - (const __half*)residual.data_ptr(), - (const __half*)bias.data_ptr(), - (const __half*)gamma.data_ptr(), - (const __half*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); - } else { - launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(), - (float*)res_output.data_ptr(), - (const float*)input.data_ptr(), - (const float*)residual.data_ptr(), - (const float*)bias.data_ptr(), - (const float*)gamma.data_ptr(), - (const float*)beta.data_ptr(), - epsilon, - rows, - elems_per_row, - InferenceContext::Instance().GetCurrentStream()); - } + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif return {norm_output, res_output}; } @@ -818,7 +896,8 @@ void quantized_gemm(void* output, weight16, (T*)input, (T*)output, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -863,7 +942,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), workspace, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -878,6 +958,74 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, return torch::from_blob(workspace, input.sizes(), input.options()); } +template +std::vector ds_rms_qkv(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& gamma, + const float epsilon, + bool q_int8, + bool transposed_mode) +{ + const int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* rms_norm_ptr = workspace + (3 * bsz * input.size(2)); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto rms_norm = at::from_blob(rms_norm_ptr, input.sizes(), options); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + + launch_rms_norm((T*)rms_norm.data_ptr(), + (T*)nullptr, + (const T*)input.data_ptr(), + (const T*)nullptr, + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm((T*)output.data_ptr(), + (T*)rms_norm.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)rms_norm.data_ptr(), + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + + return {output, rms_norm}; +} + template std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& weight, @@ -887,10 +1035,6 @@ std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& beta, const float epsilon, bool add_bias, - unsigned num_layers, - bool external_cache, - unsigned mp_size, - unsigned rank, bool q_int8, bool transposed_mode) { @@ -958,47 +1102,14 @@ void quantized_gemm(at::Tensor& output, (T*)weight16.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif } -template -at::Tensor ds_qkv_gemm_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool add_bias) -{ - int bsz = input.size(0) * input.size(1); - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - auto inp_norm = ds_layer_norm(input_cont, gamma, beta, epsilon); - - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - - return output; -} - template at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, @@ -1006,7 +1117,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, bool add_bias, bool do_flash_attn, int num_heads, - bool transposed_mode) + bool transposed_mode, + float rope_theta) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -1017,8 +1129,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); + int out_size = transposed_mode ? weight.size(0) : weight.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1036,7 +1149,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, (T*)weight.data_ptr(), (T*)input_cont.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1074,11 +1188,13 @@ at::Tensor ds_linear_layer(at::Tensor& input, (num_heads * padded_head_size), num_heads, -1, + -1, false, false, InferenceContext::Instance().GetCurrentStream(), 3, - input.size(1)); + input.size(1), + rope_theta); return at::from_blob(final_output, {3, input.size(0), num_heads, input.size(1), padded_head_size}, options); @@ -1099,11 +1215,13 @@ at::Tensor ds_linear_layer(at::Tensor& input, input_cont.size(2), num_heads, -1, + -1, false, false, InferenceContext::Instance().GetCurrentStream(), 3, - input.size(1)); + input.size(1), + rope_theta); return at::from_blob( final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, @@ -1207,31 +1325,6 @@ std::vector padd_add_transform(at::Tensor& query, {query.size(0), heads, key_value_length, padded_head_size}, query.options())}; } -template -at::Tensor ds_linear_layer_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& q_scale, - int groups) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - int bsz = input_cont.size(0) * input_cont.size(1); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, 0); - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - return output; -} template at::Tensor ds_vector_matmul(at::Tensor& input, @@ -1246,7 +1339,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); + int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); @@ -1275,7 +1368,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1361,7 +1455,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), inp_norm, intermediate, -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1405,7 +1500,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight1.data_ptr(), intermediate, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1469,39 +1565,152 @@ std::vector ds_mlp_gemm(at::Tensor& input, } template -std::vector ds_mlp_gemm_int8(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool preLayerNorm) +std::vector ds_rms_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& weight_interm, + at::Tensor& weight_out, + at::Tensor& gamma, + const float epsilon, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + int activation_type, + bool transposed_mode) { - auto input_cont = input.contiguous(); + const int bsz = input.size(0) * input.size(1); + const size_t input_neurons = input.size(2); + const size_t mlp_1_out_neurons = transposed_mode ? weight_interm.size(0) + : weight_interm.size(1); + const size_t mlp_2_in_neurons = transposed_mode ? weight_out.size(1) : weight_out.size(0); + auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + T* output_ptr = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input); + T* inp_norm_ptr = output_ptr + torch::numel(input); + T* intermediate_ptr = inp_norm_ptr + torch::numel(input); - int bsz = input_cont.size(0) * input_cont.size(1); - auto inp_norm = at::empty_like(input_cont); + auto output = at::from_blob(output_ptr, input.sizes(), options); + auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options); + auto intermediate_gemm = + at::from_blob(intermediate_ptr, + {input.size(0), input.size(1), static_cast(mlp_1_out_neurons)}, + options); - auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); + auto act_func_type = static_cast(activation_type); + + // RMS Norm, we'll update the residual in-place + launch_rms_norm((T*)inp_norm.data_ptr(), + (T*)residual.data_ptr(), + (const T*)input.data_ptr(), + (const T*)residual.data_ptr(), + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input_neurons, + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm(intermediate_ptr, + (T*)inp_norm.data_ptr(), + weight_interm, + q_scale, + q_scale.size(0), + bsz, + input_neurons); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + mlp_1_out_neurons, + bsz, + input_neurons, + &alpha, + &gemm_beta, + (T*)weight_interm.data_ptr(), + (T*)inp_norm.data_ptr(), + intermediate_ptr, +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } + + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu(intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu(intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_GELU) { + launch_gated_activation(intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + true, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_SILU) { + launch_gated_activation(intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + false, + InferenceContext::Instance().GetCurrentStream()); + } + + if (q_int8) { + quantized_gemm(output.data_ptr(), + intermediate_ptr, + weight_out, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), + CUBLAS_OP_N, + input_neurons, + bsz, + mlp_2_in_neurons, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + intermediate_ptr, + (T*)output.data_ptr(), +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_gemm_algo_standard, +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP, +#endif + mlp_1_out_neurons); + } - return {output, residual_add}; + return {output, residual}; } template @@ -1511,10 +1720,7 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, at::Tensor& bias, at::Tensor& weight_out, at::Tensor& weight_out_scale, - const float epsilon, - bool preLayerNorm, bool q_int8, - bool async_op, bool transposed_mode) { auto options = at::TensorOptions() @@ -1558,7 +1764,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)intermediate.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1592,7 +1799,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight_out.data_ptr(), (T*)intermediate.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1641,13 +1849,36 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, return residual; } +#define DISPATCH_VECTOR_ADD(T_TYPE, C_TYPE) \ + if (a.scalar_type() == at::k##T_TYPE) { \ + launch_vector_add((C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(b.data_ptr()), \ + gamma, \ + total_elems, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor& _vector_add(at::Tensor& a, at::Tensor& b, float gamma) +{ + const int total_elems = a.numel(); + + DISPATCH_VECTOR_ADD(Float, float) + DISPATCH_VECTOR_ADD(Half, __half) +#ifdef BF16_AVAILABLE + DISPATCH_VECTOR_ADD(BFloat16, __nv_bfloat16) +#endif + + return a; +} + std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, at::Tensor& key_layer, unsigned rotary_dim, unsigned offset, unsigned num_heads, bool rotate_half, - bool rotate_every_two) + float rope_theta) { auto query_cont = mixed_query.contiguous(); auto key_cont = key_layer.contiguous(); @@ -1665,10 +1896,9 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, - rotate_half, - rotate_every_two, + rope_theta, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetMaxTokenLength()); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), (__half*)key_cont.data_ptr(), @@ -1678,63 +1908,34 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, - rotate_half, - rotate_every_two, + rope_theta, InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetMaxTokenLength()); return {query_cont, key_cont}; } -template -at::Tensor fused_gemm_gelu_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool preLayerNorm) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - int bsz = input_cont.size(0) * input_cont.size(1); - - quantized_gemm(output, input_cont, weight, q_scale, groups, 0); - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - - return output; -} +#define DISPATCH_MOE_RESIDUAL(T_TYPE, C_TYPE) \ + if (moe_res.scalar_type() == torch::T_TYPE) { \ + launch_moe_res_matmul((C_TYPE*)moe_res.data_ptr(), \ + (C_TYPE*)coef.data_ptr(), \ + (C_TYPE*)output.data_ptr(), \ + M, \ + N, \ + InferenceContext::Instance().GetCurrentStream()); \ + } at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output) { int M = moe_res.size(0) * moe_res.size(1); int N = moe_res.size(2); InferenceContext::Instance().SynchComm(); - if (moe_res.scalar_type() == at::kFloat) { - launch_moe_res_matmul((float*)moe_res.data_ptr(), - (float*)coef.data_ptr(), - (float*)output.data_ptr(), - M, - N, - at::cuda::getCurrentCUDAStream()); - } else { - launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(), - (__half*)coef.data_ptr(), - (__half*)output.data_ptr(), - M, - N, - at::cuda::getCurrentCUDAStream()); - } + + DISPATCH_MOE_RESIDUAL(kFloat, float) + DISPATCH_MOE_RESIDUAL(kHalf, __half) +#ifdef BF16_AVAILABLE + DISPATCH_MOE_RESIDUAL(kBFloat16, __nv_bfloat16) +#endif + return output; } @@ -1742,85 +1943,102 @@ void ds_release_workspace() { InferenceContext::Instance().release_workspace(); bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); } +template +at::Tensor ds_dequantize(at::Tensor& weight, at::Tensor& qscale, int groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat16) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + InferenceContext::Instance().GetCurrentStream()); + + return weight16; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)"); - m.def( - "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); - m.def("softmax_context_fp16", - &ds_softmax_context<__half>, - "DeepSpeed attention with fp16 (CUDA)"); m.def("softmax_context_int8", &ds_softmax_context1<__half>, "DeepSpeed attention with int8 (CUDA)"); - m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); - m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)"); - m.def("bias_add_fp32", &ds_bias_add, "DeepSpeed Bias Add with fp32 (CUDA)"); - m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); - m.def("bias_relu_fp32", &ds_bias_relu, "DeepSpeed ReLU with fp32 (CUDA)"); - m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)"); - m.def("bias_residual_fp32", - &ds_bias_residual, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("bias_residual_fp16", - &ds_bias_residual<__half>, - "DeepSpeed residual-bias add with fp16 (CUDA)"); + + // The following functions handle type dispatching internally + m.def("gated_activation", &ds_gated_activation, "DeepSpeed Bias GEGLU (CUDA)"); m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)"); m.def( "_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)"); m.def("layer_norm_residual_store_pre_ln_res", &ds_layer_norm_residual_store_pre_ln_res, "DeepSpeed layer norm + store pre Layernorm residual (CUDA)"); - m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); - m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); - m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); - m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); - m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); - m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); - m.def("vector_matmul_int8", - &ds_vector_matmul_int8<__half>, - "DeepSpeed vector-MM with int8 (CUDA)"); - m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); - m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); - m.def("linear_layer_int8", - &ds_linear_layer_int8<__half>, - "DeepSpeed linear_layer with int8 (CUDA)"); - m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("residual_add_bias_fp32", - &residual_add_bias, - "DeepSpeed residual add with fp32 (CUDA)"); - m.def("residual_add_bias_fp16", - &residual_add_bias<__half>, - "DeepSpeed residual add with fp16 (CUDA)"); + m.def("rms_norm", &ds_rms_norm, "DeepSpeed rms norm (CUDA)"); + m.def("pre_rms_norm", &ds_pre_rms_norm, "DeepSpeed pre rms norm (CUDA)"); + m.def("_vector_add", &_vector_add, "DeepSpeed vector add (CUDA)"); m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("einsum_sec_sm_ecm_fp32", - &einsum_sec_sm_ecm, - "DeepSpeed vector-MM with fp32 (CUDA)"); - - m.def("einsum_sec_sm_ecm_fp16", - &einsum_sec_sm_ecm<__half>, - "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); - m.def("add_padding_fp32", &add_padding, "DeepSpeed residual add with fp32 (CUDA)"); - m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)"); - m.def("pad_transform_fp32", - &padd_add_transform, - "DeepSpeed residual add with fp32 (CUDA)"); - m.def("pad_transform_fp16", - &padd_add_transform<__half>, - "DeepSpeed residual add with fp16 (CUDA)"); - m.def("allocate_workspace_fp32", - &allocate_workspace, - "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); - m.def("allocate_workspace_fp16", - &allocate_workspace<__half>, - "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace"); m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); + + // The following functions are templated and need to be explicitly instantiated and bound + // to different python methods +#define DEF_OPS(_name, _dtype) \ + m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \ + m.def("softmax_context_" #_name, \ + &ds_softmax_context<_dtype>, \ + "DeepSpeed attention with " #_name " (CUDA)"); \ + m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \ + m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \ + m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \ + m.def("bias_residual_" #_name, \ + &ds_bias_residual<_dtype>, \ + "DeepSpeed residual-bias add with " #_name " (CUDA)"); \ + m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \ + m.def("rms_qkv_gemm_" #_name, \ + &ds_rms_qkv<_dtype>, \ + "DeepSpeed rms qkv gemm with " #_name " (CUDA)"); \ + m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("rms_mlp_gemm_" #_name, \ + &ds_rms_mlp_gemm<_dtype>, \ + "DeepSpeed rms mlp gemm with " #_name " (CUDA)"); \ + m.def("vector_matmul_" #_name, \ + &ds_vector_matmul<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("linear_layer_" #_name, \ + &ds_linear_layer<_dtype>, \ + "DeepSpeed linear_layer with " #_name " (CUDA)"); \ + m.def("fused_gemm_gelu_" #_name, \ + &fused_gemm_gelu<_dtype>, \ + "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("residual_add_bias_" #_name, \ + &residual_add_bias<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("einsum_sec_sm_ecm_" #_name, \ + &einsum_sec_sm_ecm<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("add_padding_" #_name, \ + &add_padding<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("pad_transform_" #_name, \ + &padd_add_transform<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("allocate_workspace_" #_name, \ + &allocate_workspace<_dtype>, \ + "DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \ + m.def("dequantize_" #_name, \ + &ds_dequantize<_dtype>, \ + "DeepSpeed dequantize with " #_name " (CUDA)"); + + DEF_OPS(fp32, float); + DEF_OPS(fp16, __half); +#ifdef BF16_AVAILABLE + DEF_OPS(bf16, __nv_bfloat16); +#endif } diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu index bf6eac269469..40926b776cf2 100644 --- a/csrc/transformer/inference/csrc/relu.cu +++ b/csrc/transformer/inference/csrc/relu.cu @@ -28,7 +28,8 @@ __global__ void fused_bias_relu(T* input, const T* bias, int total_count, int in T data[values_per_access]; T data_bias[values_per_access]; mem_access::load_global(data, input + offset); - mem_access::load_global(data_bias, bias + (offset % intermediate_size)); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll for (int i = 0; i < values_per_access; i++) { @@ -60,5 +61,11 @@ void launch_bias_relu(T* input, input, bias, total_count, intermediate_size); } -template void launch_bias_relu(float*, const float*, int, int, cudaStream_t); -template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t); +#define INSTANTIATE_LAUNCH_BIAS_RELU(T) \ + template void launch_bias_relu(T*, const T*, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_BIAS_RELU(float) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_RELU(__nv_bfloat16) +#endif +INSTANTIATE_LAUNCH_BIAS_RELU(__half) diff --git a/csrc/transformer/inference/csrc/rms_norm.cu b/csrc/transformer/inference/csrc/rms_norm.cu new file mode 100644 index 000000000000..5f72a4193752 --- /dev/null +++ b/csrc/transformer/inference/csrc/rms_norm.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "inference_cuda_layers.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace rms { +constexpr int granularity = 16; +} // namespace rms + +template +__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + + mem_access::load_global(iteration_buffer, + input_base + (i * stride), + thread_offset + (i * stride) < elems_per_row); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + float up_cast = conversion::to(iteration_buffer[j]); + float sq_val = up_cast * up_cast; + var_sum = reduce::element(var_sum, sq_val); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +template +__global__ void pre_rms_norm(T* output, + T* res_out, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + T* res_output = res_out + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + T residual_buffer[T_per_load]; + + const int iter_offset = i * stride + thread_offset; + const bool do_loads = (iter_offset < elems_per_row); + + mem_access::load_global( + iteration_buffer, input_base + (i * stride), do_loads); + mem_access::load_global( + residual_buffer, residual_base + (i * stride), do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] += residual_buffer[j]; + float vals_up_cast = conversion::to(iteration_buffer[j]); + + var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast); + } + + if (do_loads) { + mem_access::store_global(res_output + i * stride, iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + rms_norm \ + <<>>(norm_output, vals, gamma, epsilon, elems_per_row); + +#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + pre_rms_norm<<>>( \ + norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row); + +#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + if (pre_norm) { \ + LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } else { \ + LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = rms::granularity / sizeof(T); + constexpr int maxThreads = 256; + constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threads_per_group, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threads_per_group * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + bool pre_norm = (residual == nullptr) ? false : true; + + if (is_subblock_schedule) { + // <=128 + if (threads_per_group == 1) { + LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); + } else if (threads_per_group == 2) { + LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); + } else if (threads_per_group == 4) { + LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); + } else if (threads_per_group == 8) { + LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); + } else if (threads_per_group == 16) { + LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ + template void launch_rms_norm(T * norm_output, \ + T * res_output, \ + const T* vals, \ + const T* residual, \ + const T* gamma, \ + float epsilon, \ + int rows, \ + int elems_per_row, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_RMS_NORM(float) +INSTANTIATE_LAUNCH_RMS_NORM(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16) +#endif diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index 80eff139c3e9..bb06cc149ef4 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -4,9 +4,10 @@ // DeepSpeed Team #include +#include "conversion_utils.h" #include "inference_cuda_layers.h" -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif #include @@ -30,10 +31,10 @@ void CheckCudaErrorAux(const char* file, unsigned line) namespace cg = cooperative_groups; -template -__global__ void attn_softmax_v2(__half* vals, - __half* mask, - __half* alibi, +template +__global__ void attn_softmax_v2(T* vals, + T* mask, + T* alibi, float layer_scale, bool triangular, bool recompute, @@ -53,7 +54,7 @@ __global__ void attn_softmax_v2(__half* vals, float2 low_data[MAX_REG_SIZE]; float2 high_data[MAX_REG_SIZE]; - const __half zero_h = __float2half(0.f); + const T zero_h = conversion::to(0.f); int wid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; @@ -101,73 +102,87 @@ __global__ void attn_softmax_v2(__half* vals, ((data_id + reduceWidth * 3) > window_stride); if (mask && alibi) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset])) + - (__half2float(mask[data_id + mask_offset])) - : minus_infinity; + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; low_data[i].y = - low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + - (__half2float(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + - (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) : minus_infinity; high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + - (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) : minus_infinity; } else if (mask) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(mask[data_id + mask_offset])) - : minus_infinity; - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth])) + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(mask[data_id + mask_offset])) : minus_infinity; + low_data[i].y = + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; high_data[i].x = - high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) - : minus_infinity; + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; high_data[i].y = - high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) - : minus_infinity; + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; } else if (alibi) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset])) - : minus_infinity; + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + : minus_infinity; low_data[i].y = - low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth])) - : minus_infinity; + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + : minus_infinity; high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) : minus_infinity; high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) : minus_infinity; } else { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + low_data[i].x = low_x_check ? conversion::to(vals[data_id]) * layer_scale : minus_infinity; - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale - : minus_infinity; - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale - : minus_infinity; - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale - : minus_infinity; + low_data[i].y = + low_y_check ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + : minus_infinity; } // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); @@ -225,13 +240,13 @@ __global__ void attn_softmax_v2(__half* vals, for (int i = 0; i < iterations; i++) { int data_id = i * (reduceWidth << 2) + (seq_lane); if (data_id < sequence_length) { - vals[data_id] = __float2half(low_data[i].x / sum); + vals[data_id] = conversion::to(low_data[i].x / sum); if ((data_id + reduceWidth) < sequence_length) - vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum); + vals[data_id + reduceWidth] = conversion::to(low_data[i].y / sum); if ((data_id + reduceWidth * 2) < sequence_length) - vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum); + vals[data_id + reduceWidth * 2] = conversion::to(high_data[i].x / sum); if ((data_id + reduceWidth * 3) < sequence_length) - vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum); + vals[data_id + reduceWidth * 3] = conversion::to(high_data[i].y / sum); } } } @@ -389,23 +404,23 @@ __global__ void attn_softmax_v2(float* vals, } } -#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ - attn_softmax_v2<<>>(vals, \ - mask, \ - alibi, \ - layer_scale, \ - triangular, \ - recompute, \ - local_attention, \ - window_size, \ - total_count, \ - heads, \ - sequence_length, \ - num_seq, \ - head_offset, \ - mask_stride, \ - mp_size, \ - reduce_width); +#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ + attn_softmax_v2<<>>(vals, \ + mask, \ + alibi, \ + layer_scale, \ + triangular, \ + recompute, \ + local_attention, \ + window_size, \ + total_count, \ + heads, \ + sequence_length, \ + num_seq, \ + head_offset, \ + mask_stride, \ + mp_size, \ + reduce_width); template void launch_attn_softmax_v2(T* vals, @@ -472,35 +487,76 @@ void launch_attn_softmax_v2(T* vals, throw std::runtime_error("Unsupport Seq_Length!"); } -template void launch_attn_softmax_v2(float* vals, - float* mask, - float* alibi, - float layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - int head_offset, - int mask_stride, - int mp_size, - cudaStream_t stream); -template void launch_attn_softmax_v2(__half* vals, - __half* mask, - __half* alibi, - float layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - int head_offset, - int mask_stride, - int mp_size, - cudaStream_t stream); +#define INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(T) \ + template void launch_attn_softmax_v2(T* vals, \ + T* mask, \ + T* alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int batch_size, \ + int heads, \ + int num_seq, \ + int sequence_length, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(__half); + +#define DEF_ATTN_SOFTMAX_V2_HALF(_iter) \ + template __global__ void attn_softmax_v2<__half, _iter>(__half * vals, \ + __half * mask, \ + __half * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define DEF_ATTN_SOFTMAX_V2_BF16(_iter) \ + template __global__ void attn_softmax_v2<__nv_bfloat16, _iter>(__nv_bfloat16 * vals, \ + __nv_bfloat16 * mask, \ + __nv_bfloat16 * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define FOREACH_ITERATIONS(cb) \ + cb(1); \ + cb(2); \ + cb(4); \ + cb(8); \ + cb(16); \ + cb(32); \ + cb(64) + +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_HALF); +#ifdef BF16_AVAILABLE +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_BF16); +#endif diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 8ef0d3289bb5..e7624363021e 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -3,12 +3,24 @@ // DeepSpeed Team -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif +#include "conversion_utils.h" #include "inference_cuda_layers.h" namespace cg = cooperative_groups; +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +#if defined(__CUDA_BF16_H__) +static_assert(sizeof(__nv_bfloat162) == sizeof(__half2), + "CUDA's __nv_bfloat162 doesn't match __half2 size"); +#else +// Fallback to simple typedef only if CUDA doesn't provide it +using __nv_bfloat162 = __half2; +#endif +#endif + // Bias add __global__ void bias_add_transform_0213(float* output, @@ -20,11 +32,14 @@ __global__ void bias_add_transform_0213(float* output, int seq_length, unsigned seq_offset, int heads, + int head_stride, + int num_kv, int rotary_dim, bool rotate_half, bool rotate_every_two, int head_ext, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -43,10 +58,10 @@ __global__ void bias_add_transform_0213(float* output, float4* output_vec = reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); - vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); - vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); - vals_vec += (cnt * d1_stride); - vals_vec += (d2 * d2_stride); + vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); + vals_vec += d1 * (d1_stride + num_kv * 2 * d2_stride); + vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); output_vec += (d1 * d2_stride); output_vec += (d0 * d0_out_stride); @@ -62,7 +77,7 @@ __global__ void bias_add_transform_0213(float* output, #pragma unroll for (int o = 0; o < 2; o++) { float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq)); q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq)); } @@ -75,22 +90,28 @@ __global__ void bias_add_transform_0213(float* output, #define ATTN_H 3 #define MAX_SEQ_LINE 10 -__global__ void bias_add_transform_0213(__half* output, // q - __half* k_cache, - __half* v_cache, - const __half* vals, // qkv - const __half* bias, +template +__global__ void bias_add_transform_0213(T* output, // q + T* k_cache, + T* v_cache, + const T* vals, // qkv + const T* bias, int hidden_dim, int seq_length, unsigned seq_offset, int all_tokens, int heads, + int head_stride, + int num_kv, int rotary_dim, bool rotate_half, bool rotate_every_two, int head_ext, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; unsigned half_dim = (rotary_dim << 3) >> 1; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -108,17 +129,17 @@ __global__ void bias_add_transform_0213(__half* output, // q float4 vals_arr; float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + T2* vals_half = reinterpret_cast(&vals_arr); + T2* output_half = reinterpret_cast(&output_arr); const float4* vals_vec = reinterpret_cast(vals); float4* output_vec = reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); - vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); - vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); - vals_vec += (cnt * d1_stride); - vals_vec += (d2 * d2_stride); + vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); + vals_vec += (d1 * (d1_stride + num_kv * 2 * d2_stride)); + vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); output_vec += (d1 * d2_stride); output_vec += (d0 * d0_out_stride); @@ -129,17 +150,19 @@ __global__ void bias_add_transform_0213(__half* output, // q int lane = d3 & 0x1f; if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { float4 q = vals_vec[d3]; - __half2* q_h = reinterpret_cast<__half2*>(&q); + T2* q_h = reinterpret_cast(&q); if (rotate_every_two) { #pragma unroll for (int o = 0; o < 4; o++) { float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id; float q_data[2]; - q_data[0] = (float)q_h[o].x; - q_data[1] = (float)q_h[o].y; - q_h[o].x = (__half)(-1.0 * q_data[1] * sinf(inv_freq) + q_data[0] * cosf(inv_freq)); - q_h[o].y = (__half)(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); + q_data[0] = conversion::to(q_h[o].x); + q_data[1] = conversion::to(q_h[o].y); + q_h[o].x = conversion::to(-1.0 * q_data[1] * sinf(inv_freq) + + q_data[0] * cosf(inv_freq)); + q_h[o].y = + conversion::to(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); } } output_vec[d3] = q; @@ -160,12 +183,14 @@ void launch_bias_add_transform_0213(float* output, int all_tokens, int hidden_dim, int heads, + int num_kv, int rotary_dim, bool rotate_half, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { hidden_dim >>= 2; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; @@ -182,48 +207,36 @@ void launch_bias_add_transform_0213(float* output, seq_length, seq_offset, heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, rotary_dim >> 2, rotate_half, rotate_every_two, head_ext, - max_out_tokens); + max_out_tokens, + rope_theta); } + template -void launch_bias_add_transform_0213(T* outputs, - T* vals, - T* vals1, - const T* vals2, +void launch_bias_add_transform_0213(T* output, + T* k_cache, + T* v_cache, + const T* vals, const T* bias, int batch_size, int seq_length, unsigned seq_offset, - int seq_length1, + int all_tokens, int hidden_dim, int heads, + int num_kv, int rotary_dim, bool rotate_half, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens); -template <> -void launch_bias_add_transform_0213<__half>(__half* output, - __half* k_cache, - __half* v_cache, - const __half* vals, - const __half* bias, - int batch_size, - int seq_length, - unsigned seq_offset, - int all_tokens, - int hidden_dim, - int heads, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - cudaStream_t stream, - int trans_count, - int max_out_tokens) + int max_out_tokens, + float rope_theta) { hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; @@ -239,13 +252,42 @@ void launch_bias_add_transform_0213<__half>(__half* output, seq_offset, all_tokens, heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, rotary_dim >> 3, rotate_half, rotate_every_two, head_ext, - max_out_tokens); + max_out_tokens, + rope_theta); } +#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \ + template void launch_bias_add_transform_0213(T*, \ + T*, \ + T*, \ + const T*, \ + const T*, \ + int, \ + int, \ + unsigned, \ + int, \ + int, \ + int, \ + int, \ + int, \ + bool, \ + bool, \ + cudaStream_t, \ + int, \ + int, \ + float) + +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__half); + // Bias add __global__ void pad_add_transform_0213(float* output, @@ -258,17 +300,20 @@ __global__ void pad_add_transform_0213(float* output, { } -__global__ void pad_add_transform_0213(__half* output, - const __half* vals, +template +__global__ void pad_add_transform_0213(T* output, + const T* vals, int hidden_dim, int seq_length, int padded_seq_len, int heads, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; @@ -301,17 +346,6 @@ __global__ void pad_add_transform_0213(__half* output, output_vec[d3] = ZERO; } -template -void launch_pad_add_transform_0213(T* output, - const T* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream); - // [B S C*H] - > C * [B A S N] template <> void launch_pad_add_transform_0213(float* output, @@ -325,16 +359,17 @@ void launch_pad_add_transform_0213(float* output, cudaStream_t stream) { } -template <> -void launch_pad_add_transform_0213<__half>(__half* output, - const __half* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream) + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) { hidden_dim >>= 3; dim3 block_dim((padded_head_size >> 3), heads, 2); @@ -343,6 +378,15 @@ void launch_pad_add_transform_0213<__half>(__half* output, output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3); } +#define INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(T) \ + template void launch_pad_add_transform_0213( \ + T*, const T*, int, int, int, int, int, int, cudaStream_t); + +INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(__nv_bfloat16); +#endif + // Bias add template __global__ void bias_add_transform_0213(T* output, @@ -394,15 +438,17 @@ __global__ void bias_add_transform_0213(float* output, d2 * d2_out_stride + d3] = outputs; } -template <> -__global__ void bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; @@ -418,9 +464,9 @@ __global__ void bias_add_transform_0213<__half>(__half* output, float4 vals_arr; float4 bias_arr; float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + T2* vals_half = reinterpret_cast(&vals_arr); + T2* bias_half = reinterpret_cast(&bias_arr); + T2* output_half = reinterpret_cast(&output_arr); const float4* vals_vec = reinterpret_cast(vals); const float4* bias_vec = reinterpret_cast(bias); @@ -449,13 +495,16 @@ __global__ void bias_add_transform_0213<__half>(__half* output, output_vec[d3] = output_arr; } -__global__ void bias_add_transform_0213_v2(__half* output, - const __half* vals, - const __half* bias, +template +__global__ void bias_add_transform_0213_v2(T* output, + const T* vals, + const T* bias, int hidden_dim, int seq_length, int heads) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; __shared__ float4 in_data[3072]; int d0_stride = hidden_dim * seq_length; @@ -477,9 +526,9 @@ __global__ void bias_add_transform_0213_v2(__half* output, float4 vals_arr[1]; float4 bias_arr[1]; float4 output_arr[1]; - __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(output_arr); + T2* vals_half = reinterpret_cast(vals_arr); + T2* bias_half = reinterpret_cast(bias_arr); + T2* output_half = reinterpret_cast(output_arr); const float4* vals_vec = reinterpret_cast(vals); const float4* bias_vec = reinterpret_cast(bias); @@ -560,13 +609,13 @@ __global__ void transform4d_0213(float* out, } } -template <> -__global__ void transform4d_0213<__half>(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) { int d0_stride = hidden_dim * (seq_length / head_ext); int d1_stride = hidden_dim; @@ -594,11 +643,8 @@ __global__ void transform4d_0213<__half>(__half* out, out_vec[d3] = in_vec[d3]; } -__global__ void transform4d_0213_v2(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim) +template +__global__ void transform4d_0213_v2(T* out, const T* in, int heads, int seq_length, int hidden_dim) { __shared__ float4 in_data[3072]; @@ -660,20 +706,28 @@ void launch_transform4d_0213(float* out, <<>>(out, in, heads, seq_length, hidden_dim, 1); } -template <> -void launch_transform4d_0213<__half>(__half* out, - const __half* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) { hidden_dim >>= 3; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); dim3 block_dims(hidden_dim / heads, (heads / head_ext)); - transform4d_0213<__half> - <<>>(out, in, heads, seq_length, hidden_dim, head_ext); + transform4d_0213<<>>( + out, in, heads, seq_length, hidden_dim, head_ext); } + +#define INSTANTIATE_2B_LAUNCH_TRANSFORM4D(T) \ + template void launch_transform4d_0213(T*, const T*, int, int, int, int, cudaStream_t, int); + +INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_2B_LAUNCH_TRANSFORM4D(__nv_bfloat16) +#endif diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index f7bbcad91e2a..378fd4e5e990 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -60,12 +60,17 @@ class InferenceContext { { _workSpaceSize = 0; _workspace = 0; - if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { - auto message = std::string("Fail to create cublas handle."); + + cublasStatus_t stat = cublasCreate(&_cublasHandle); + if (stat != CUBLAS_STATUS_SUCCESS) { + // It would be nice to use cublasGetStatusName and + // cublasGetStatusString, but they were only added in CUDA 11.4.2. + auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") + + std::to_string(stat); std::cerr << message << std::endl; throw std::runtime_error(message); } -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); #endif cudaEventCreate(&_comp1_event); @@ -136,7 +141,7 @@ class InferenceContext { if (_max_seq_len < min_out_tokens) { printf( - "Allocatable workspace available (%d tokens) is less than minimum requested " + "Allocatable workspace available (%ld tokens) is less than minimum requested " "workspace (%d tokens)\n", _max_seq_len, min_out_tokens); @@ -175,7 +180,7 @@ class InferenceContext { _workSpaceSize = workSpaceSize; _attention_unfused_workspace_offset = workSpaceSize - temp_size; } - inline size_t GetMaxTokenLenght() const { return _max_seq_len; } + inline size_t GetMaxTokenLength() const { return _max_seq_len; } cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } diff --git a/csrc/transformer/inference/includes/inference_cublas_wrappers.h b/csrc/transformer/inference/includes/inference_cublas_wrappers.h index e899ec266d83..40c3e443941d 100644 --- a/csrc/transformer/inference/includes/inference_cublas_wrappers.h +++ b/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -8,14 +8,19 @@ #include #include #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include -#ifndef __HIP_PLATFORM_HCC__ +#ifndef __HIP_PLATFORM_AMD__ #include #endif #include -#ifdef __HIP_PLATFORM_HCC__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -27,7 +32,8 @@ int cublas_gemm_ex(rocblas_handle handle, const float* A, const float* B, float* C, - rocblas_gemm_algo algo) + rocblas_gemm_algo algo, + int b_stride = -1) #else int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, @@ -40,10 +46,13 @@ int cublas_gemm_ex(cublasHandle_t handle, const float* A, const float* B, float* C, - cublasGemmAlgo_t algo) + cublasGemmAlgo_t algo, + int b_stride = -1) #endif { -#ifdef __HIP_PLATFORM_HCC__ + const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -56,7 +65,7 @@ int cublas_gemm_ex(cublasHandle_t handle, (transa == rocblas_operation_none) ? m : k, (const void*)B, rocblas_datatype_f32_r, - (transb == rocblas_operation_none) ? k : n, + ldb, (const void*)beta, C, rocblas_datatype_f32_r, @@ -77,20 +86,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, - (transb == CUBLAS_OP_N) ? k : n, +#endif + ldb, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -106,7 +134,9 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +template +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -115,10 +145,11 @@ int cublas_gemm_ex(rocblas_handle handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, - rocblas_gemm_algo algo) + const T* A, + const T* B, + T* C, + rocblas_gemm_algo algo, + int b_stride = -1) #else int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, @@ -128,13 +159,18 @@ int cublas_gemm_ex(cublasHandle_t handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, - cublasGemmAlgo_t algo) + const T* A, + const T* B, + T* C, + cublasGemmAlgo_t algo, + int b_stride = -1) #endif { -#ifdef __HIP_PLATFORM_HCC__ + const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -143,23 +179,28 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, - rocblas_datatype_f16_r, + rocblas_dtype_16, (transa == rocblas_operation_none) ? m : k, (const void*)B, - rocblas_datatype_f16_r, - (transb == rocblas_operation_none) ? k : n, + rocblas_dtype_16, + ldb, (const void*)beta, (void*)C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, (void*)C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, rocblas_datatype_f32_r, algo, 0, 0); #else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; +#else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmEx(handle, transa, transb, @@ -168,20 +209,27 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, - CUDA_R_16F, + cublas_dtype_16, (transa == CUBLAS_OP_N) ? m : k, (const void*)B, - CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, + cublas_dtype_16, + ldb, (const void*)beta, (void*)C, - CUDA_R_16F, + cublas_dtype_16, m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -197,7 +245,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -233,7 +282,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -273,24 +323,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -307,16 +376,18 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_HCC__ +template +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, rocblas_operation op_A, rocblas_operation op_B, int stride_A, @@ -331,9 +402,9 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, @@ -343,7 +414,10 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -353,20 +427,20 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, - rocblas_datatype_f16_r, + rocblas_dtype_16, (op_A == rocblas_operation_none) ? m : k, stride_A, B, - rocblas_datatype_f16_r, + rocblas_dtype_16, (op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, stride_C, C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, stride_C, batch, @@ -375,6 +449,11 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, 0, 0); #else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; +#else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmStridedBatchedEx(handle, op_A, op_B, @@ -383,24 +462,31 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, - CUDA_R_16F, + cublas_dtype_16, (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, - CUDA_R_16F, + cublas_dtype_16, (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, - CUDA_R_16F, + cublas_dtype_16, m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 8e5c1ae4f44b..dcc020483687 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -8,6 +8,9 @@ #include "ds_kernel_utils.h" #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include #include @@ -49,11 +52,13 @@ void launch_bias_gelu(T* input, cudaStream_t stream); template -void launch_fused_bias_geglu(T* output, +void launch_gated_activation(T* output, const T* activation, const T* bias, int rows, + int output_stride, int elems_per_row, + bool use_gelu, cudaStream_t stream); // Fused bias add with relu activation @@ -114,6 +119,17 @@ void launch_fused_residual_ln_store_pre_ln_res(T* norm_output, int elems_per_row, cudaStream_t stream); +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + template void launch_dequantize(T* output, const int8_t* input, @@ -152,8 +168,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, - bool rotate_half, - bool rotate_every_two, + float rope_theta, cudaStream_t stream, int max_out_tokens); @@ -187,12 +202,14 @@ void launch_bias_add_transform_0213(T* outputs, int seq_length1, int hidden_dim, int heads, + int num_kv, int rotary_dim, bool rotate_half, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens); + int max_out_tokens, + float rope_theta); template void pad_data(T* padded_output, T* output, @@ -221,3 +238,11 @@ void launch_pad_add_transform_0213(T* output, int heads, int padded_head_size, cudaStream_t stream); + +template +void launch_vector_add(T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + cudaStream_t stream); diff --git a/csrc/utils/flatten_unflatten.cpp b/csrc/utils/flatten_unflatten.cpp deleted file mode 100644 index ab95ee191464..000000000000 --- a/csrc/utils/flatten_unflatten.cpp +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -/* -Copyright NVIDIA/apex -This file is adapted from fused adam in NVIDIA/apex, commit a109f85 -*/ - -#include -#include -// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h - -at::Tensor flatten(std::vector tensors) -{ - return torch::utils::flatten_dense_tensors(tensors); -} - -std::vector unflatten(at::Tensor flat, std::vector tensors) -{ - return torch::utils::unflatten_dense_tensors(flat, tensors); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("flatten", &flatten, "Flatten dense tensors"); - m.def("unflatten", &unflatten, "Unflatten dense tensors"); -} diff --git a/csrc/utils/py_ds_utils.cpp b/csrc/utils/py_ds_utils.cpp new file mode 100644 index 000000000000..df5c9f361c61 --- /dev/null +++ b/csrc/utils/py_ds_utils.cpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* + Collection of system utilities. +*/ + +#include +#include "tensor_cast.h" +using namespace pybind11::literals; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("cast_to_byte_tensor", + py::overload_cast(&cast_to_byte_tensor), + "Cast a 1-dimensional tensor of any type to byte tensor.", + "src_tensor"_a); + + m.def("cast_to_byte_tensor", + py::overload_cast&>(&cast_to_byte_tensor), + "Cast a multi-dimensional tensor of any type to byte tensor.", + "src_tensor"_a); +} diff --git a/csrc/utils/tensor_cast.cpp b/csrc/utils/tensor_cast.cpp new file mode 100644 index 000000000000..8352bc72bb02 --- /dev/null +++ b/csrc/utils/tensor_cast.cpp @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "tensor_cast.h" + +at::Tensor cast_to_byte_tensor(at::Tensor& src_tensor) +{ + if (src_tensor.nbytes() <= 1) return src_tensor; + + auto options = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(src_tensor.layout()) + .device(src_tensor.device()); + return at::from_blob( + src_tensor.data_ptr(), static_cast(src_tensor.nbytes()), options); +} + +std::vector cast_to_byte_tensor(std::vector& tensor_list) +{ + std::vector byte_tensors; + for (auto src_tensor : tensor_list) { byte_tensors.push_back(cast_to_byte_tensor(src_tensor)); } + + return byte_tensors; +} diff --git a/csrc/utils/tensor_cast.h b/csrc/utils/tensor_cast.h new file mode 100644 index 000000000000..86155567122d --- /dev/null +++ b/csrc/utils/tensor_cast.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* + Utilities for type casting torch tensors without data movement. +*/ + +#include +#include + +using namespace std; +at::Tensor cast_to_byte_tensor(at::Tensor& src_tensor); + +std::vector cast_to_byte_tensor(std::vector& tensor_list); diff --git a/csrc/xpu/adagrad/cpu_adagrad.cpp b/csrc/xpu/adagrad/cpu_adagrad.cpp new file mode 100644 index 000000000000..dc727f8fa216 --- /dev/null +++ b/csrc/xpu/adagrad/cpu_adagrad.cpp @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_adagrad.h" +#include +#include +#include +#include +#include +#include + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adagrad_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) { + float step_size = -1 * _alpha; + ds_half_precision_t* grads_cast_h; + ds_half_precision_t* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast(grads); + params_cast_h = reinterpret_cast(_params); + } + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = grads[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0) { grad = param * _weight_decay + grad; } + + variance += grad * grad; + + grad = sqrt(variance); + grad += _eps; + grad = momentum / grad; + param = grad * step_size + param; + if (half_precision) + params_cast_h[k] = (ds_half_precision_t)param; + else + _params[k] = param; + // STORE UPDATE TERM TO GRAD'S MEMORY + grads[k] = grad * step_size; + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adagrad_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adagrad_optimizer(int optimizer_id, + float alpha = 1e-2, + float eps = 1e-8, + float weight_decay = 0, + bool should_log = false) +{ + auto opt = std::make_shared(alpha, eps, weight_decay); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay); + } + + return 0; +} + +void Adagrad_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adagrad_step(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step); + opt->update_state(lr, epsilon, weight_decay); + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel()); + + return 0; +} + +int ds_adagrad_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + assert(false); + return 0; +} + +int destroy_adagrad_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); + m.def("adagrad_update_copy", + &ds_adagrad_step_plus_copy, + "DeepSpeed CPU Adagrad update and param copy (C++)"); + m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); + m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); +} diff --git a/csrc/xpu/adam/fused_adam_frontend.cpp b/csrc/xpu/adam/fused_adam_frontend.cpp new file mode 100755 index 000000000000..13b390248608 --- /dev/null +++ b/csrc/xpu/adam/fused_adam_frontend.cpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +void multi_tensor_adam_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_adam", + &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); +} diff --git a/csrc/xpu/adam/multi_tensor_adam.dp.cpp b/csrc/xpu/adam/multi_tensor_adam.dp.cpp new file mode 100644 index 000000000000..0720a020247a --- /dev/null +++ b/csrc/xpu/adam/multi_tensor_adam.dp.cpp @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include + +#include + +#include +#include "multi_tensor_apply.dp.hpp" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum : int { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __inline__ __attribute__((always_inline)) void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<4>& tl, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float epsilon, + const float lr, + adamMode_t mode, + const float decay) + { + auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + int tensor_loc = tl.block_to_tensor[item_ct1.get_group(2)]; + + int chunk_idx = tl.block_to_chunk[item_ct1.get_group(2)]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T* v = (T*)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += item_ct1.get_local_range(2) * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + item_ct1.get_local_id(2) + ii * item_ct1.get_local_range(2); + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sycl::sqrt(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sycl::sqrt(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + item_ct1.get_local_id(2) + ii * item_ct1.get_local_range(2); + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) +} diff --git a/csrc/xpu/adam/multi_tensor_apply.dp.hpp b/csrc/xpu/adam/multi_tensor_apply.dp.hpp new file mode 100644 index 000000000000..0511ccf3a179 --- /dev/null +++ b/csrc/xpu/adam/multi_tensor_apply.dp.hpp @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +#include "compat.h" + +#include +#include +#include + +namespace at { +namespace cuda { +sycl::queue* getCurrentCUDAStream() +{ + c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream(); + auto& queue = stream.queue(); + return &queue; +} + +sycl::queue* getStreamFromPool(bool) +{ + // not implemented + return nullptr; +} +} // namespace cuda +} // namespace at +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + +template +class multi_tensor_apply_kernel { +public: + multi_tensor_apply_kernel(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) + : chunk_size(chunk_size), noop_flag(noop_flag), tl(tl), callable(callable), args(args...) + { + } + + // This should be identical to original __global__ function + static void inline __global__function(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) + { + callable(chunk_size, noop_flag, tl, args...); + } + + // If global function template contains parameter pack, + // we only deal with parameter pack at the end of template parameter list + template + static void inline __tuple_expand_driver(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + Tuple args, + std::index_sequence) + { + __global__function(chunk_size, noop_flag, tl, callable, std::get(args)...); + } + + // + // Because __global__ function can't really use any reference types, we can sure that args + // are all good behaviors + // + void operator()(sycl::nd_item<3>) const + { + __tuple_expand_driver(chunk_size, + noop_flag, + tl, + callable, + args, + std::make_index_sequence()); + } + +private: + int chunk_size; + volatile int* noop_flag; + T tl; + U callable; + std::tuple args; +}; + +// to make sure multi_tensor_apply_kernel can be used in sycl::buffer +namespace sycl { +template +struct is_device_copyable> : std::true_type {}; +} // namespace sycl + +template +void multi_tensor_apply(int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kXPU, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + /* const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); */ + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + /* multi_tensor_apply_kernel, T, ArgTypes...> + * fn(chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); */ + if constexpr (sizeof(multi_tensor_apply_kernel( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...)) < + 2048) { + ((sycl::queue*)(stream)) + ->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, loc_block_info) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + multi_tensor_apply_kernel( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...)); + } else { + auto capture = multi_tensor_apply_kernel( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + sycl::buffer params(const_cast(&capture), + sycl::range<1>(1)); + stream->submit([&](sycl::handler& cgh) { + auto device_params = + params.template get_access(cgh); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, loc_block_info) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item) { device_params[0](item); }); + }); + } + 0; + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/csrc/xpu/common/custom_cuda_kernel.dp.cpp b/csrc/xpu/common/custom_cuda_kernel.dp.cpp new file mode 100644 index 000000000000..cfd004ef1357 --- /dev/null +++ b/csrc/xpu/common/custom_cuda_kernel.dp.cpp @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +inline void has_capability_or_fail(const sycl::device& dev, + const std::initializer_list& props) +{ + for (const auto& it : props) { + if (dev.has(it)) continue; + switch (it) { + case sycl::aspect::fp64: + throw std::runtime_error("'double' is not supported in '" + + dev.get_info() + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error("'half' is not supported in '" + + dev.get_info() + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string { + switch (AspectNum) { +#include +#include + default: return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw std::runtime_error("'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } +} + +void param_update_kernel(const float* input, sycl::half* output, int size) +{ + auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + + if (id < size) { output[id] = (sycl::half)input[id]; } +} + +void launch_param_update(const float* input, sycl::half* output, int size, sycl::queue* stream) +{ + int threads = 1024; + + sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1); + sycl::range<3> block_dim(1, 1, threads); + + { + has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { param_update_kernel(input, output, size); }); + } +} + +void param_update_kernel_half(const float* input, sycl::half* output, int size) +{ + auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + sycl::half2* output_cast = reinterpret_cast(output); + if (id < size) { + float input_f = input[id]; + sycl::half2* input_h = reinterpret_cast(&input_f); + output_cast[id] = *input_h; + } +} + +void launch_param_update_half(const float* input, sycl::half* output, int size, sycl::queue* stream) +{ + int threads = 1024; + size /= 2; + sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1); + sycl::range<3> block_dim(1, 1, threads); + + { + has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { param_update_kernel_half(input, output, size); }); + } +} diff --git a/csrc/xpu/includes/compat.h b/csrc/xpu/includes/compat.h new file mode 100755 index 000000000000..6d54446d472e --- /dev/null +++ b/csrc/xpu/includes/compat.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/csrc/xpu/includes/cpu_adagrad.h b/csrc/xpu/includes/cpu_adagrad.h new file mode 100644 index 000000000000..660f860917f6 --- /dev/null +++ b/csrc/xpu/includes/cpu_adagrad.h @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#include +#include +#include "simd.h" + +typedef unsigned short ds_half_precision_t; + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + ds_half_precision_t* dev_param = nullptr, \ + bool half_precision = false); + +class Adagrad_Optimizer { +public: + Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) + : _alpha(alpha), _eps(eps), _weight_decay(weight_decay) + { + } + ~Adagrad_Optimizer() {} +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t param_size, + ds_half_precision_t* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step) + { + _step++; + if (_step != step) { _step = step; } + } + inline void update_state(float lr, float epsilon, float weight_decay) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + } + +private: + float _alpha; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay4; + if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i, half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, grads + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i, half_precision); + + if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } + + simd_fma(variance_4, grad_4, grad_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_add(grad_4, grad_4, eps_4); + simd_div(grad_4, momentum_4, grad_4); + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + i, param_4, half_precision); + simd_store(_exp_avg_sq + i, variance_4, false); + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/xpu/includes/cpu_adam.h b/csrc/xpu/includes/cpu_adam.h new file mode 100644 index 000000000000..7c11fa863219 --- /dev/null +++ b/csrc/xpu/includes/cpu_adam.h @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#include +#include +#include +#include "simd.h" + +#include +typedef unsigned short ds_half_precision_t; + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + ds_half_precision_t* dev_param = nullptr, \ + bool half_precision = false); + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) + { + } + ~Adam_Optimizer() {} + +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + ds_half_precision_t* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + if (step == _step + 1) { // first optimizer step increase + _step++; + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } else if (step == + _step) { // no need to update step; beta1_t and beta2_t already updated + return; + } else { // support step increase not equal to 1 + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adam_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + int rshft = half_precision ? 1 : 0; + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + (i >> rshft), half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + (i >> rshft), half_precision); + + if (_weight_decay > 0 && !_adamw_mode) { + simd_fma(grad_4, param_4, weight_decay4, grad_4); + } + + simd_mul(momentum_4, momentum_4, betta1_4); + simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); + simd_mul(variance_4, variance_4, betta2_4); + simd_mul(grad_4, grad_4, grad_4); + simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); + simd_div(grad_4, momentum_4, grad_4); + + if (_weight_decay > 0 && _adamw_mode) { + simd_fma(param_4, param_4, weight_decay4, param_4); + } + + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + (i >> rshft), param_4, half_precision); + simd_store(_exp_avg + i, momentum_4, false); + simd_store(_exp_avg_sq + i, variance_4, false); + } + } + *rounded_size = new_rounded_size; +} +#endif + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false); + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq); + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params); + +int destroy_adam_optimizer(int optimizer_id); diff --git a/csrc/xpu/includes/simd.h b/csrc/xpu/includes/simd.h new file mode 100644 index 000000000000..097e2d8585cc --- /dev/null +++ b/csrc/xpu/includes/simd.h @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define TILE (128 * 1024 * 1024) +#if defined(__AVX512__) or defined(__AVX256__) + +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_AND(x, y) _mm512_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm512_or_ps(x, y) +#define SIMD_XOR(x, y) _mm512_xor_ps(x, y) +#define SIMD_WIDTH 16 + +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm512_storeu_ps(x, d)) + +#define INTV __m256i +#elif defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_AND(x, y) _mm256_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm256_or_ps(x, y) +#define SIMD_XOR(x, y) _mm256_xor_ps(x, y) +#define SIMD_WIDTH 8 + +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm256_storeu_ps(x, d)) + +#define INTV __m128i +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) + __m256 data; +#endif + // float data_f[16]; +}; + +template +inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +{ + size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); +#pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } +} +template +inline void simd_load(AVX_Data* dst, float* src, bool half_precision) +{ + size_t width = (half_precision ? 1 : SIMD_WIDTH); +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); + } +} +template +inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data); + } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); } +} + +#endif diff --git a/csrc/xpu/includes/type_shim.h b/csrc/xpu/includes/type_shim.h new file mode 100644 index 000000000000..1897afd1fea2 --- /dev/null +++ b/csrc/xpu/includes/type_shim.h @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ +#include +/* #include */ +#include + +// Forward/backward compatibility hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +template +__inline__ __attribute__((always_inline)) T +reduce_block_into_lanes(T* x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2); + int blockSize = item_ct1.get_local_range(2) * + item_ct1.get_local_range(1); // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + /* + DPCT1118:1: SYCL group functions and algorithms must be encountered in converged control + flow. You may need to adjust the code. + */ + /* + DPCT1065:6: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if + there is no access to global memory. + */ + item_ct1.barrier(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + /* + DPCT1118:2: SYCL group functions and algorithms must be encountered in converged control + flow. You may need to adjust the code. + */ + /* + DPCT1065:7: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if + there is no access to global memory. + */ + item_ct1.barrier(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + /* + DPCT1118:3: SYCL group functions and algorithms must be encountered in converged control + flow. You may need to adjust the code. + */ + /* + DPCT1065:8: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if + there is no access to global memory. + */ + item_ct1.barrier(); + } + + return final; +} diff --git a/csrc/xpu/packbits/packing.cpp b/csrc/xpu/packbits/packing.cpp new file mode 100644 index 000000000000..d07fa4575e52 --- /dev/null +++ b/csrc/xpu/packbits/packing.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include + +using namespace sycl; + +void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1) +{ + // get the sign bit of each float and pack them into byte + int i = item_ct1; + for (int j = 0; j < 8; ++j) { + int k = i * 8 + j; + int bit = k < input_size && (!sycl::signbit(input[k])); + output[i] |= bit << (7 - j); + } +} + +void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1) +{ + // use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1 + int i = item_ct1; + output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2; +} + +sycl::queue get_current_queue(at::Device device) +{ + c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream(device.index()); + return stream.queue(); +} + +/* +pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8 +if float x >= 0, will be packed as a '1' bit, or will be packed as '0' +Arguments: + tensor: A bool tensor that get packed. + input_size: numel of input tensor + rank: device id in order to get corresponding stream +*/ +at::Tensor packbits(at::Tensor tensor, int input_size, int rank) +{ + at::Device device = "xpu:" + std::to_string(rank); + sycl::queue q = get_current_queue(device); + + int packed_size = (input_size + 7) / 8; + auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU); + at::Tensor packed = torch::zeros({packed_size}, unit8_options); + + float* input = (float*)tensor.data_ptr(); + uint8_t* output = (uint8_t*)packed.data_ptr(); + + auto event = q.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) { + packbitskernel(input, output, input_size, item_ct1); + }); + }); + + return packed; +} + +/* +unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float +a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1). +Arguments: + tensor: A uint8 tensor that get unpacked. + input_size: numel of input tensor + rank: device id in order to get corresponding stream +*/ +at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank) +{ + at::Device device = "xpu:" + std::to_string(rank); + sycl::queue q = get_current_queue(device); + + auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU); + at::Tensor unpacked = torch::empty({input_size * 8}, float_options); + + uint8_t* input = (uint8_t*)tensor.data_ptr(); + float* output = (float*)unpacked.data_ptr(); + + auto event = q.submit([&](sycl::handler& cgh) { + cgh.parallel_for<>(range(input_size * 8), + [=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); }); + }); + + return unpacked; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)"); + m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)"); +} diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 12f26b1927af..0d53a172e64e 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -12,11 +12,23 @@ from torch.optim.lr_scheduler import _LRScheduler from packaging import version as pkg_version +# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed +if not (hasattr(torch.version, 'hip') and torch.version.hip is not None): + try: + import triton # noqa: F401 # type: ignore + HAS_TRITON = True + except ImportError: + HAS_TRITON = False +else: + HAS_TRITON = False + from . import ops from . import module_inject +from .accelerator import get_accelerator +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable -from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER +from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine @@ -25,17 +37,19 @@ from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError from .runtime.activation_checkpointing import checkpointing from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig -from .module_inject import replace_transformer_layer, revert_transformer_layer +from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode from .utils import log_dist, OnDevice, logger from .comm.comm import init_distributed -from .runtime import zero -from .runtime import DeepSpeedOptimizer, ZeROOptimizer +from .runtime import zero, domino +from .runtime.compiler import is_compile_supported from .pipe import PipelineModule from .git_version_info import version, git_hash, git_branch +from .runtime.tensor_parallel.init_utils import (load_ds_config, merge_tp_model_init_into_config, + record_tp_model_init_args) def _parse_version(version_str): @@ -50,6 +64,18 @@ def _parse_version(version_str): __git_hash__ = git_hash __git_branch__ = git_branch +# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init +dist = None + + +def set_optimizer_flags(config_class, model): + if config_class.optimizer_name == MUON_OPTIMIZER: + for name, p in model.named_parameters(): + if p.ndim >= 2 and not any(keyword in name.lower() for keyword in ("embed", "lm_head")): + setattr(p, "use_muon", True) + else: + setattr(p, "use_muon", False) + def initialize(args=None, model: torch.nn.Module = None, @@ -57,10 +83,12 @@ def initialize(args=None, model_parameters: Optional[torch.nn.Module] = None, training_data: Optional[torch.utils.data.Dataset] = None, lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, mpu=None, dist_init_required: Optional[bool] = None, collate_fn=None, config=None, + mesh_param=None, config_params=None): """Initialize the DeepSpeed Engine. @@ -81,6 +109,8 @@ def initialize(args=None, lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object. The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods + distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training + mpu: Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}() @@ -119,6 +149,14 @@ def initialize(args=None, assert model is not None, "deepspeed.initialize requires a model" + global dist + from deepspeed import comm as dist + dist_backend = get_accelerator().communication_backend_name() + dist.init_distributed(dist_backend=dist_backend, + distributed_port=distributed_port, + dist_init_required=dist_init_required) + + ##TODO: combine reuse mpu as mesh device and vice versa # Set config using config_params for backwards compat if config is None and config_params is not None: config = config_params @@ -127,8 +165,8 @@ def initialize(args=None, if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config is - None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + assert (args.deepspeed_config + is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config args.deepscale_config = None @@ -136,10 +174,30 @@ def initialize(args=None, if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None: assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call" config = args.deepspeed_config - assert config != None, "DeepSpeed requires --deepspeed_config to specify configuration file" + assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file" + if not isinstance(config, dict): + config = load_ds_config(config) + + mesh_device = None + if mesh_param: + logger.info(f"mesh_param to Initialize mesh device: {mesh_param}") + mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel")) + #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device + else: + if "sequence_parallel_size" in config and "data_parallel_size" in config: + logger.info(f"config to Initialize mesh device: {config}") + mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \ + ("data_parallel", "sequence_parallel")) + + merge_tp_model_init_into_config(config, mpu, mesh_param, dist) + + autotp_size = config.get("tensor_parallel", {}).get("autotp_size", 0) + if autotp_size and autotp_size > 0: + set_autotp_mode(training=True) if not isinstance(model, PipelineModule): - config_class = DeepSpeedConfig(config, mpu) + config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device) + set_optimizer_flags(config_class, model) if config_class.hybrid_engine.enabled: engine = DeepSpeedHybridEngine(args=args, model=model, @@ -163,11 +221,13 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, + mesh_device=mesh_device, config_class=config_class) else: assert mpu is None, "mpu must be None with pipeline parallelism" mpu = model.mpu() config_class = DeepSpeedConfig(config, mpu) + set_optimizer_flags(config_class, model) engine = PipelineEngine(args=args, model=model, optimizer=optimizer, @@ -180,7 +240,15 @@ def initialize(args=None, config=config, config_class=config_class) - return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler] + # Restore zero.Init context if necessary + zero.partition_parameters.restore_init_context() + + return_items = [ + engine, + engine.optimizer, + engine.training_dataloader, + engine.lr_scheduler, + ] return tuple(return_items) @@ -216,12 +284,6 @@ def _add_core_arguments(parser): type=str, help='Deprecated DeepSpeed json configuration file.') - group.add_argument('--deepspeed_mpi', - default=False, - action='store_true', - help="Run via MPI, this will attempt to discover the necessary variables to initialize torch " - "distributed from the MPI environment") - return parser @@ -274,7 +336,7 @@ def init_inference(model, config=None, **kwargs): .. code-block:: python generator.model = deepspeed.init_inference(generator.model, - mp_size=world_size, + tensor_parallel={"tp_size": world_size}, dtype=torch.half, replace_with_kernel_inject=True) string = generator("DeepSpeed is") @@ -324,3 +386,69 @@ def init_inference(model, config=None, **kwargs): engine = InferenceEngine(model, config=ds_inference_config) return engine + + +def tp_model_init(model, tp_size, dtype, config=None, **kwargs): + """ + Record tensor-parallel initialization arguments for training. + + Note (compatibility and initialization behavior): + AutoTP sharding is applied during ``deepspeed.initialize(...)``. This + function exists for backward compatibility and only records TP arguments so + they can be validated and merged with the DeepSpeed config at initialization. + When you use both (i.e., calling ``set_autotp_mode(training=True)`` and + ``deepspeed.tp_model_init`` while also passing the config to + ``deepspeed.initialize``), DeepSpeed merges the settings at initialization. + Conflicting settings raise an error. The table below summarizes the behavior + across input combinations. + + Inputs: + - TPI: tp_model_init was called? (Y/N) + - TPG: tp_model_init provided tp_group? (Y/N) + - CFG: tensor_parallel in DeepSpeed config? (Y/N) + - MPU: mpu passed to deepspeed.initialize()? (Y/N) + + | TPI | TPG | CFG | MPU | Outcome | Notes | + |-----|-----|-----|-----|----------------------------------------|-------| + | N | N | N | N | Error | No TP intent; nothing to initialize | + | N | N | N | Y | No AutoTP | mpu may be used for other MP, but TP not enabled | + | N | N | Y | N | Init AutoTP from config | Use config; need TP group via config-driven init | + | N | N | Y | Y | Init AutoTP from config | mpu used to build TP group | + | Y | N | N | N | Error | No TP group source | + | Y | N | N | Y | Init AutoTP from tp_model_init | Use recorded args + mpu for TP group | + | Y | N | Y | N | Init AutoTP from config | Fill missing from TPI; error on mismatches; need TP group source | + | Y | N | Y | Y | Init AutoTP from config | Fill missing from TPI; error on mismatches | + | Y | Y | N | N | Init AutoTP from tp_model_init | Use recorded tp_group; config absent | + | Y | Y | N | Y | Error | tp_group + mpu conflict | + | Y | Y | Y | N | Init AutoTP from config | Error on mismatches; use tp_group from TPI; reject mpu | + | Y | Y | Y | Y | Error | tp_group + mpu conflict | + + Field-level merge rules when both tp_model_init and config exist: + - Canonical source: config + - Allowed: fill missing config fields from tp_model_init + - Error on mismatch: autotp_size, dtype, tp_group size or identity + + Extra checks: + - If tp_group is provided, reject mpu. + - If tp_group is not provided, require mpu (or another TP group source). + - If tensor_parallel is absent and only tp_model_init was called, require + a TP group source (direct tp_group or mpu). + + Args: + model (torch.nn.Module): The model to be initialized. + tp_size (int): The tensor parallelism size. + dtype (torch.dtype): The data type to be used for the model. + + Returns: + torch.nn.Module: The original model (no sharding applied here). + """ + if hasattr(model, 'ds_autotp_parsed'): + logger.warning("ds_autotp_parsed' attribute already exists in the model; tp_model_init is now record-only.") + + tp_group = kwargs.get("tp_group") + record_tp_model_init_args(tp_size=tp_size, dtype=dtype, tp_group=tp_group, dist_module=dist) + + # Keep AutoTP training mode active for backward compatibility. + set_autotp_mode(training=True) + + return model diff --git a/deepspeed/autotuning/README.md b/deepspeed/autotuning/README.md index 2cb73b01318a..1a9adfede948 100755 --- a/deepspeed/autotuning/README.md +++ b/deepspeed/autotuning/README.md @@ -94,7 +94,7 @@ Note that ZeRO stages, micro-batch sizes, and other ZeRO configurations to tune The DeepSpeed Autotuner tunes ZeRO stages, micro-batch size per GPU, and ZeRO configurations. Other DeepSpeed configurations are used as defined by the user in the DeepSpeed configuration file. Users can overwrite any of the tuning parameters. ### Configuring ZeRO Stage -By default, the DeepSpeed Autotuner tunes ZeRO stages. If `"zero_optimization"` is not defined or set to `"all"`, the Autotuner explores ZeRO stages in the order of `[0, 1, 2, 3]`. Users can overwrite this behavior if they already know what ZeRO stage(s) to use. For example, the below section in the DeepSpeed configuration file limits the Autotuner to only exploring ZeRO stage 2 and 3. +By default, the DeepSpeed Autotuner does not tune ZeRO stages. If `"zero_optimization"` is not defined, DeepSpeed ZeRO is disabled. If `"zero_optimization"` is set to `"all"`, the Autotuner explores ZeRO stages in the order of `[0, 1, 2, 3]`. Users can overwrite this behavior if they already know what ZeRO stage(s) to use. For example, the below section in the DeepSpeed configuration file limits the Autotuner to only exploring ZeRO stage 2 and 3. ```json { @@ -214,7 +214,7 @@ If `"stage"` is not defined or set as `"all"`, then the overwriting applies to a Currently, the DeepSpeed Autotuner does not tune offloading behaviors but instead uses the values defined in the offload section of the DeepSpeed configuration file. See [Parameter offloading](https://www.deepspeed.ai/docs/config-json/#parameter-offloading) and [Optimizer offloading](https://www.deepspeed.ai/docs/config-json/#optimizer-offloading) for details. -If using NVME for offloading, users can run a benchmark offline to select the optimal `aio` setup in DeepSpeed. Refer to [profiling NVMe and configuring aio param section](https://github.com/microsoft/DeepSpeed/issues/998). +If using NVME for offloading, users can run a benchmark offline to select the optimal `aio` setup in DeepSpeed. Refer to [profiling NVMe and configuring aio param section](https://github.com/deepspeedai/DeepSpeed/issues/998). ## Autotuning Output @@ -336,13 +336,13 @@ The Autotuner stops exploring the space when any of the following conditions mee ## Using Autotuning with Hugging Face -Hugging Face users can set some configurations values to ["auto"](https://huggingface.co/transformers/main_classes/deepspeed.html?highlight=gradient_accumulation_steps#shared-configuration). +Hugging Face users can set some configurations values to ["auto"](https://huggingface.co/docs/transformers/deepspeed#deepspeed-and-trainer-parameters). `"auto"` means the value will be set to the default in Hugging Face or be overwritten using the supplied values from the command line arguments. In DeepSpeed Autotuning, if the user-provided DeepSpeed configuration file has "auto" keywords, they are treated as the value "auto". ## GPT2-large Example -This section shows an example of using DeepSpeed autotuning. For more examples, refer to [autotuning](https://github.com/microsoft/DeepSpeedExamples/tree/master/autotuning) in the DeepSpeedExamples repo. +This section shows an example of using DeepSpeed autotuning. For more examples, refer to [autotuning](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/autotuning) in the DeepSpeedExamples repo. Example training script: @@ -412,4 +412,4 @@ Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulati | ---------- | -------------------- | ------------------------ | ------------------------------ | | GPT2-large | 27.874 (mbs = 1) | 56.797 (z = 1, mbs = 2), | 69.061 (z = 1, mbs = 3) | -As we can see the DeepSpeed Autotuner can select a better than hand-tuned configuration with a reasonable number of experiments. Examples in [Autotuning Hugging Face Examples](https://github.com/microsoft/DeepSpeedExamples/tree/master/autotuning/hf#autotuning-hugging-face-examples) would demonstrate the effectiveness of autotuning across different models. +As we can see the DeepSpeed Autotuner can select a better than hand-tuned configuration with a reasonable number of experiments. Examples in [Autotuning Hugging Face Examples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/autotuning/hf#autotuning-hugging-face-examples) would demonstrate the effectiveness of autotuning across different models. diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index 3f2b01e14668..cd171f4c0aa7 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -50,7 +50,7 @@ def __init__(self, args, active_resources): assert tabulate is not None, "Missing required package `tabulate`, please install with `pip install deepspeed[autotuning]`." - logger.debug(f"autotunning args={args}") + logger.debug(f"autotuning args={args}") self.user_config = self._get_user_config(args.user_args) assert self.user_config is not None, "DeepSpeed configuration is not provided" @@ -69,9 +69,9 @@ def __init__(self, args, active_resources): try: os.makedirs(self.exps_dir, exist_ok=True) logger.info(f"Created autotuning experiments directory: {self.exps_dir}") - except: + except Exception: logger.error( - f"Failed to create {self.exps_dir}, please check `exps_dir` in the autotuning config file is accessible by all the nodes in the job." + f"Failed to create {self.exps_dir}, please check exps_dir in the autotuning config file is accessible by all the nodes in the job." ) exit(-1) @@ -81,10 +81,10 @@ def __init__(self, args, active_resources): if not os.path.exists(self.results_dir): try: os.makedirs(self.results_dir, exist_ok=True) - logger.info(f"Created autotuning resutls directory: {self.exps_dir}") - except: + logger.info(f"Created autotuning results directory: {self.results_dir}") + except Exception: logger.error( - f"Failed to create {self.results_dir}, please check `results_dir` in the autotuning config file is accessible by all the nodes in the job." + f"Failed to create {self.results_dir}, please check results_dir in the autotuning config file is accessible by all the nodes in the job." ) exit(-1) @@ -101,7 +101,7 @@ def __init__(self, args, active_resources): self.records = {} self.optimal_cmd = None - self.optmal_ds_config = None + self.optimal_ds_config = None self.mlflow_parent_id = None @@ -145,7 +145,7 @@ def print_tuning_results(self): f"{best_exp['name']} is the optimal setup after tuning. The exp result is at {best_exp['result_dir']}." ) else: - logger.info(f"No optimal setup is found. Please check that experiments were run successfully.") + logger.info("No optimal setup is found. Please check that experiments were run successfully.") tuning_duration = datetime.timedelta(seconds=(time.time() - self.start_time)) logger.info(f"Tuning completed in {tuning_duration}") @@ -248,8 +248,8 @@ def mp_size(self): return self.autotuning_config.mp_size def max_train_micro_batch_size_per_gpu(self): - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // ( self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1 return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size) @@ -410,7 +410,7 @@ def tune(self): self.start_time = time.time() if self.fast_enabled(): - logger.info(f"Fast mode is enabled. Tuning micro batch size only.") + logger.info("Fast mode is enabled. Tuning micro batch size only.") # model info profile run with DEFAULT_MIN_MEM_CONFIG model_info = self.model_info_profile_run() @@ -429,9 +429,8 @@ def tune(self): f"The model requires at least {memory_to_string(self.activation_mem, postfix='B')} activation memory for micro batch size 1." ) - #TODO: FIX THIS - stage = self.user_config.get(ZERO_OPTIMIZATION, {}).get(ZERO_OPTIMIZATION_STAGE, "all") - stage = "all" + stage = self.user_config.get(ZERO_OPTIMIZATION, {}).get(ZERO_OPTIMIZATION_STAGE, 0) + user_zero_stages = [stage] if not isinstance(stage, list) else stage logger.info(f"User-defined zero stages are {stage}.") @@ -638,7 +637,7 @@ def tune_space(self, tuning_space, prev_max_mbs=0, prev_best_mbs=0, prev_best_me logger.info(f"End tuning for space: {tuning_space_name}") return max_micro_batch_size, best_mbs, best_metric_val - def get_plauteu_mbs(self, tuning_space_name): + def get_plateau_mbs(self, tuning_space_name): if tuning_space_name not in self.records: return 0 space_records = self.records[tuning_space_name] @@ -662,7 +661,7 @@ def get_model_num_params(self): return self.model_info["num_params"] def model_info_profile_run(self): - """Does a model information profling experiment that collects the number of model parameters and activation memory.\ + """Does a model information profiling experiment that collects the number of model parameters and activation memory.\ The experiment produces a "profile_model_info" folder under self.results_dir. Returns: [dict]: a model information dictionary, e.g., {"num_params": 335144976, "trainable_num_params": 335144976, "activation_mem_per_gpu": 324358144, "rank": 0} @@ -684,6 +683,7 @@ def model_info_profile_run(self): exp_config[DS_CONFIG] = ds_config exp_config['num_gpus'] = self.exp_num_gpus exp_config['num_nodes'] = self.exp_num_nodes + exp_config['hostfile'] = self.args.hostfile exp_path = os.path.join(self.exps_dir, f'{exp_name}.json') with open(exp_path, 'w', buffering=BUFSIZE) as fd: @@ -762,6 +762,7 @@ def run_tuning_micro_batch_sizes(self, tuning_micro_batch_sizes, max_train_batch exp_config[DS_CONFIG] = ds_config exp_config['num_gpus'] = self.exp_num_gpus exp_config['num_nodes'] = self.exp_num_nodes + exp_config['hostfile'] = self.args.hostfile exp_path = os.path.join(self.exps_dir, f'{exp_name}.json') with open(exp_path, 'w', buffering=BUFSIZE) as fd: @@ -802,7 +803,7 @@ def run_tuning_micro_batch_sizes(self, tuning_micro_batch_sizes, max_train_batch if tuning_micro_batch_sizes_overwritten: return tuning_micro_batch_sizes - # in a auto-detected tuning_micro_batch_sizs list, max_micro_batch_size might not be performant as the memory consumption is close to max + # in a auto-detected tuning_micro_batch_sizes list, max_micro_batch_size might not be performant as the memory consumption is close to max # try smaller values while gas stays the same # if finding a more performant mbs value, use it to replace max_micro_batch_size in the list min_micro_batch_size_with_same_gas = (tuning_micro_batch_sizes[-2] + @@ -963,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) - if prev_metric_val and ( - (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST: + if prev_metric_val and ((metric_val - prev_metric_val) / + prev_metric_val) < METRIC_PERCENT_DIFF_CONST: logger.info(f"performance plateaus at mbs = {low}") break prev_metric_val = metric_val @@ -986,7 +987,7 @@ def get_gas_from_user_config(self): if isinstance(gas_in_config, int): gas = gas_in_config elif gas_in_config == "auto": # GRADIENT_ACCUMULATION_STEPS: "auto" - val = self.get_val_from_config(GRADIENT_ACCUMULATION_STEPS) + val = self.get_val_from_user_args(GRADIENT_ACCUMULATION_STEPS) if val: gas = int(val) elif isinstance(gas_in_config, list): @@ -1025,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} )) # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) )) - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus * self.exp_num_nodes) else: @@ -1056,6 +1057,7 @@ def run_ds_config(self, ds_config, exp_name): exp_config[DS_CONFIG] = ds_config exp_config['num_gpus'] = self.exp_num_gpus exp_config['num_nodes'] = self.exp_num_nodes + exp_config['hostfile'] = self.args.hostfile exp_path = os.path.join(self.exps_dir, f'{exp_name}.json') logger.debug(f'run_ds_config exp_name = {exp_name}') @@ -1093,14 +1095,14 @@ def write_optimal_config(self): fd.write("\n") fd.flush() self.optimal_cmd = cmd - self.optmal_ds_config = ds_config + self.optimal_ds_config = ds_config logger.info( f"Wrote the optimal DeepSpeed configuration found by autotuning to {ds_config_path}, and the corresponding DeepSpeed command to {cmd_path}" ) def run_after_tuning(self): """ Launches the training with the optimal DeepSpeed configuration found through the autotuning process. - "ds_config_optimal.json" describing the optmimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir. + "ds_config_optimal.json" describing the optimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir. """ if self.optimal_cmd: result = subprocess.Popen(self.optimal_cmd) @@ -1108,4 +1110,4 @@ def run_after_tuning(self): logger.info(f"Done running with the optimal DeepSpeed configuration using {self.optimal_cmd}") else: - logger.info(f"No optimal DeepSpeed configuration found by autotuning.") + logger.info("No optimal DeepSpeed configuration found by autotuning.") diff --git a/deepspeed/autotuning/constants.py b/deepspeed/autotuning/constants.py index e6a62e32fde3..9f828608c852 100644 --- a/deepspeed/autotuning/constants.py +++ b/deepspeed/autotuning/constants.py @@ -4,7 +4,7 @@ # DeepSpeed Team ######################################### -# autotunner implementation constants +# autotuner implementation constants ######################################### import os @@ -117,7 +117,7 @@ MODEL_INFO_PROFILE_DEFAULT = False MODEL_INFO_NUM_PARAMS = "num_params" MODEL_INFO_NUM_PARAMS_DEFAULT = None -MODEL_INFO_HIDDEN_SIZE = "hideen_size" +MODEL_INFO_HIDDEN_SIZE = "hidden_size" MODEL_INFO_HIDDEN_SIZE_DEFAULT = None MODEL_INFO_NUM_LAYERS = "num_layers" MODEL_INFO_NUM_LAYERS_DEFAULT = None @@ -130,7 +130,7 @@ } ######################################### -# autotunner search space constants +# autotuner search space constants ######################################### DEFAULT_HF_CONFIG = { @@ -144,7 +144,7 @@ "zero_optimization": { "stage": 3 }, - "memory_break_down": False + "memory_breakdown": False } DEFAULT_TUNING_SPACE_ZERO_0 = {"zero_optimization": {"stage": 0}} diff --git a/deepspeed/autotuning/scheduler.py b/deepspeed/autotuning/scheduler.py index 40978aa00ab9..14e9541d03a3 100755 --- a/deepspeed/autotuning/scheduler.py +++ b/deepspeed/autotuning/scheduler.py @@ -5,7 +5,6 @@ import copy -from numpy import BUFSIZE import json import subprocess import sys @@ -18,8 +17,8 @@ from tqdm import tqdm from ..utils import logger -from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH -from .utils import get_val_by_key, search_error, was_interruptted +from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH, BUFSIZE +from .utils import get_val_by_key, search_error, was_interrupted """ thread-0: loop over experiment queue dispatching experiments if they become available thread-N: start each experiment in its own thread @@ -77,7 +76,7 @@ def schedule_experiments(self, exp_paths): # skip existing experiments (except for the ones that were interrupted) if os.path.exists(result_dir) and os.path.exists(stderr_file): - if not was_interruptted(stderr_file): + if not was_interrupted(stderr_file): err = search_error(stderr_file) exp_id = exp["exp_id"] self.finished_experiments[exp_id] = (exp, err) @@ -316,7 +315,10 @@ def run_experiment(exp: dict, reservations, user_script, user_args): include_str += f"{reservation.node.host}:{slots}@" include_str = include_str[:-1] master_port = exp["master_port"] + hostfile = exp["hostfile"] exp["launcher_args"] = [ + "--hostfile", + f"{hostfile}", "--include", f"{include_str}", "--master_port", diff --git a/deepspeed/autotuning/tuner/base_tuner.py b/deepspeed/autotuning/tuner/base_tuner.py index 3ac7389810fc..722af4748ffd 100755 --- a/deepspeed/autotuning/tuner/base_tuner.py +++ b/deepspeed/autotuning/tuner/base_tuner.py @@ -39,14 +39,14 @@ def tune(self, sample_size=1, n_trials=1000, early_stopping=None): i = 0 try: while i < n_trials and self.has_next(): - # Select the next batch of configuratiosn for evaluation + # Select the next batch of configuration for evaluation sampled_exps = self.next_batch(sample_size) # Generate experiments for measurement of performance exp_paths = write_experiments(sampled_exps, self.rm.exps_dir) self.rm.schedule_experiments(exp_paths) self.rm.run() exp, metric_val = self.rm.parse_results(self.metric) - if self.best_exp == None or self.best_metric_val == None or (metric_val + if self.best_exp is None or self.best_metric_val is None or (metric_val and metric_val > self.best_metric_val): # logger.info(f"tuner finds better = {exp}") self.best_exp = exp @@ -67,6 +67,6 @@ def tune(self, sample_size=1, n_trials=1000, early_stopping=None): ) break return i - except: - logger.info("Tunner Error:", sys.exc_info()[0]) + except Exception: + logger.info("Tuner Error:", sys.exc_info()[0]) return i diff --git a/deepspeed/autotuning/tuner/model_based_tuner.py b/deepspeed/autotuning/tuner/model_based_tuner.py index 23f224b5eba2..aec9264f9b7c 100755 --- a/deepspeed/autotuning/tuner/model_based_tuner.py +++ b/deepspeed/autotuning/tuner/model_based_tuner.py @@ -19,9 +19,9 @@ class ModelBasedTuner(BaseTuner): """Exploring the search space with a cost model""" - def __init__(self, exps: list, resource_manager, metric, tuning_sapce): + def __init__(self, exps: list, resource_manager, metric, tuning_space): super().__init__(exps, resource_manager, metric) - self.tuning_space = tuning_sapce + self.tuning_space = tuning_space self.best_iter = 0 self.all_configs = [e['ds_config'] for e in exps] diff --git a/deepspeed/autotuning/utils.py b/deepspeed/autotuning/utils.py index dec13ca7f621..53d5d68ac11e 100644 --- a/deepspeed/autotuning/utils.py +++ b/deepspeed/autotuning/utils.py @@ -26,7 +26,7 @@ def search_error(filename): return None -def was_interruptted(filename): +def was_interrupted(filename): if not os.path.exists(filename): return "stderr.log does not exist" with open(filename) as f: @@ -42,7 +42,7 @@ def find_replace_str(value, replace_dict): if not isinstance(value, str): return str(value) - matches = re.findall(r"\$[A-Za-z0-9_]+", value) + matches = re.findall(r"\$[\w]+", value) for var in matches: var_key = var.replace("$", "").lower() if var_key == "nvme_path": @@ -268,7 +268,7 @@ def prune_configs(configs, ignored_keys=[]): def get_tuning_keys(tuning_space: dict): - """Outputs the list of tunnable parameters in the tuning space dict. + """Outputs the list of tunable parameters in the tuning space dict. Args: tuning_space (dict): a configuration dictionary containing tunable parameters as lists of values. @@ -422,7 +422,7 @@ def memory_to_string(n, postfix="", units=None, precision=2): elif n // 10**6 > 0: return str(round(n / 1024**2, precision)) + " M" + postfix elif n // 10**3 > 0: - return str(round(n / 1014, precision)) + " K" + postfix + return str(round(n / 1024, precision)) + " K" + postfix else: return str(n) + " " else: diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index c9822693867d..1f645a494ade 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -15,6 +15,6 @@ from .zero_checkpoint import ZeROCheckpoint -from .universal_checkpoint import enable_universal_checkpoint +from .universal_checkpoint import enable_universal_checkpoint, SubparamShape from .constants import * diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 7735d763f598..dde5b16bd946 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -14,12 +14,15 @@ FP32_FLAT_GROUPS = 'fp32_flat_groups' BASE_OPTIMIZER_STATE = 'base_optimizer_state' +BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step' SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" +PARAM_GROUPS = 'param_groups' GROUP_PADDINGS = 'group_paddings' PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' FP32_WEIGHT_KEY = "fp32" +LOSS_SCALER = 'loss_scaler' ######################################### # Module checkpoint keys @@ -27,6 +30,8 @@ PARAM = 'param' PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' +FROZEN_PARAM_SHAPES = 'frozen_param_shapes' +FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments' ######################################### # Checkpoint naming constants @@ -51,18 +56,34 @@ UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version' # Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2 +# Attribute name used to store AutoTP universal-checkpoint metadata on torch Parameters. +DS_AUTOTP_UC_META = "ds_autotp_universal_checkpoint_meta" # Vocabulary padding -VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor' +VOCAB_TENSOR = 'vocab_tensor' PADDED_VOCAB_SIZE = 'padded_vocab_size' ORIGINAL_VOCAB_SIZE = 'original_vocab_size' # Parameter splitting/merging PARAM_SLICE_MAPPINGS = 'param_slice_mappings' CAT_DIM = "cat_dim" +# Following is a special case where a parameter effectively contains sub parameters. +# As an example, consider Megatron-DeepSpeed GPT SWIGLU implementation (mlp.h_to_4h). +# In this case, a single parameter ia allocated contiguously, but used as separate parameters. +# When using universal checkpoint, we have to normalize the representation of the full parameter. +# We normalize it by concatenating all slices of the sub params and then concatenating the sub params. +# All concat operations are done on CAT_DIM (currently, no support for different concat dims sub params and TP slicing). +# Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal. +PARAM_N_SUB_PARAMS = "param_n_sub_params" + +SUB_PARAM_SHAPE = "sub_param_shape" # Regex list of parameters that require special handling VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns' PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns' PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns' PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns' +TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns' +PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0' +PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params' +SUB_PARAMS_SHAPE = 'sub_params_shape' diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index ef36b0c5ef3f..3f97ec067dbd 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re from typing import Dict import torch @@ -21,6 +22,7 @@ ARGS_KEY = 'args' CHECKPOINT_INFO_KEY = 'checkpoint_info' ITERATION_KEY = 'iteration' +LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*' SEQUENTIAL_LAYERS = [ 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight', @@ -32,9 +34,18 @@ class DeepSpeedCheckpoint(object): - def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): + def __init__(self, + dir, + tp_degree=None, + pp_degree=None, + dp_degree=None, + final_layer_norm_idx=FINAL_LAYER_NORM_INDEX): + self.final_layer_norm_idx = final_layer_norm_idx self.dir = dir - self._validate_folder(dir) + + pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0 + + self._validate_folder(dir, pipeline_parallel) self.zero_checkpoint = ZeROCheckpoint(dir) @@ -70,7 +81,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.pp_to_transformer_map = self._build_pp_transformer_map() self.transformer_file_map = self._build_transformer_file_map() self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) - self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) + self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx) self._build_global_state() def is_change_tp_degree(self): @@ -83,14 +94,14 @@ def is_change_dp_degree(self): return self.dp_degree != self.zero_checkpoint.get_src_dp_degree() def show_2d_mapping(self): - print(f'reshaped 2d map ---- begin') + print('reshaped 2d map ---- begin') for i in range(self.pp_degree): for j in range(self.tp_degree): file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j) print(f'[{i}, {j}] = {file_list}') - print(f'reshaped 2d map ---- end') + print('reshaped 2d map ---- end') def show_tp_embedding_map(self): self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') @@ -98,22 +109,23 @@ def show_tp_embedding_map(self): def show_tp_final_norm_map(self): self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') - def show_pp_tranformer_map(self): - self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') + def show_pp_transformer_map(self): + self._dump_mapping(self.pp_to_transformer_map, 'pp_to_transformer_layers') def show_transformer_file_map(self): - self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') + self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files') def _build_global_state(self): - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) - def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: + def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index, strip_tensor_paddings: bool = True) -> dict: return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index, - keys_to_ignore=[PARAM_SHAPES]) + keys_to_ignore=[PARAM_SHAPES], + strip_tensor_paddings=strip_tensor_paddings) def get_zero_files(self, pp_index, tp_index, dp_index) -> list: return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) @@ -122,18 +134,21 @@ def get_embedding_layer_id(self): return self.layer_keys[EMBEDDING_LAYER_INDEX] def get_final_norm_layer_id(self): - return self.layer_keys[FINAL_LAYER_NORM_INDEX] + return self.layer_keys[self.final_layer_norm_idx] def get_iteration(self): - if not ITERATION_KEY in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + if ITERATION_KEY not in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) return self.global_state[ITERATION_KEY] def get_embedding_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_embedding_map.keys() - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] + sd_list = [ + torch.load(fname, map_location=torch.device('cpu'), weights_only=False) + for fname in self.tp_to_embedding_map[tp_index] + ] sd = self._merge_state_dicts(sd_list) return sd @@ -142,8 +157,8 @@ def get_embedding_files(self, tp_index: int) -> list: return self.tp_to_embedding_map[tp_index] def _get_checkpoint_value(self, key): - if not key in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + if key not in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[key] = sd.get(key, None) return self.global_state[key] @@ -158,7 +173,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: assert tp_index < self.tp_degree assert pp_index < self.pp_degree fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] merged_sd = None for sd in sd_list: @@ -174,7 +189,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert pp_index < self.pp_degree t_list = [] for fname_list in self.transformer_file_map[(tp_index, pp_index)]: - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] sd = self._merge_state_dicts(sd_list) t_list.append(sd) return t_list @@ -185,7 +200,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list: def get_final_norm_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_final_norm_map.keys() - sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False) return sd def get_final_norm_files(self, tp_index: int) -> list: @@ -193,7 +208,10 @@ def get_final_norm_files(self, tp_index: int) -> list: return self.tp_to_final_norm_map[tp_index] def _build_tp_other_layer_map(self, layer_index: int): - assert layer_index < len(self.layer_files) + data_map = {} + if len(self.layer_files) < 1: + return data_map + assert layer_index <= len(self.layer_files) layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) layer_file_partitions = partition_data(layer_files, self.tp_degree) data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} @@ -207,9 +225,13 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: def _build_pp_transformer_map(self): data_map = {} - transformer_layers = self.layer_keys[1:-1] - layers_per_pp = len(transformer_layers) // self.pp_degree - data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)} + if self.pp_degree > 0: + transformer_layers = self.layer_keys[1:self.final_layer_norm_idx] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = { + i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] + for i in range(0, self.pp_degree) + } return data_map def _dump_mapping(self, data_map, map_tag=None): @@ -219,20 +241,20 @@ def _dump_mapping(self, data_map, map_tag=None): print(f'{k} = {v}') def _build_transformer_file_map(self): - transformer_layer_keys = self.layer_keys[1:-1] + transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx] file_map = {} # XXX: this is not guaranteed - layers_per_pp = len(transformer_layer_keys) // self.pp_degree - if layers_per_pp == 0: - layers_per_pp = 1 + layers_per_pp = 1 + if self.pp_degree > 0: + layers_per_pp = len(transformer_layer_keys) // self.pp_degree #print(f"{transformer_layer_keys} {layers_per_pp}") for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp - layer_files = get_files_with_prefix(self.layer_files, layer_key) + layer_files = get_files_with_prefix(self.layer_files, layer_key + '-') layer_file_partitions = partition_data(layer_files, self.tp_degree) for tp_index in range(self.tp_degree): map_key = (tp_index, pp_index) - if not map_key in file_map.keys(): + if map_key not in file_map.keys(): file_map[map_key] = [] file_map[map_key].append(layer_file_partitions[tp_index]) @@ -240,8 +262,8 @@ def _build_transformer_file_map(self): def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 - assert len(self.layer_keys) > 2 assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0 + assert self.zero_checkpoint.num_files % (self.tp_degree) == 0 # XXX: fix me - isn't always the case # only true with --pp-partition-method 'type:transformer|embedding' \ # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 @@ -253,16 +275,18 @@ def validate_files(self): def _get_layer_keys(self): key_set = set() - key_len = len(LAYER_FILE_PREFIX) + 2 for file_path in self.layer_files: _, fname = os.path.split(file_path) - key_set.add(fname[:key_len]) - return sorted(list(key_set)) + layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1) + key_set.add(layer_id) + sorted_ids = sorted(list(key_set), key=int) + layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids] + return layer_keys def _merge_state_dicts(self, sd_list): merged_sd = {} for key in sd_list[0].keys(): - if not key in SEQUENTIAL_LAYERS: + if key not in SEQUENTIAL_LAYERS: cat_dim = LAYER_CONCAT_DIM.get(key, 0) merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) else: @@ -270,12 +294,14 @@ def _merge_state_dicts(self, sd_list): return merged_sd - def _validate_folder(self, dir): + def _validate_folder(self, dir, pipeline_parallel): basic_folder_validation(dir) file_list = get_files(dir) - - for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']: + file_prefix_list = [MODEL_FILE_PREFIX] + if pipeline_parallel: + file_prefix_list.extend([LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']) + for file_prefix in file_prefix_list: ckpt_files = get_files_with_prefix(file_list, file_prefix) assert len( ckpt_files diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py new file mode 100755 index 000000000000..8a39f6bb4c31 --- /dev/null +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from functools import partial +from itertools import chain +import argparse +import glob +import itertools +import math +from concurrent.futures import ProcessPoolExecutor +import os +import re +import shutil +import torch +import tqdm +#from pprint import pprint + +from deepspeed.checkpoint import DeepSpeedCheckpoint +from deepspeed.checkpoint import ( + OPTIMIZER_STATE_DICT, + ZERO_STAGE, + BASE_OPTIMIZER_STATE, + SINGLE_PARTITION_OF_FP32_GROUPS, + PARAM_GROUPS, + PARAM_SLICE_MAPPINGS, + PARAM_SHAPES, + PARAM, + CAT_DIM, + PARAM_N_SUB_PARAMS, + SUB_PARAM_SHAPE, + VOCAB_TENSOR, + UNIVERSAL_CHECKPOINT_INFO, + UNIVERSAL_CHECKPOINT_VERSION_KEY, + UNIVERSAL_CHECKPOINT_VERSION_VALUE, + VOCABULARY_PARAMETER_PATTERNS, + PIPELINE_REPLICATED_PARAMETER_PATTERNS, + TP_REPLICATED_PARAMETER_PATTERNS, + PARAMETER_TO_AVERAGE_PATTERNS, + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, + PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, + PARAMETER_WITH_SUB_PARAMS, + SubparamShape, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder') + parser.add_argument('--num_extract_workers', + default=4, + type=int, + help='How many parallel processes to extract zero shards') + parser.add_argument( + '--num_merge_workers', + default=2, + type=int, + help= + 'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))' + ) + parser.add_argument('--keep_temp_folder', + action='store_true', + help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.') + parser.add_argument('--no_strict', + dest='strict', + action='store_false', + help='Do not perform validity checks on converted checkpoint.') + parser.add_argument('--inject_missing_state', + action='store_true', + help='Inject missing checkpoint state into the checkpoint if it is absent.') + args = parser.parse_args() + print(f'args = {args}') + return args + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): + path_list = [] + iter_folder = f'iter_{iteration:07d}' + for i in range(0, tp_degree): + path_list.append([]) + for j in range(0, pp_degree): + rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' + ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') + path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) + + return path_list + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + +def extract_zero_shards(dir, ds_checkpoint, indices_3D): + pp_index, tp_index, dp_index = indices_3D + sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) + + # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") + + optim_sd = sd[OPTIMIZER_STATE_DICT] + param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) + # print(f'{pipeline_replicated_params=}') + + # dict + state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] + # list + fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] + param_groups_cnt = len(state_groups) + + for param_group_id in range(param_groups_cnt): + + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) + + if "step" in state_groups[param_group_id]: + flat_state["step"] = state_groups[param_group_id]["step"] + + for name, fragment_mapping in param_slice_mappings[param_group_id].items(): + if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): + # Skip tied weights that are replicated in first and last pp stages + continue + + # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") + for state_key in flat_state.keys(): + dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, + fragment_mapping.start, fragment_mapping.numel) + + +def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): + state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) + + for idx, sub_group_shape in enumerate(param_shapes): + flat_state = dict( + exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"], + exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"], + fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx], + ) + offset = 0 + for name, shape in sub_group_shape.items(): + unpartitioned_numel = shape.numel() + partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) + padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) + for state_key in flat_state.keys(): + dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, + padding_free_numel) + offset += partitioned_numel + + +cnt = 0 + + +def dp_index_to_str(dp_index): + return f"{dp_index:0>2d}" + + +def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): + + global cnt # temp hack + + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + + cnt += 1 + + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") + + #print(f"{param_name}: {offset}: {numel} => {path}") + + # State might be a python int or a tensor + if state_name != "step" and torch.is_tensor(state_flat_tensor): + state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() + _save_checkpoint(path, state_flat_tensor) + + +def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): + slices = [] + for tp_index in range(tp_degree): + prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") + paths = glob.glob(f"{prefix_path}.*") + + if len(paths) == 0: + continue + + pattern = re.compile(f"{prefix_path}\\.([0-9]+)") + dp_indices = set() + for p in paths: + m = pattern.match(p) + if m: + dp_indices.add(int(m.group(1))) + else: + raise ValueError(f"Cannot parse dp_rank from {p}") + + paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] + shards = [torch.load(p, weights_only=False) for p in paths] + + if state == "step": + assert all(v == shards[0] for v in shards), "All shards must have the same step value" + slice = shards[0] + else: + if slice_shape is None: + slice = torch.cat(shards, dim=0) + else: + slice = torch.cat(shards, dim=0).reshape(slice_shape) + + slices.append(slice) + return slices + + +def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): + + name, shape = name_and_shape + slice_base_path = os.path.join(slice_dir, name) + param_base_path = os.path.join(dir, name) + + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, []) + parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) + parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) + vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) + parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, []) + parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, []) + + unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism + + vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0) + unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params)) + + def get_matched_pattern(patterns_, name_): + matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)] + assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}' + if matched_: + pattern_ = matched_[0] + unmatched_patterns.discard(pattern_) + return pattern_ + return None + + def get_matched_sub_params_pattern(name_): + for subparam_shape_dict in parameter_with_sub_params: + subparam_shape = SubparamShape(**subparam_shape_dict) + for pattern_ in subparam_shape.patterns: + if re.match(pattern_, name_): + unmatched_patterns.discard(pattern_) + return subparam_shape + return None + + matched_sub_params_shape = get_matched_sub_params_pattern(name) + + step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape) + if step_merged: + _save_checkpoint(os.path.join(param_base_path, "step.pt"), step_merged[0]) + + for state in ("fp32", "exp_avg", "exp_avg_sq"): + slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) + final_path = os.path.join(param_base_path, f"{state}.pt") + + #print(f"Expected shape: {shape}") + #print(f"Fragment sizes:", list(frag.shape for frag in slices)) + ckpt_dict = {} + if get_matched_pattern(replicated_parameters, name): + if len(slices) > 1: + assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) + param = slices[0] + # print(f'replicate {name} using first slice') + elif get_matched_pattern(parameters_to_average, name): + param = sum(slices) / len(slices) + # print(f'merge {name} using average') + elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name): + cat_dim = 0 + chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices] + merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim) + merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim) + param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim) + ckpt_dict[CAT_DIM] = cat_dim + ckpt_dict[PARAM_N_SUB_PARAMS] = 2 + elif matched_sub_params_shape: + merged_chunks = [] + partition_dim = matched_sub_params_shape.partition_dim + + sub_dim_sizes = matched_sub_params_shape.shape[partition_dim] + if not isinstance(sub_dim_sizes, tuple): + sub_dim_sizes = (sub_dim_sizes, ) + + partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape] + partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)] + slices = [s.view(partition_shape) for s in slices] + + offset = 0 + for sub_dim_size in sub_dim_sizes: + part_sub_dim_size = sub_dim_size // tp_degree + merged_chunks.append( + torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim)) + offset += part_sub_dim_size + param = torch.cat(merged_chunks, dim=partition_dim) + ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape + else: + cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0 + # print(f"merge {name} with CAT DIM: {cat_dim}") + param = torch.cat(slices, dim=cat_dim) + ckpt_dict[CAT_DIM] = cat_dim + + if get_matched_pattern(vocabulary_parameters, name): + #print(f"Before {param.shape=}") + # strip padding + original_vocab_size = universal_checkpoint_info['original_vocab_size'] + param = param[:original_vocab_size, :] + ckpt_dict[VOCAB_TENSOR] = True + #print(f"After {param.shape=}") + + #print(f"Final shape: {param.shape}") + ckpt_dict[PARAM] = param + _save_checkpoint(final_path, ckpt_dict) + + return unmatched_patterns + + +def merge_zero3_slices(dp_degree, dir, slice_dir, name): + slice_base_path = os.path.join(slice_dir, name) + param_base_path = os.path.join(dir, name) + + for state in ("fp32", "exp_avg", "exp_avg_sq"): + slices = _merge_zero_shards(slice_base_path, state, 1) + final_path = os.path.join(param_base_path, f"{state}.pt") + _save_checkpoint(final_path, slices[0]) + + +def _do_parallel_work(do_work, work_chunks, num_workers): + results = [] + if num_workers > 1: + with ProcessPoolExecutor(max_workers=num_workers) as executor: + future_list = [executor.submit(do_work, work) for work in work_chunks] + for f in tqdm.tqdm(future_list): + results.append(f.result()) + else: + # No parallel pass for unit testing + # We can't create child processes in tests + for work in tqdm.tqdm(work_chunks): + results.append(do_work(work)) + return results + + +def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): + _3d_range_list = list( + itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), + range(ds_checkpoint.dp_degree))) + #pprint(f'{_3d_range_list=}') + + do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) + _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) + + +def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir): + do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir) + _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) + + +def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): + zero_output_folder = os.path.join(args.output_folder, "zero") + do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) + unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers) + + # verify that all patterns were used + # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert + sets = [set(lst) for lst in unmatched_patterns_lists] + unmatched_patterns = list(set.intersection(*sets)) + if args.strict: + assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices' + elif unmatched_patterns: + print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices') + + +def _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir): + zero_output_folder = os.path.join(args.output_folder, "zero") + do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir) + _do_parallel_work(do_work, param_keys, args.num_merge_workers) + + +def _zero_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _parse_model_states_stage3(files): + return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES] + + +def _save_optimizer_state(args, ds_checkpoint): + sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS] + sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0) + + optim_sd = sd[OPTIMIZER_STATE_DICT] + output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} + output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS] + zero_output_folder = os.path.join(args.output_folder, "zero") + output_file_path = os.path.join(zero_output_folder, "optimizer_state.pt") + _save_checkpoint(output_file_path, output_sd) + + +def _save_optimizer_state_stage3(args, optim_files): + sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) + output_sd = sd[OPTIMIZER_STATE_DICT] + output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS] + zero_output_folder = os.path.join(args.output_folder, "zero") + output_file_path = os.path.join(zero_output_folder, "optimizer_state.pt") + _save_checkpoint(output_file_path, output_sd) + + +def _get_optim_files(checkpoint_dir): + return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def _get_model_state_files(checkpoint_dir): + return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def _get_checkpoint_files(checkpoint_dir, glob_pattern): + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def _get_zero_stage(optim_files): + state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) + optimizer_state = state_dict[OPTIMIZER_STATE_DICT] + zero_stage = optimizer_state.get(ZERO_STAGE, 1) + return zero_stage + + +def _inject_missing_state(ds_checkpoint): + if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: + sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) + if UNIVERSAL_CHECKPOINT_INFO not in sd: + ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} + ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ + UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE + + +def _check_for_required_state(ds_checkpoint): + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' + + +def main(args): + print('Convert DeepSpeed Checkpoint to Universal Checkpoint') + + print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') + + optim_files = _get_optim_files(args.input_folder) + zero_stage = _get_zero_stage(optim_files) + + if zero_stage <= 2: + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) + if args.inject_missing_state: + _inject_missing_state(ds_checkpoint) + else: + _check_for_required_state(ds_checkpoint) + + iteration = ds_checkpoint.get_iteration() + #_create_latest_file(args.output_folder, iteration) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, + ds_checkpoint.pp_degree) + + slice_shapes = [] + for mp_rank_file in ds_checkpoint.mp_rank_files: + mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False) + slice_shapes += mp_sd[PARAM_SHAPES] + + # fix back to normal flat dict, merge duplicates for tp>1 + slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) + temp_dir = os.path.join(args.output_folder, 'tmp') + + print('*** 1. Extracting ZeRO fragments') + _extract_zero_shard_files(args, ds_checkpoint, temp_dir) + + print('*** 2. Merging slices .....') + _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) + + print('*** 3. Saving common optimizer states') + _save_optimizer_state(args, ds_checkpoint) + + if not args.keep_temp_folder: + shutil.rmtree(temp_dir, ignore_errors=True) + + # Copy mp* files into output folder + for f in glob.glob(os.path.join(args.input_folder, 'mp*')): + shutil.copy2(f, args.output_folder) + + else: + model_files = _get_model_state_files(args.input_folder) + param_shapes = _parse_model_states_stage3(model_files) + dp_degree = len(model_files) + + temp_dir = os.path.join(args.output_folder, 'tmp') + + print('*** 1. Extracting ZeRO fragments') + _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) + + print('*** 2. Merging slices .....') + param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()} + _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir) + + print('*** 3. Saving common optimizer states') + _save_optimizer_state_stage3(args, optim_files) + + if not args.keep_temp_folder: + shutil.rmtree(temp_dir, ignore_errors=True) + + # Copy *model_states files into output folder + for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')): + shutil.copy2(f, args.output_folder) + + # Update latest to output folder + checkpoint_root_folder, step_folder = os.path.split(args.output_folder) + latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') + with open(latest_file, "w") as f: + f.write(step_folder) + + print('*** Done!') + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py index b5bf41e2d160..02b3947624a1 100644 --- a/deepspeed/checkpoint/reshape_3d_utils.py +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -81,7 +81,7 @@ def get_model_3d_descriptor(dir): else: tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) dp_degree = max(1, len(zero_file_list) // tp_degree) - pp_degree = 0 + pp_degree = 1 return model_3d_desc(pp_degree, tp_degree, dp_degree) diff --git a/deepspeed/checkpoint/reshape_meg_2d.py b/deepspeed/checkpoint/reshape_meg_2d.py index 3bff87f4344f..cef8fa5cc569 100644 --- a/deepspeed/checkpoint/reshape_meg_2d.py +++ b/deepspeed/checkpoint/reshape_meg_2d.py @@ -24,7 +24,7 @@ def add_data(self, pp_index, tp_index, data): assert type(data) is list key = self._make_key(pp_index, tp_index) - if not key in self.map.keys(): + if key not in self.map.keys(): self.map[key] = [] self.map[key] += data @@ -84,14 +84,14 @@ def reshape_meg_2d_parallel(old_pp_degree, old_tp_degree, new_pp_degree, new_tp_ old_2d_map = meg_2d_parallel_map(old_pp_degree, old_tp_degree) old_2d_map.simple_init() if verbose: - old_2d_map.print_data(f'original_2d_map:') + old_2d_map.print_data('original_2d_map:') if old_tp_degree != new_tp_degree: new_tp_map = _reshape_tp_dimension(old_2d_map, new_tp_degree) else: new_tp_map = old_2d_map if verbose: - new_tp_map.print_data(f'after_tp_reshape:') + new_tp_map.print_data('after_tp_reshape:') if old_pp_degree != new_pp_degree: final_map = _reshape_pp_dimension(new_tp_map, new_pp_degree) @@ -99,7 +99,7 @@ def reshape_meg_2d_parallel(old_pp_degree, old_tp_degree, new_pp_degree, new_tp_ final_map = new_tp_map if verbose: - final_map.print_data(f'final_2d_map:') + final_map.print_data('final_2d_map:') return final_map @@ -159,7 +159,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_dp_group_ranks] all_pp_group_ranks.append(list(ranks)) - print(f"PP", all_pp_group_ranks) + print("PP", all_pp_group_ranks) # Build the tensor model-parallel groups. all_tp_group_ranks = [] @@ -167,7 +167,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) all_tp_group_ranks.append(list(ranks)) - print(f"TP", all_tp_group_ranks) + print("TP", all_tp_group_ranks) return all_tp_group_ranks, all_pp_group_ranks, all_dp_group_ranks diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py index 15b6ce28b2fd..137607721ebf 100644 --- a/deepspeed/checkpoint/reshape_utils.py +++ b/deepspeed/checkpoint/reshape_utils.py @@ -4,9 +4,10 @@ # DeepSpeed Team import os +import re import torch from collections import OrderedDict -from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX) +from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX, MODEL_FILE_PREFIX) def basic_folder_validation(dir): @@ -38,12 +39,28 @@ def get_files(dir): return file_list +def sort_zero_files(files, prefix): + pattern = f"{prefix}([0-9]+)_{MODEL_FILE_PREFIX}([0-9]+)" + rank_pairs = [] + for f in files: + m = re.search(pattern, f) + if m: + dp_rank = int(m.group(1)) + mp_rank = int(m.group(2)) + rank_pairs.append((dp_rank, mp_rank, f)) + else: + raise ValueError(f"Cannot parse dp_rank and mp_rank from {f}") + + sorted_files = sorted(rank_pairs, key=lambda x: (x[0], x[1])) + return [f for _, _, f in sorted_files] + + def get_zero_files(dir): file_list = get_files(dir) for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]: zero_files = get_files_with_prefix(file_list, prefix) if len(zero_files) > 0: - return zero_files + return sort_zero_files(zero_files, prefix) return [] diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index cd4d7d51a4c2..7a9c2bcb068b 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -4,23 +4,119 @@ # DeepSpeed Team import os +import re import torch import types -from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM) +from typing import List, Tuple, Union +from dataclasses import dataclass +from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE, + DS_AUTOTP_UC_META) + + +@dataclass +class SubparamShape: + patterns: List[str] + shape: Tuple[Union[Tuple[int], int]] + partition_dim: int + + +def _get_param_uc_restore_meta(param): + """Return the restore-facing view of AutoTP UC metadata for a parameter. + + AutoTP parameter metadata intentionally serves two separate consumers: + - restore-time fields at the top level, consumed here by UC loading + - conversion-time fields under `conversion`, consumed by + `collect_autotp_universal_checkpoint_info()` in `layers.py` + """ + return getattr(param, DS_AUTOTP_UC_META, None) + + +def _resolve_autotp_partition(current_param, ckpt_dict, full_hp_param, tp_rank, tp_world_size): + meta = _get_param_uc_restore_meta(current_param) + if not meta: + return None + + partition_dim = meta.get('partition_dim') + logical_shape = meta.get('logical_shape') + sub_param_shape = meta.get('sub_param_shape') + sub_param_sizes = meta.get('sub_param_sizes') + replicated = meta.get('replicated', False) + + if replicated: + assert partition_dim is None + slice_tensor = full_hp_param + return slice_tensor.flatten() + + if partition_dim is None: + return None + + if logical_shape is None: + return None + + full_view = full_hp_param.view(logical_shape) + + if sub_param_shape is not None: + if hasattr(sub_param_shape, "shape") and hasattr(sub_param_shape, "partition_dim"): + shape_spec = sub_param_shape.shape + partition_dim = sub_param_shape.partition_dim + else: + shape_spec = sub_param_shape + + sub_dim_sizes = shape_spec[partition_dim] + if not isinstance(sub_dim_sizes, tuple): + sub_dim_sizes = (sub_dim_sizes, ) + + offset = 0 + merged_chunks = [] + for sub_dim_size in sub_dim_sizes: + sub_slice = full_view.narrow(partition_dim, offset, sub_dim_size) \ + .chunk(tp_world_size, dim=partition_dim)[tp_rank] + merged_chunks.append(sub_slice) + offset += sub_dim_size + + slice_tensor = torch.cat(merged_chunks, dim=partition_dim) + return slice_tensor.flatten() + + if sub_param_sizes is not None: + if not isinstance(sub_param_sizes, (tuple, list)): + sub_param_sizes = (sub_param_sizes, ) + + offset = 0 + merged_chunks = [] + for sub_dim_size in sub_param_sizes: + sub_slice = full_view.narrow(partition_dim, offset, sub_dim_size) \ + .chunk(tp_world_size, dim=partition_dim)[tp_rank] + merged_chunks.append(sub_slice) + offset += sub_dim_size + + slice_tensor = torch.cat(merged_chunks, dim=partition_dim) + return slice_tensor.flatten() + + slice_tensor = full_view.chunk(tp_world_size, dim=partition_dim)[tp_rank] + return slice_tensor.flatten() def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping = self._hp_mapping - optim_state_keys = hp_mapping.get_optim_state_keys() - hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys - checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} + hp_mapping.optim_fragment = {} - for file in checkpoint_files.values(): - assert os.path.isfile(file), f'{file} is not a valid file' + hp_keys = [] + for file in os.listdir(folder): + # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt" + pattern = r'(.+).pt' + match = re.search(pattern, file) + if match: + hp_keys.append(match.group(1)) + step = None for key in hp_keys: - ckpt_file = checkpoint_files[key] - ckpt_dict = torch.load(ckpt_file) + ckpt_file = os.path.join(folder, f"{key}.pt") + ckpt_dict = torch.load(ckpt_file, weights_only=False) + + if key == "step": + step = ckpt_dict + continue + full_hp_param = ckpt_dict[PARAM] # need to deal with slices that were averaged. @@ -43,51 +139,83 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # the converter to universal currently strips the original padding completely so the saved # weight is padding-free and we just need to add new padding depending on the target TP # degree - vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None) - if vocab_divisibility_padding_tensor is not None: + is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False) + if is_vocab_tensor: # In the absence of data passed from the user wrt new padded vocab specific to tp degree # we can again derive that data by reverse engineering the target shapes like so: padded_target_vocab_size = self.shape[0] * tp_world_size + assert padded_target_vocab_size >= full_hp_param.shape[0], \ + f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}' if padded_target_vocab_size > full_hp_param.shape[0]: - # Need to expand padding_size = padded_target_vocab_size - full_hp_param.shape[0] - # Implement the following concat in efficient way using pad - #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0) - full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor - else: - # Need to shrink or keep the same - full_hp_param = full_hp_param[:padded_target_vocab_size, :] - - full_param_numel = full_hp_param.numel() - tp_slice_numel = self.numel() - # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: - # print_rank_0(f'{full_hp_param[:10]=}', force=True) + autotp_tp_hp_slice = _resolve_autotp_partition(self, ckpt_dict, full_hp_param, tp_rank, tp_world_size) + if autotp_tp_hp_slice is not None: + tp_hp_slice = autotp_tp_hp_slice + else: + full_param_numel = full_hp_param.numel() + tp_slice_numel = self.numel() + assert full_param_numel == tp_world_size * tp_slice_numel, \ + f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' + + # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") + # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") + + sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None) + # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse + # special case is when a single parameter is effectively a container for multiple sub parameters + # (more details at PARAM_N_SUB_PARAMS definition) + chunk_dim = ckpt_dict.get(CAT_DIM, 0) + n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1) + if sub_param_shape: + partition_dim = sub_param_shape.partition_dim + sub_dim_sizes = sub_param_shape.shape[partition_dim] + if not isinstance(sub_dim_sizes, tuple): + sub_dim_sizes = (sub_dim_sizes, ) + + partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape] + full_hp_param = full_hp_param.view(partition_shape) + + offset = 0 + merged_chunks = [] + for sub_dim_size in sub_dim_sizes: + sub_params_tp_slice = full_hp_param.narrow(partition_dim, + offset, sub_dim_size).chunk(tp_world_size, + dim=partition_dim)[tp_rank] + merged_chunks.append(sub_params_tp_slice) + offset += sub_dim_size + tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim) + + elif n_sub_params > 1: + sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim) + sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params] + tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim) + else: + # this performs the opposite of cat when merging TP slices + tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] - assert full_param_numel == tp_world_size * tp_slice_numel, \ - f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' - dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key) - - # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") - # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") - - # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse - chunk_dim = ckpt_dict.get(CAT_DIM, 0) - - # this performs the opposite of cat when merging TP slices - tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] - tp_hp_slice = tp_hp_slice.flatten() + tp_hp_slice = tp_hp_slice.flatten() lp_frag_address = hp_mapping.lp_fragment_address tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel) - assert dst_tensor.numel() == lp_frag_address.numel, \ - f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' # print(f"{key} SHAPE: {tp_hp_slice.shape=}") # print(f"{key} SHAPE: {dst_tensor.shape=}") # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") - dst_tensor.data.copy_(tp_hp_fragment.data) + + if key == FP32_WEIGHT_KEY: + dst_tensor = hp_mapping.get_hp_fragment() + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + dst_tensor.data.copy_(tp_hp_fragment.data) + else: + assert tp_hp_fragment.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}' + + hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach() + + return step def enable_universal_checkpoint(param_list): diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 3707aa0eb419..5964da00728e 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import torch from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX) @@ -29,3 +30,38 @@ def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank): ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}' ckpt_path = os.path.join(base_folder, ckpt_file) return ckpt_path + + +# We pass cloned tensors to torch.save() to avoid checkpoint bloat that occurs when torch.save() +# saves the underlying storage rather than the slice of the storage corresponding to individual tensors. +# This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers. +# Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size. +# It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat. +# See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing +def clone_tensors_for_torch_save(item, device=torch.device('cpu')): + """ + Returns a copy of ``item`` with all enclosed tensors replaced by clones on a specified device. + Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. + + Parameters: + - ``item``: tensor to clone or (possibly nested) container of tensors to clone. + - ``device``: target device (defaults to 'cpu') + + Returns: + - copy of ``item`` with cloned tensors on target device + """ + if torch.is_tensor(item): + if type(device) is str: + device = torch.device(device) + if device == item.device: + return item.detach().clone() + else: + return item.detach().to(device) + elif isinstance(item, list): + return [clone_tensors_for_torch_save(v, device) for v in item] + elif isinstance(item, tuple): + return tuple([clone_tensors_for_torch_save(v, device) for v in item]) + elif isinstance(item, dict): + return type(item)({k: clone_tensors_for_torch_save(v, device) for k, v in item.items()}) + else: + return item diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index c65745d3dd0c..c85f0241005d 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) merged_sd = None for state_file in state_file_list: - sd = torch.load(state_file, map_location=torch.device('cpu')) + sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False) for key in keys_to_ignore: sd.pop(key, None) @@ -105,9 +105,11 @@ def _strip_tensor_paddings(self, sd): if group_paddings[key] == 0: continue for state_name, state_value in group_state.items(): - if torch.is_tensor(state_value): + if state_name != "step" and torch.is_tensor(state_value): raw_length = state_value.numel() - group_paddings[key] group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone() + else: + group_state[state_name] = state_value def _clear_group_paddings(self, sd): group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) diff --git a/deepspeed/comm/__init__.py b/deepspeed/comm/__init__.py index 82c86113fd30..2c26ef491308 100644 --- a/deepspeed/comm/__init__.py +++ b/deepspeed/comm/__init__.py @@ -3,48 +3,5 @@ # DeepSpeed Team -import torch from .utils import * -from deepspeed import utils - -supported_torch_version = False - -# See more details at: https://github.com/pytorch/pytorch/pull/48767 -# The PG API in torch versions lesser than 1.8 are different so it is -# non-trivial to support both in the same API. We will just use the -# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+. - -if older_torch(): - # Add custom deepspeed torch comm functions here since we can't import deepspeed.comm - # NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them. - supported_torch_version = False - from torch.distributed import * - - def get_world_group(): - return group.WORLD - - def get_global_rank(group, group_rank): - if hasattr(torch.distributed.distributed_c10d, "get_global_rank"): - from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank - else: - from torch.distributed.distributed_c10d import _get_global_rank - return _get_global_rank(group, group_rank) - - def allgather_fn(output_tensor, input_tensor, group=None, async_op=False): - from torch.distributed import all_gather, get_world_size - from torch import chunk - output_tensors = list(chunk(output_tensor, get_world_size(group))) - return all_gather(output_tensors, input_tensor, group=group, async_op=async_op) - - def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False): - from torch.distributed import reduce_scatter, get_world_size - from torch import chunk - input_tensor_lst = list(chunk(input_tensor, get_world_size(group))) - return reduce_scatter(output_tensor, input_tensor_lst, group=group) - - def configure(deepspeed_config=None, enabled=None, prof_all=None, prof_ops=None, verbose=None): - utils.logger.warn("Communication logging is not supported in torch versions older than 1.8") - -else: - supported_torch_version = True - from .comm import * +from .comm import * diff --git a/deepspeed/comm/backend.py b/deepspeed/comm/backend.py index 4e02ad89719c..15dc96c40521 100644 --- a/deepspeed/comm/backend.py +++ b/deepspeed/comm/backend.py @@ -28,8 +28,8 @@ def __init__(self, name='backend', rank=0, size=1): self.name = name # The world size and rank of the world process group self.world_group = None - self.world_size = rank - self.world_rank = size + self.world_size = size + self.world_rank = rank # Single process group (pg) implementation for now but keep a list for future self.process_groups = [] self.initialized = False diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py new file mode 100644 index 000000000000..e95e40a03087 --- /dev/null +++ b/deepspeed/comm/ccl.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +''' +Copyright 2021 The Microsoft DeepSpeed Team +''' + +import torch +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import NotImplementedBuilder +from .reduce_op import ReduceOp +from .torch import TorchBackend + + +def build_ccl_op(): + builder = get_accelerator().create_op_builder("CCLCommBuilder") + if builder is None or isinstance(builder, NotImplementedBuilder): + return None + ccl_cpp_module = builder.load() + print(f'DeepSpeed {builder.absolute_name()} built successfully') + return ccl_cpp_module + + +class CCLHandler(): + + def __init__(self, ccl_comm_op=None): + self.ccl_comm_op = ccl_comm_op + + def wait(self): + # backend covered it + pass + + +class CCLBackend(TorchBackend): + + def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None): + self.ccl_comm_op = build_ccl_op() + if self.ccl_comm_op is None: + # set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded + self.initialized = False + return + super(CCLBackend, self).__init__(backend='ccl', + name='torch', + rank=rank, + world_size=world_size, + timeout=timeout, + init_method=init_method) + self.name = 'ccl' + size = self.get_world_size() + rank = self.get_rank() + main_kvs = self.ccl_comm_op.get_kvs_addr(rank) + main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name()) + super(CCLBackend, self).broadcast(main_kvs, 0) + self.ccl_comm_op.initialize(size, rank, main_kvs) + self.initialized = True + self.groups = [tuple(range(self.get_world_size()))] + self.available_coll = self.ccl_comm_op.get_available_coll() + + def is_initialized(self): + return self.initialized + + def run_collective(self, name, **kwargs): + if name in self.available_coll: + if 'group' in kwargs: + kwargs['group'] = self.get_all_ranks_from_group(kwargs['group']) + if 'dst' in kwargs: + kwargs['dst'] = kwargs['group'].index(kwargs['dst']) + if 'src' in kwargs: + kwargs['src'] = kwargs['group'].index(kwargs['src']) + func = "self.ccl_comm_op." + name + eval(func)(*(kwargs.values())) + return CCLHandler(self.ccl_comm_op) + else: + func = "super(CCLBackend, self)." + name + eval(func)(*(kwargs.values())) + return CCLHandler(self.ccl_comm_op) + + def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): + name = "all_reduce" + if name in self.available_coll: + group = self.get_all_ranks_from_group(group) + return self.ccl_comm_op.all_reduce(tensor, op, group, async_op) + else: + return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op) + + def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None): + name = "inference_all_reduce" + if name in self.available_coll: + return self.ccl_comm_op.inference_all_reduce(tensor, op) + else: + return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=False) + + def broadcast(self, tensor, src, group=None, async_op=False): + return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op) + + def all_gather(self, tensor_list, tensor, group=None, async_op=False): + return self.run_collective(name="all_gather", + tensor_list=tensor_list, + tensor=tensor, + group=group, + async_op=async_op) + + def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False): + return self.run_collective(name="reduce_scatter_tensor", + output_tensor=output_tensor, + input_tensor=input_tensor, + op=op, + group=group) + + def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): + return self.run_collective(name="all_gather_into_tensor", + output_tensor=output_tensor, + input_tensor=input_tensor, + group=group) + + def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False): + return self.run_collective(name="all_to_all_single", + output=output, + input=input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group) + + def send(self, tensor, dst, group=None, tag=0): + return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, tag=tag) + + def recv(self, tensor, src, group=None, tag=0): + return self.run_collective(name="recv", tensor=tensor, src=src, group=group, tag=tag) + + def gather(self, tensor, gather_list, dst, group=None, async_op=False): + return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group) + + def scatter(self, tensor, gather_list, dst, group=None, async_op=False): + return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group) + + def barrier(self, group=None, async_op=False): + return self.run_collective(name="barrier", group=group, async_op=async_op) + + def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False): + return self.run_collective(name="monitored_barrier", group=group) + + def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): + return self.run_collective(name="reduce_scatter", + output=output, + input_list=input_list, + op=op, + group=group, + async_op=async_op) + + def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): + return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op) + + def new_group(self, ranks): + return super(CCLBackend, self).new_group(ranks) + + def _new_group(self, ranks, group): + size = len(ranks) + rank = self.get_rank() + sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0]) + sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name()) + super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group) + self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks) + self.groups.append(tuple(ranks)) + + def get_all_ranks_from_group(self, group): + if group is None: + return list(range(self.get_world_size())) + rank = 0 + results = [] + try: + while True: + results.append(super(CCLBackend, self).get_global_rank(group, rank)) + rank += 1 + except (ValueError, RuntimeError): + pass + if tuple(results) not in self.groups: + self._new_group(results, group) + return results diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py old mode 100644 new mode 100755 index 61b001ad0662..f9e94f0175e2 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -5,7 +5,7 @@ """ DeepSpeed Communication Package: deepspeed.comm deepspeed.comm - -- import and use deepspeeed.ops.comm + -- import and use deepspeed.ops.comm -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version -- use custom collectives -- can either use torch.dist or ds.ops.comm? @@ -21,26 +21,15 @@ -- deepspeed groups API should be brought into ds.comm """ -from enum import Enum import torch +from torch.distributed import GradBucket # noqa: F401 import os from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout from .constants import * +from .reduce_op import * from deepspeed.accelerator import get_accelerator - - -class ReduceOp(Enum): - SUM = 0 - PRODUCT = 1 - MIN = 2 - MAX = 3 - BAND = 4 - BOR = 5 - BXOR = 6 - AVG = 7 - UNUSED = 8 - +from deepspeed.comm.ccl import CCLBackend from deepspeed.utils.comms_logging import CommsLogger from deepspeed.utils import timer, get_caller_func @@ -50,7 +39,6 @@ class ReduceOp(Enum): from datetime import timedelta # Current deepspeed.comm backend (cdb) global object for simple access by client code -use_ds_backend = False cdb = None # Create global timer for ops @@ -62,11 +50,21 @@ class ReduceOp(Enum): # Maintain objects of all initialized ds backends and assign them using the API functions in this file nccl_backend = None mpi_backend = None +ccl_backend = None +hccl_backend = None # This should be set here so all rank/size information from the launcher can be propagated from deepspeed.comm.utils import * +class ProcessGroup(): + + def __init__(self, comm_id, ranks=[]): + self.ranks = ranks + self.comm_id = comm_id + self.size = len(ranks) + + def _configure_using_config_file(config): if config.comms_logger_enabled: comms_logger.configure(config) @@ -143,20 +141,29 @@ def log_wrapper(*args, **kwargs): # UNUSED: Future helper function to initialize DS backends -def init_deepspeed_backend(ds_backend): +def init_deepspeed_backend(ds_backend, timeout, init_method): global cdb global nccl_backend global mpi_backend - global use_ds_backend + global ccl_backend + global hccl_backend + + rank = int(os.getenv('RANK', '-1')) + size = int(os.getenv('WORLD_SIZE', '-1')) if ds_backend == NCCL_BACKEND: - utils.logger.warn("NCCL backend in DeepSpeed not yet implemented") + utils.logger.debug("NCCL backend in DeepSpeed not yet implemented") elif ds_backend == MPI_BACKEND: - utils.logger.warn("MPI backend in DeepSpeed not yet implemented") + utils.logger.debug("MPI backend in DeepSpeed not yet implemented") elif ds_backend == GLOO_BACKEND: - utils.logger.warn("Gloo backend in DeepSpeed not yet implemented") + utils.logger.debug("Gloo backend in DeepSpeed not yet implemented") + elif ds_backend == CCL_BACKEND: + ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method) + utils.logger.info(f"Initialize {ds_backend} backend") + elif ds_backend == HCCL_BACKEND: + utils.logger.debug("HCCL backend in DeepSpeed not yet implemented") else: - utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend") + utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend") def is_initialized(): @@ -189,26 +196,27 @@ def is_available() -> bool: return True -def set_backend(backend_name): - if not use_ds_backend: - utils.logger.error( - "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality" - ) - raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.') - +def set_backend(): global cdb global nccl_backend global mpi_backend + global ccl_backend + global hccl_backend - try: - if backend_name == NCCL_BACKEND: - if nccl_backend is not None and nccl_backend.is_initialized(): - cdb = nccl_backend - elif backend_name == MPI_BACKEND: - if mpi_backend is not None and mpi_backend.is_initialized(): - cdb = mpi_backend - except Exception as inst: - print(inst) + backend_name = get_accelerator().communication_backend_name() + + if backend_name == NCCL_BACKEND: + if nccl_backend is not None and nccl_backend.is_initialized(): + cdb = nccl_backend + elif backend_name == MPI_BACKEND: + if mpi_backend is not None and mpi_backend.is_initialized(): + cdb = mpi_backend + elif backend_name == CCL_BACKEND: + if ccl_backend is not None and ccl_backend.is_initialized(): + cdb = ccl_backend + elif backend_name == HCCL_BACKEND: + if hccl_backend is not None and hccl_backend.is_initialized(): + cdb = hccl_backend @timed_op @@ -217,6 +225,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) +@timed_op +def broadcast_object_list(object_list, src, group=None, device=None): + global cdb + return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device) + + @timed_op def all_gather(tensor_list, tensor, @@ -229,6 +243,12 @@ def all_gather(tensor_list, return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) +@timed_op +def all_gather_object(object_list, obj, group=None, prof=False, log_name='all_gather_object', debug=get_caller_func()): + global cdb + return cdb.all_gather_object(object_list=object_list, obj=obj, group=group) + + def has_reduce_scatter_tensor(): global cdb assert cdb is not None and cdb.is_initialized( @@ -257,7 +277,7 @@ def reduce_scatter_fn(output_tensor, else: if get_rank() == 0: utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to " - "torch.distributed.all_gather which will result in suboptimal performance. " + "torch.distributed.reduce_scatter which will result in suboptimal performance. " "please consider upgrading your pytorch installation.") input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group))) return reduce_scatter(output_tensor, @@ -339,6 +359,12 @@ def all_to_all_single(output, async_op=async_op) +@timed_op +def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + global cdb + return cdb.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) + + @timed_op def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()): global cdb @@ -392,7 +418,7 @@ def scatter(tensor, @timed_op def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()): global cdb - return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids) + return cdb.barrier(group=group, async_op=async_op) @timed_op @@ -403,15 +429,162 @@ def monitored_barrier(group=None, log_name='monitored_barrier', debug=get_caller_func()): global cdb - return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks) + return cdb.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks) + + +def log_summary(show_straggler=False, return_dict=False): + """ + Print and/or return communication operation statistics across all ranks. + This function synchronizes all ranks and logs communication statistics. + Only rank 0 prints to console by default, but all ranks can return the dictionary. -def log_summary(): + Args: + show_straggler (bool, optional): Whether to include straggler effect analysis. + When True, computes the time difference between the fastest and slowest ranks + for each communication operation. Defaults to False. + return_dict (bool, optional): Whether to return statistics as a dictionary. + When True, returns a comprehensive dictionary with communication metrics. + Defaults to False. + + Returns: + dict or None: If return_dict=True, returns communication statistics dictionary. + The structure is identical to CommsLogger.log_all() return value. + Returns None if return_dict=False. + + Dictionary structure (when return_dict=True): + { + "summary": { + "operation_name": { + message_size_bytes: { + "count": int, + "total_latency_ms": float, + "avg_latency_ms": float, + "tput_avg_gbps": float, + "busbw_avg_gbps": float, + "msg_size_bytes": int, + "msg_size_str": str + } + } + }, + "straggler_analysis": {...} if show_straggler else None, + "metadata": { + "world_size": int, + "rank": int, + "timestamp": str + } + } + + Note: + - This function includes barriers for synchronization across all ranks + - Straggler analysis requires additional all_reduce operations + - All ranks return the same data when return_dict=True + - Only rank 0 prints to console when print_log=True (default behavior) + + Example: + # Print summary only (backward compatible) + deepspeed.comm.log_summary() + + # Get dictionary and print summary + stats = deepspeed.comm.log_summary(return_dict=True) + + # Include straggler analysis + stats = deepspeed.comm.log_summary(show_straggler=True, return_dict=True) + + # Access specific operation data + if stats and "all_reduce" in stats["summary"]: + all_reduce_stats = stats["summary"]["all_reduce"] + """ global cdb barrier(log_name='log_summary_barrier') + + result = None if cdb.get_rank() == 0: - comms_logger.log_all() + result = comms_logger.log_all(print_log=True, show_straggler=show_straggler, return_dict=return_dict) + else: + # Non-rank-0 processes: don't print but may still return dict if requested + result = comms_logger.log_all(print_log=False, show_straggler=show_straggler, return_dict=return_dict) + barrier(log_name='log_summary_barrier') + return result + + +def reset_log(): + """ + Clear all accumulated communication logging data. + + This function clears the communication logger's internal data dictionary, + allowing for epoch-by-epoch or interval-based logging. After calling this + function, subsequent log_summary() calls will only show statistics for + communication operations that occur after the reset. + + Note: + - This affects the global communication logger + - All accumulated statistics will be lost + - This function is useful for getting per-epoch or per-interval statistics + + Example: + # Training loop with per-epoch communication logging + for epoch in range(num_epochs): + # Reset logger at start of epoch + deepspeed.comm.reset_log() + + # Train for one epoch + train_one_epoch(model, dataloader) + + # Get communication stats for this epoch only + epoch_stats = deepspeed.comm.log_summary(return_dict=True) + print(f"Epoch {epoch} communication stats: {epoch_stats}") + """ + global comms_logger + comms_logger.reset_data() + + +def has_comm_data(): + """ + Check if any communication data has been logged. + + Returns: + bool: True if communication operations have been logged, False otherwise + + Example: + if deepspeed.comm.has_comm_data(): + stats = deepspeed.comm.log_summary(return_dict=True) + else: + print("No communication operations logged yet") + """ + global comms_logger + return comms_logger.has_data() + + +def get_comm_operation_count(): + """ + Get the total number of communication operations logged. + + Returns: + int: Total count of all communication operations across all types + + Example: + total_ops = deepspeed.comm.get_comm_operation_count() + print(f"Total communication operations this epoch: {total_ops}") + """ + global comms_logger + return comms_logger.get_total_operations() + + +def get_logged_comm_ops(): + """ + Get list of communication operation types that have been logged. + + Returns: + list: List of operation names that have been logged (e.g., ['all_reduce', 'broadcast']) + + Example: + ops = deepspeed.comm.get_logged_comm_ops() + print(f"Communication operations used: {ops}") + """ + global comms_logger + return comms_logger.get_operation_names() @timed_op @@ -440,6 +613,30 @@ def reduce_scatter(output, return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op) +def has_all_reduce_coalesced(): + """""" + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined' + return cdb.has_all_reduce_coalesced + + +def has_coalescing_manager(): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined' + return cdb.has_coalescing_manager + + +def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op) + + @timed_op def all_reduce(tensor, op=ReduceOp.SUM, @@ -457,6 +654,30 @@ def all_reduce(tensor, return cdb.all_reduce(tensor, op, group, async_op) +@timed_op +def inference_all_reduce(tensor, + op=ReduceOp.SUM, + group=None, + async_op=False, + prof=False, + log_name='all_reduce', + debug=get_caller_func()): + global cdb + return cdb.inference_all_reduce(tensor, op, group) + + +@timed_op +def all_reduce_coalesced(tensors, + op=ReduceOp.SUM, + group=None, + async_op=False, + prof=False, + log_name='all_reduce', + debug=get_caller_func()): + global cdb + return cdb.all_reduce_coalesced(tensors, op, group, async_op) + + def get_world_group(): global cdb assert cdb is not None and cdb.is_initialized( @@ -522,6 +743,47 @@ def get_global_rank(group=None, group_rank=0): return cdb.get_global_rank(group, group_rank) +def get_all_ranks_from_group(group=None): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + rank = 0 + group_ranks = [] + try: + while True: + group_ranks.append(cdb.get_global_rank(group, rank)) + rank += 1 + except (RuntimeError, ValueError): + pass + return group_ranks + + +def initialize_mesh_device(mesh_shape, mesh_dim_names): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + mesh_device = None + if hasattr(cdb, 'init_device_mesh'): + utils.logger.info(f"Initializing mesh device with backend {cdb.name} \ + with shape {mesh_shape} and dim names {mesh_dim_names}") + mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names) + else: + if get_rank() == 0: + utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization") + return mesh_device + + +def enable_symm_mem_for_group(group_name: str): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + + if hasattr(cdb, 'enable_symm_mem_for_group'): + cdb.enable_symm_mem_for_group(group_name) + else: + raise RuntimeError(f"Backend {cdb.name} does not support symmetric memory initialization") + + # Main DeepSpeed Comms. public API. def init_distributed(dist_backend=None, auto_mpi_discovery=True, @@ -536,14 +798,14 @@ def init_distributed(dist_backend=None, ''' Initialize dist backend, potentially performing MPI discovery if needed Arguments: - dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo + dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo, hccl auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI distributed_port: Optional (int). torch distributed backend port verbose: Optional (bool). verbose logging - timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes. - init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. + timeout: Optional (timedelta). Timeout for operations executed against the process group. The default value of 30 minutes can be overridden by the environment variable `DEEPSPEED_TIMEOUT`. + init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is "env://" if no init_method or store is specified. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling) - rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization) + rank: Optional (int). The current manually specified rank. Some init_method like "tcp://" need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization) world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization. ''' global cdb @@ -553,6 +815,10 @@ def init_distributed(dist_backend=None, if dist_init_required is None: dist_init_required = cdb is None or not cdb.is_initialized() + if cdb is None: + init_deepspeed_backend(get_accelerator().communication_backend_name(), timeout, init_method) + set_backend() + utils.logger.info(f'cdb={cdb}') if cdb is None and torch.distributed.is_initialized(): # The user initialized torch.dist themselves, create cdb and short-circuit cdb = TorchBackend(dist_backend, timeout, init_method) @@ -580,7 +846,7 @@ def init_distributed(dist_backend=None, utils.logger.info('Distributed backend already initialized') else: assert isinstance(timeout, timedelta) - if dist_backend == None: + if dist_backend is None: dist_backend = get_accelerator().communication_backend_name() if int(os.getenv('RANK', '0')) == 0: utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend)) @@ -600,9 +866,14 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True) master_addr = None if rank == 0: - hostname_cmd = ["hostname -I"] - result = subprocess.check_output(hostname_cmd, shell=True) - master_addr = result.decode('utf-8').split()[0] + import shlex + try: + hostname_cmd = shlex.split("hostname -I") + result = subprocess.check_output(hostname_cmd) + master_addr = result.decode('utf-8').split()[0] + except subprocess.CalledProcessError: # hostname -I not available (e.g. on macOS) + import socket + master_addr = socket.gethostbyname(socket.gethostname()) master_addr = comm.bcast(master_addr, root=0) # Determine local rank by assuming hostnames are unique diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py index 138badebe5a9..57501c9dd237 100644 --- a/deepspeed/comm/config.py +++ b/deepspeed/comm/config.py @@ -3,20 +3,12 @@ # DeepSpeed Team -from pydantic import BaseModel -from .constants import * - +from deepspeed.runtime.config_utils import DeepSpeedConfigModel -class CommsConfig(BaseModel): - - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' +from .constants import * -class CommsLoggerConfig(CommsConfig): +class CommsLoggerConfig(DeepSpeedConfigModel): enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index ab309247befe..50d234c93fa0 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -4,9 +4,11 @@ # DeepSpeed Team NCCL_BACKEND = 'nccl' +CCL_BACKEND = 'ccl' MPI_BACKEND = 'mpi' GLOO_BACKEND = 'gloo' SCCL_BACKEND = 'sccl' +HCCL_BACKEND = 'hccl' DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" diff --git a/deepspeed/comm/mori.py b/deepspeed/comm/mori.py new file mode 100644 index 000000000000..508c94090173 --- /dev/null +++ b/deepspeed/comm/mori.py @@ -0,0 +1,228 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""mori SDMA backend, plugged into ``TorchBackend.all_gather_into_tensor``. + +When the user opts in, ``deepspeed.comm`` routes ``all_gather_into_tensor`` +on the WORLD process group through ``mori_cpp.AllGatherIntoTensor`` +(intra-node SDMA copy on AMD MI300). Any failure (mori missing, +non-AMD/ROCm runtime, shmem init error, oversized call, non-WORLD group) +yields ``None`` and the caller falls back to the underlying RCCL/NCCL +allgather. + +User-visible controls (env vars, no ``ds_config`` field): + +* ``DS_SDMA_ALLGATHER=1`` opt in to the SDMA path. Required: + even when mori is installed, the + SDMA fast-path stays off unless + the user sets this explicitly. + When set, ``MORI_ENABLE_SDMA=1`` is + auto-exported on the user's behalf + so mori allocates uncached transit + buffers. +* ``DS_SDMA_ALLGATHER_MAX_NUMEL=N`` override the transit buffer size in + elements (default 64M = 256 MiB + per-rank input, ~2 GiB output on 8 + ranks) +""" + +import os +from typing import Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import logger + +_handle = None +_dtype_map = None +_max_numel = 0 +_init_attempted = False +_call_failed_warned = False + + +class _SdmaWork: + """Duck-type compatible with ``torch.distributed.Work``. + + ``wait()`` issues a stream-level dependency only and does NOT block the + CPU, mirroring RCCL ``Work.wait()`` semantics. ZeRO-3's prefetch + pipeline relies on the CPU staying free so the next bucket can be + queued ahead of time while bucket N is in flight. + """ + + def __init__(self, event): + self._event = event + + def wait(self): + get_accelerator().current_stream().wait_event(self._event) + + def is_completed(self) -> bool: + return self._event.query() + + +def _ensure_default_pg_registered(): + """Register the WORLD process group as 'default' in PyTorch's C++ GroupRegistry. + + mori's shmem layer looks up the PG by the name "default"; the standard + DeepSpeed init path doesn't register WORLD under that label. + """ + world_group = torch.distributed.group.WORLD + assert world_group is not None, "torch.distributed must be initialized before SDMA allgather" + torch._C._distributed_c10d._register_process_group("default", world_group) + + +def _build_dtype_map(): + """torch.dtype -> mori_cpp.DataType (NCCL-style enum).""" + from mori.ccl import DataType + return { + torch.uint8: DataType.Uint8, + torch.int8: DataType.Int8, + torch.int16: DataType.Int16, + torch.int32: DataType.Int32, + torch.int64: DataType.Int64, + torch.float16: DataType.Float16, + torch.bfloat16: DataType.BFloat16, + torch.float32: DataType.Float32, + torch.float64: DataType.Float64, + } + + +_TRUTHY = {"1", "true", "True", "TRUE", "yes", "Yes", "YES", "on", "On", "ON"} + + +def _is_enabled_by_env() -> bool: + """User must explicitly opt in via ``DS_SDMA_ALLGATHER=1``. + + Default is off even when mori happens to be importable: mori is an + external dependency and we don't want DeepSpeed's collective backend + to silently change behaviour based on which extra packages are + installed. Keeping this opt-in also makes A/B baselines against the + stock RCCL path trivial without having to uninstall mori. + """ + return os.environ.get("DS_SDMA_ALLGATHER", "0") in _TRUTHY + + +def _resolve_max_numel(default: int) -> int: + raw = os.environ.get("DS_SDMA_ALLGATHER_MAX_NUMEL") + if raw is None: + return default + try: + return max(int(raw), 0) + except ValueError: + return default + + +def init(max_numel: int = 64 * 1024 * 1024) -> None: + """Best-effort, idempotent SDMA handle construction. + + Builds one ``mori_cpp.AllGatherIntoTensor`` (NCCL/RCCL-style C++ + dispatcher) sized for the largest expected per-rank shard. All + subsequent allgather calls reuse this handle. Safe to call + unconditionally: any failure leaves ``_handle`` unset and logs a + single rank-0 info line, so callers transparently fall back to + RCCL/NCCL. + """ + global _handle, _dtype_map, _max_numel, _init_attempted + if _init_attempted: + return + _init_attempted = True + + is_rank0 = torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 + if not _is_enabled_by_env(): + # Silent no-op: SDMA stays off and dist.allgather is used. We + # don't log here because most users never set DS_SDMA_ALLGATHER and + # rank-0 spam on every backend init is noise. + return + + max_numel = _resolve_max_numel(max_numel) + # mori's SymmMemManager only allocates the uncached transit buffers + # required by the SDMA kernel when MORI_ENABLE_SDMA is set; setdefault + # so users who already exported it (or want to override) win. + os.environ.setdefault("MORI_ENABLE_SDMA", "1") + + try: + _ensure_default_pg_registered() + import mori.shmem as shmem + from mori.ccl import AllGatherIntoTensor + + shmem.shmem_torch_process_group_init("default") + my_pe = shmem.shmem_mype() + npes = shmem.shmem_npes() + # Per-rank input transit buffer must hold the largest shard we'll + # ever see; output buffer = npes * input. 4 B/element is the SDMA + # kernel's uint32 lane width. + input_bytes = max_numel * 4 + _handle = AllGatherIntoTensor( + my_pe=my_pe, + npes=npes, + input_buffer_size=input_bytes, + output_buffer_size=input_bytes * npes, + copy_output_to_user=True, + ) + _dtype_map = _build_dtype_map() + _max_numel = max_numel + if is_rank0: + logger.info(f"SDMA allgather enabled via mori_cpp.AllGatherIntoTensor " + f"(max_numel={max_numel})") + except Exception as e: + _handle = None + _dtype_map = None + _max_numel = 0 + if is_rank0: + logger.info(f"SDMA allgather unavailable ({type(e).__name__}: {e}); " + f"using RCCL/NCCL allgather") + + +def is_enabled() -> bool: + return _handle is not None + + +def supports(input_tensor: torch.Tensor, group=None) -> bool: + """Cheap pre-check used by ``TorchBackend.all_gather_into_tensor``. + + SDMA is only safe when: + - the backend is initialised (``_handle`` set), + - the call is on the WORLD process group (mori's shmem layer was + bound to "default"/WORLD at init time), + - the per-rank shard fits inside the pre-allocated transit buffer, + - the input dtype is in ``_dtype_map``. + """ + if _handle is None: + return False + if group is not None and group is not torch.distributed.group.WORLD: + return False + if input_tensor.numel() > _max_numel: + return False + if _dtype_map is None or input_tensor.dtype not in _dtype_map: + return False + return True + + +def allgather_into_tensor(input_tensor: torch.Tensor, output_tensor: torch.Tensor, group=None) -> Optional[_SdmaWork]: + """Run one allgather_into_tensor through the SDMA handle. + + Returns an ``_SdmaWork`` (Work-compatible) on success. Returns + ``None`` when SDMA is not applicable for this call (uninitialised, + non-WORLD group, dtype not supported, shard larger than the transit + buffer) or the call fails for any reason — the caller falls back to + ``dist.allgather_fn``. + """ + global _call_failed_warned + if not supports(input_tensor, group): + return None + try: + stream = get_accelerator().current_stream() + dtype = _dtype_map[input_tensor.dtype] + ok = _handle(input_tensor.data_ptr(), output_tensor.data_ptr(), input_tensor.numel(), dtype, + stream.cuda_stream) + if not ok: + return None + event = get_accelerator().Event() + event.record(stream) + return _SdmaWork(event) + except Exception as e: + if (not _call_failed_warned and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0): + logger.warning(f"SDMA allgather failed ({e}); falling back to dist.allgather") + _call_failed_warned = True + return None diff --git a/deepspeed/comm/reduce_op.py b/deepspeed/comm/reduce_op.py new file mode 100644 index 000000000000..49393f0c7227 --- /dev/null +++ b/deepspeed/comm/reduce_op.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum + + +class ReduceOp(Enum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + AVG = 7 + UNUSED = 8 diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py old mode 100644 new mode 100755 index 18e18500a4e9..39e3f65fbe92 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -3,11 +3,96 @@ # DeepSpeed Team +import deepspeed from deepspeed import utils +from packaging import version +import inspect from .utils import * from .backend import * from .comm import * +from ..runtime import compiler +from deepspeed.utils.torch import required_torch_version +import os + +DS_COMM_ALL_GATHER_OFF = False +DS_COMM_REDUCE_SCATTER_OFF = False +DS_COMM_BROADCAST_OFF = False +DS_COMM_ALL_REDUCE_OFF = False +DS_COMM_REDUCE_OFF = False + + +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) + + +def build_shm_op(): + builder = get_accelerator().create_op_builder("ShareMemCommBuilder") + if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: + return None + shm_cpp_module = builder.load() + print(f'DeepSpeed {builder.absolute_name()} built successfully') + return shm_cpp_module + + +def has_coalescing_manager(): + has_c10d = hasattr(torch.distributed, 'distributed_c10d') + return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager') + + +def has_all_reduce_coalesced(): + return hasattr(torch.distributed, "all_reduce_coalesced") and required_torch_version(min_version=1.13) + + +def get_coalescing_manager(group, device, reqs, async_op): + if required_torch_version(min_version=2.0, max_version=2.0): + return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs) + elif required_torch_version(min_version=2.1): + return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, async_ops=async_op) + else: + return torch.distributed.distributed_c10d._coalescing_manager(group, reqs) + + +##Utilities to turn comm off +##TODO: move to base comm (wrapper) +def all_gather_comm_off(flag=False): + global DS_COMM_ALL_GATHER_OFF + DS_COMM_ALL_GATHER_OFF = flag + + +def reduce_scatter_comm_off(flag=False): + global DS_COMM_REDUCE_SCATTER_OFF + DS_COMM_REDUCE_SCATTER_OFF = flag + + +def broadcast_comm_off(flag=False): + global DS_COMM_BROADCAST_OFF + DS_COMM_BROADCAST_OFF = flag + + +def all_reduce_comm_off(flag=False): + global DS_COMM_ALL_REDUCE_OFF + DS_COMM_ALL_REDUCE_OFF = flag + + +def reduce_comm_off(flag=False): + global DS_COMM_REDUCE_OFF + DS_COMM_REDUCE_OFF = flag + + +#assumption: all_gather and reduce scatter +## are what we care about +def backward_comm_off(flag=False): + all_gather_comm_off(flag) + reduce_scatter_comm_off(flag) + + +class Noop: + + def wait(self): + return None class TorchBackend(Backend): @@ -21,7 +106,9 @@ class TorchBackend(Backend): def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'): super(TorchBackend, self).__init__() - self.torch_version_before_18 = older_torch() + self.shm_comm_op = build_shm_op() + self.has_all_reduce_coalesced = has_all_reduce_coalesced() + self.has_coalescing_manager = has_coalescing_manager() self.all_gather_function = self.get_all_gather_function() self.reduce_scatter_function = self.get_reduce_scatter_function() self.initialized = True @@ -31,8 +118,28 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' # it is not so we can run on a single GPU without doing any init_process_group self.single_gpu_mode = True self.init_process_group(backend, timeout, init_method, rank, world_size) + if self.shm_comm_op != None: + self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) + # Best-effort SDMA (mori) backend acquisition. Stays None on + # non-AMD/ROCm or when mori is unavailable; in that case + # all_gather_into_tensor below transparently falls through to + # torch.distributed.all_gather_into_tensor. + self._init_sdma_backend() + + def _init_sdma_backend(self): + """Try to enable the mori SDMA path for ``all_gather_into_tensor``. + + Failure (non-AMD, mori not installed, shmem init error) leaves the + handle unset and the standard torch.distributed allgather is used. + """ + try: + from . import mori as _mori + _mori.init() + except Exception: + pass @classmethod + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -41,6 +148,7 @@ def get_all_gather_function(self): return None @classmethod + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -56,45 +164,154 @@ def has_reduce_scatter_tensor(self): def init_process_group(self, backend, timeout, init_method, rank, world_size): if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend, - timeout=timeout, - init_method=init_method, - rank=rank, - world_size=world_size) + kwargs = dict(timeout=timeout, init_method=init_method, rank=rank, world_size=world_size) + + # 1. device_id arg was added in torch==2.3 + # 2. setting device_id leads to hanging in 2.6.0= 3.14 +except ImportError: + annotationlib = None + +# Deal with annotations in python versions >=3.14. See: +# - Python 3.14 release notes: https://docs.python.org/3/whatsnew/3.14.html +# Porting annotations: https://docs.python.org/3/whatsnew/3.14.html#whatsnew314-porting-annotations +# - PEP649: https://peps.python.org/pep-0649/ +# - PEP749: https://peps.python.org/pep-0749/ +# Backwards compatible, applies best practices (use annotationlib) from python 3.14 onwards. + + +def get_annotations_from_namespace(namespace: Mapping[str, object]) -> Dict[str, Any]: + if annotationlib: + annotate_func = annotationlib.get_annotate_from_class_namespace(namespace) + if annotate_func is not None: + return annotationlib.call_annotate_function(annotate_func, annotationlib.Format.VALUE) + return namespace.get("__annotations__", {}) + + +def get_annotations(obj: Any) -> Dict[str, Any]: + """ + Retrieves annotations from a Python object. + + In python >=3.14 this is a thin wrapper around the `annotationlib.get_annotations` function + with the added convenience to automatically infer the type for non module, class, function + or customly annotated objects. + """ + if annotationlib: + has_annotations = hasattr(obj, "__annotations__") or hasattr(obj, "__annotate__") + if not isinstance(obj, type) and not ismodule(obj) and not callable(obj) and not has_annotations: + obj = type(obj) + return annotationlib.get_annotations(obj) + try: + return obj.__annotations__ + except AttributeError: + return {} diff --git a/deepspeed/compile/__init__.py b/deepspeed/compile/__init__.py new file mode 100644 index 000000000000..ab2a908812f4 --- /dev/null +++ b/deepspeed/compile/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .util import pad_tensors diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py new file mode 100644 index 000000000000..3858c2f20993 --- /dev/null +++ b/deepspeed/compile/backend.py @@ -0,0 +1,405 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict, List, Callable, Tuple, Set +import time +import gc +from collections import OrderedDict, deque + +import torch +from torch.fx import Graph, GraphModule + +try: + import torch._dynamo + from functorch.compile import make_boxed_func + from torch._functorch.aot_autograd import aot_module_simplified + from torch._functorch.partitioners import min_cut_rematerialization_partition + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch._subclasses.fake_tensor import is_fake +except ImportError: + pass + +from deepspeed.accelerator import get_accelerator + +from .fx import add_free_activations +from .graph_param import DSGraphParamManager +from .profilers import ProfilingResult +from .profilers.graph_profile import MemoryProfilingInterpreter +from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs +from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0, is_backend_inductor +from .partitioner import get_wrapped_partitioner +from .inductor import register_custom_ops, patch_create_aot_dispatcher_function +from .input_storage import InputStorage + +remaining_schedule = None +next_pass_step = -1 +next_passes = None +current_passes = None + +param_manager: Dict[int, DSGraphParamManager] = {} + + +class GraphOrder: + + def __init__(self): + self.frames = OrderedDict() + + def __len__(self): + return len(self.frames) + + def add_graph(self, graph_id: int, frame_id: int): + if frame_id not in self.frames: + self.frames[frame_id] = (graph_id, None) + + def set_needs_backward(self, frame_id: int, needs_backward: bool): + if frame_id in self.frames: + self.frames[frame_id] = (self.frames[frame_id][0], needs_backward) + + def get_graph_order(self) -> List[Tuple[int, bool]]: + assert all(isinstance(needs_backward, bool) for _, needs_backward in self.frames.values()) + return list(self.frames.values()) + + def clear(self): + self.frames.clear() + + +graph_order_with_frame_id = GraphOrder() + +frames_needing_bwd = set() +frames_partitioned: Set[int] = set() +profiling_results: Dict[int, ProfilingResult] = {} +opt_pass_times = [] +opt_passes = {} + +fwd_real_inputs = [] + + +def register_compile_pass(name: str, opt_pass_fn): + opt_passes[name] = opt_pass_fn + + +def init_schedule(schedule): + + assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}" + + for step, passes in schedule: + assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}" + assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}" + + global remaining_schedule + remaining_schedule = deque(schedule) + + +def launch_compile_passes(global_steps: int): + global next_pass_step, next_passes + + if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]: + _, next_passes = remaining_schedule.popleft() + log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True) + + torch._dynamo.reset() + get_deepcompile_handle().reset() + graph_order_with_frame_id.clear() + profiling_results.clear() + param_manager.clear() + frames_partitioned.clear() + + +def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results): + node_time = [] + tensor_sizes = [] + + for n in graph.nodes: + node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0, + n.meta["wall_time"] if "wall_time" in n.meta else 0.0)) + tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0)) + + if bwd: + profiling_results[graph_id].bwd_graph = graph + profiling_results[graph_id].bwd_time = node_time + profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes + profiling_results[graph_id].bwd_mem = mem + else: + profiling_results[graph_id].fwd_graph = graph + profiling_results[graph_id].fwd_time = node_time + profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes + profiling_results[graph_id].fwd_mem = mem + + +def evaluate_symint_from_shape_env(sym_int_v): + assert isinstance(sym_int_v, torch.SymInt) + # shape_env = sym_int_v.node.shape_env + # v = shape_env.evaluate_sym_node(sym_int_v.node) + return sym_int_v.node.hint + + +def set_example_values_to_symints(real_inputs, param_indices=None): + real_inputs_ret = [] + + # Create a set of parameter indices for quick lookup + param_idx_set = set() + if param_indices is not None: + param_idx_set = {i for i, _, _ in param_indices} + + for i, v in enumerate(real_inputs): + if isinstance(v, torch.Tensor): + if is_fake(v): + shape = [] + for fs in v.shape: + if isinstance(fs, torch.SymInt): + shape.append(evaluate_symint_from_shape_env(fs)) + else: + shape.append(fs) + stride = [] + for fs in v.stride(): + if isinstance(fs, torch.SymInt): + stride.append(evaluate_symint_from_shape_env(fs)) + else: + stride.append(fs) + with unset_fake_temporarily(): + dummy_v = torch.empty_strided(shape, + stride, + dtype=v.dtype, + layout=v.layout, + device=v.device, + requires_grad=v.requires_grad).zero_() + + # Create Parameter if this input index corresponds to a parameter + if i in param_idx_set: + dummy_v = torch.nn.Parameter(dummy_v) + # Copy any additional attributes from the original if they exist + if hasattr(v, 'ds_id'): + dummy_v.ds_id = v.ds_id + + real_inputs_ret.append(dummy_v) + else: + real_inputs_ret.append(v) + else: + if isinstance(v, torch.SymInt): + real_inputs_ret.append(evaluate_symint_from_shape_env(v)) + else: + real_inputs_ret.append(v) + + return tuple(real_inputs_ret) + + +def run_opt_passes(opt_passes: List[Callable], + gm: GraphModule, + graph_id: int, + graph_order: List[Tuple[int, bool]], + profiling_results, + create_inputs_fn, + mem_budget: float, + param_manager, + bwd: bool, + debug_log=False) -> None: + + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() + + for i, opt_pass_fn in enumerate(opt_passes): + log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log) + + gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, + bwd) + if gm_new is not None: + gm = gm_new + gm.graph.lint() + gm.recompile() + + mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) + mem_prof.run(*create_inputs_fn()) + mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] + + set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) + + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() + + +def make_backend(backend, compile_config, compile_kwargs={}): + + register_custom_ops() + + # Extract values from compile_config + debug_log = compile_config.debug_log + free_activation = compile_config.free_activation and not is_backend_inductor(backend) + + def backend_fn(gm: GraphModule, real_inputs): + graph_id = id(gm.graph) + + # Checking the existence of input tensors requiring grad alone is insufficient to determine `need_backward`. + # AOT autograd also checks the graph data flow and skips the backward pass if no output requires grad and no + # input requiring grad is mutated. + # + # Instead of replicating AOT autograd's backward pass determination (which is too costly), we infer whether + # backward pass is needed by checking if the joint graph is partitioned (into a forward and a backward module). + # This check cannot be placed here because autograd creates the fw/bw compiler callables before graph + # partitioning. It is thus postponed to the point where the fw compiler is called. + frame_id = gm.meta["dynamo_compile_id"].frame_id + graph_order_with_frame_id.add_graph(graph_id, frame_id) + + z3_partition = any(hasattr(v, "ds_id") for v in real_inputs) + if z3_partition: + param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs) + if isinstance(input_val, torch.nn.Parameter)] + else: + assert all(hasattr(v, "param_id") for v in real_inputs + if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id" + param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs) + if isinstance(input_val, torch.nn.Parameter)] + + global fwd_real_inputs + + # Create an InputStorage instance for this specific graph + # It will be captured by the make_fw_graph closure, eliminating the need for graph ID management + input_storage = InputStorage(keep_int_input_tensors=compile_config.keep_int_input_tensors, + keep_all_input_tensors=compile_config.keep_all_input_tensors) + + # Store in both list (for backward compatibility) and storage (for persistence) + # The input_storage keeps tensor metadata to handle cases where + # backend_fn is called once but make_fw_graph is called multiple times + fwd_real_inputs.append(real_inputs) + input_storage.put(real_inputs) + + global profiling_results + if graph_id not in profiling_results: + profiling_results[graph_id] = ProfilingResult() + profiling_results[graph_id].param_indices = param_indices + + def make_fw_graph(gm, sample_inputs): + time_start = time.time() + graph_index = len(graph_order_with_frame_id) - 1 + + needs_backward = frame_id in frames_partitioned + graph_order_with_frame_id.set_needs_backward(frame_id, needs_backward) + profiling_results[graph_id].needs_backward = needs_backward + + if needs_backward: + if len(frames_needing_bwd) == 0: + patch_compiled_func() + frames_needing_bwd.add(frame_id) + + # Try to get real_inputs from the list first, then from storage + if fwd_real_inputs: + real_inputs = fwd_real_inputs.pop(0) + elif input_storage.has_data(): + # Note: input_storage is captured from the enclosing backend_fn scope + # Materialize tensors from storage when list is empty + log_rank0(f"Retrieving real inputs from storage for graph_id={graph_id}", enable=debug_log) + real_inputs = input_storage.get() + else: + raise RuntimeError(f"No real inputs available for graph_id {graph_id}. " + f"List size: {len(fwd_real_inputs)}, Storage has data: {input_storage.has_data()}") + + real_inputs = set_example_values_to_symints(real_inputs) + + param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices) + + real_inputs_with_rng = real_inputs + tuple(sample_inputs[len(real_inputs):]) + run_opt_passes( + opt_passes=next_passes, + gm=gm, + graph_id=graph_id, + graph_order=graph_order_with_frame_id.get_graph_order(), + profiling_results=profiling_results, + create_inputs_fn=lambda: real_inputs_with_rng, + mem_budget=.0, # unused + param_manager=param_manager, + bwd=False, + debug_log=debug_log) + + opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start)) + + log_rank0(f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()}", + enable=debug_log) + + return gm.graph + + def make_bw_graph(gm, sample_inputs): + time_start = time.time() + + graph_order = graph_order_with_frame_id.get_graph_order() + graph_index = get_index_by_graph_id(graph_order, graph_id) + log_rank0( + f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", + enable=debug_log) + + bwd_inputs_stack = get_backward_inputs() + + param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph) + if len(bwd_inputs_stack) == 0: + # dynamo calls bw compiler ahead of time when symints are saved for backward. See the details for aot_dispatch_autograd in jit_compile_runtime_wrappers. + # As we currently use actually bwd input values in bw compiler, we make dummy data for profiling. + # Replace fake tensors with real parameters before calling set_example_values_to_symints + log_rank0(f"Generating dummy backward inputs for profiling. graph_id={graph_id}", enable=True) + sample_inputs_with_real_params = param_manager[graph_id].replace_fake_tensors_with_real_params( + sample_inputs, gm.graph) + bwd_real_inputs = set_example_values_to_symints(sample_inputs_with_real_params) + else: + bwd_real_inputs = bwd_inputs_stack.pop() + + run_opt_passes( + opt_passes=next_passes, + gm=gm, + graph_id=graph_id, + graph_order=graph_order, + profiling_results=profiling_results, + create_inputs_fn=lambda: tuple(bwd_real_inputs), + mem_budget=.0, # unused + param_manager=param_manager, + bwd=True, + debug_log=debug_log) + + # assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager" + + if free_activation: + param_names = [n.name for n in param_nodes_bw] + non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names] + add_free_activations(graph_id, gm.graph, + get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names)) + + frames_needing_bwd.remove(frame_id) + if len(frames_needing_bwd) == 0: + unpatch_compiled_func() + + log_rank0( + f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", + enable=debug_log) + + opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start)) + + return gm.graph + + if backend == "eager": + + def make_compiler_fn(make_graph_fn): + + def compiler_fn(gm, sample_inputs): + return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward) + + return compiler_fn + + partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition, + frame_id, frames_partitioned) + aot_mod = aot_module_simplified(gm, + real_inputs, + fw_compiler=make_compiler_fn(make_fw_graph), + bw_compiler=make_compiler_fn(make_bw_graph), + partition_fn=partition_fn) + return torch._dynamo.optimize(**compile_kwargs)(aot_mod) + elif backend == "inductor": + patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs, + param_indices, param_manager, frame_id, frames_partitioned) + + return torch._inductor.compile(gm, real_inputs) + + raise ValueError(f"Unsupported backend {backend}") + + return backend_fn diff --git a/deepspeed/compile/config.py b/deepspeed/compile/config.py new file mode 100644 index 000000000000..2137b94722f2 --- /dev/null +++ b/deepspeed/compile/config.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Optional, Literal +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + +PassName = Literal["z1", "z3", "autosp"] + + +class CompileConfig(DeepSpeedConfigModel): + """ Configure compile settings """ + + deepcompile: bool = False + """ Turn on/off the DeepCompile mode """ + + free_activation: bool = False + """ Turn on/off the free activation mode """ + + free_activation_threshold: int = 10 * 1024 * 1024 + """ In free activation mode, activations no less than this threshold (in byte) are eagerly freed """ + + offload_activation: bool = False + """ Turn on/off the activation offloading """ + + offload_opt_states: bool = False + """ Turn on/off the optimizer states offloading """ + + double_buffer: bool = True + """ Turn on/off the double buffering """ + + symmetric_memory: bool = False + """ Turn on/off the symmetric memory """ + + debug_log: bool = False + """ Turn on/off the graph dumping """ + + offload_parameters: bool = False + """ Turn on/off the parameter offloading """ + + sync_before_reduce: bool = False + """ Turn on/off the sync before reduce """ + + sync_after_reduce: bool = False + """ Turn on/off the sync after reduce """ + + sync_before_allgather: bool = False + """ Turn on/off the sync before allgather """ + + sync_after_allgather: bool = False + """ Turn on/off the sync after allgather """ + + keep_int_input_tensors: bool = True + """ Keep real values for int tensors in InputStorage instead of using dummy values """ + + keep_all_input_tensors: bool = False + """ Keep real values for all input tensors in InputStorage instead of using dummy values """ + + passes: Optional[List[PassName]] = None + """ Composes different optimizations. """ diff --git a/deepspeed/compile/constants.py b/deepspeed/compile/constants.py new file mode 100644 index 000000000000..e365b692a7d8 --- /dev/null +++ b/deepspeed/compile/constants.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######################################### +# AUTOSP +######################################### +AUTOSP_INPUT_ID_KEY = "input_id" +AUTOSP_LABEL_ID_KEY = "label_id" +AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py new file mode 100644 index 000000000000..e5fc593a2e7e --- /dev/null +++ b/deepspeed/compile/custom_ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .all_to_all import all_to_all +from . import sp_dp_registry + +__all__ = ["all_to_all", "sp_dp_registry", "sp_compat"] diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py new file mode 100644 index 000000000000..3307bbc527ff --- /dev/null +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +from torch.utils._sympy.functions import FloorDiv +from .sp_dp_registry import get_group, is_setup, sp_size + + +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) +def all_to_all( + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + name: str, +) -> torch.Tensor: + """ + All-to-all collective for SDPA tensors [B, N, S, H]. + + For QKV (scatter_idx=1, gather_idx=2): + [B, N, S/P, H] -> [B, N/P, S, H] + For O (scatter_idx=2, gather_idx=1): + [B, N/P, S, H] -> [B, N, S/P, H] + """ + assert is_setup(), 'Incorrect initialization of SP/DP mesh.' + B, dim1, dim2, H = input.shape + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + if scatter_idx == 1: + N, local_S = dim1, dim2 + input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) + else: + local_N, S = dim1, dim2 + input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) + + return output + + +@torch.library.register_fake("autosp::all_to_all") +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): + + def maybe_restore_sharded_dim(dim: torch.SymInt, factor: int): + # Torch 2.9 may keep `P * (s // P)` distinct from the original `s` during + # fake shape propagation. When the local dim is exactly `FloorDiv(s, P)`, + # restore the original symbol so downstream ops see a consistent sequence dim. + node = getattr(dim, "node", None) + if node is None: + return dim * factor + + expr = node.expr + if isinstance(expr, FloorDiv) and expr.args[1] == factor: + hint = node.hint * factor if node.has_hint() else None + return node.shape_env.create_symintnode(expr.args[0], hint=hint) + + return dim * factor + + B, dim1, dim2, H = input.shape + if scatter_idx == 1: + return input.new_empty(B, dim1 // sp_size(), maybe_restore_sharded_dim(dim2, sp_size()), H) + else: + return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) + + +def _all_to_all_backward_setup(ctx, inputs, output): + _, scatter_idx, gather_idx, name = inputs + ctx.scatter_idx = gather_idx + ctx.gather_idx = scatter_idx + ctx.name = name + "_grad" + + +def _all_to_all_backward(ctx, grad): + return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None) + + +torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup) diff --git a/deepspeed/compile/custom_ops/sp_compat.py b/deepspeed/compile/custom_ops/sp_compat.py new file mode 100644 index 000000000000..136d01edce0b --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_compat.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from packaging.version import Version + + +def _check_autosp_compatibility(): + # Strip the local version segment (e.g. +cu128) so CUDA builds don't sort + # above the max bound when using packaging's local-version ordering rules. + torch_version = Version(torch.__version__.split("+")[0]) + if torch_version < Version("2.9"): + raise RuntimeError("AutoSP requires PyTorch >= 2.9, found " + f"{torch.__version__}.") + + try: + import transformers + if Version(transformers.__version__) > Version("4.50.3"): + raise RuntimeError("AutoSP requires transformers <= 4.50.3, found " + f"{transformers.__version__}.") + except ImportError: + pass # transformers not installed; skip the check diff --git a/deepspeed/compile/custom_ops/sp_dp_registry.py b/deepspeed/compile/custom_ops/sp_dp_registry.py new file mode 100644 index 000000000000..a93707032959 --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_dp_registry.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.comm as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group() + + +def get_registry(): + return GROUP_REGISTRY + + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + + +def extract_mesh_size(param_dict): + sp_size = param_dict.get('sequence_parallel_size', 1) + assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE' + dp_size = dist.get_world_size() // sp_size + + return sp_size, dp_size + + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + + +def populate_registry(SP_SIZE, DP_SIZE): + """ Populate rank to SP/DP mesh index. """ + + if GROUP_REGISTRY.get('is_reg', False): + return + + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py new file mode 100644 index 000000000000..51a2147ab7c2 --- /dev/null +++ b/deepspeed/compile/fx.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Callable, Any, List, Dict, Optional +from collections import defaultdict + +import torch +from torch.fx import Node, Graph, GraphModule +from torch.fx.node import map_aggregate + +from .util import get_last_uses + + +def get_output_node(graph: Graph): + for v in graph.nodes: + if v.target == "output": + return v + raise ValueError("No output node found") + + +def add_end_backward(graph: Graph, graph_id: int): + reduce_nodes = [n for n in graph.nodes if n.target == torch.ops.dc.reduce_grad.default] + if len(reduce_nodes) == 0: + return + + with graph.inserting_before(get_output_node(graph)): + graph.create_node("call_function", torch.ops.dc.end_backward.default, (reduce_nodes, graph_id)) + + +def replace_reduce_outputs_with_none(graph: Graph): + output_node = get_output_node(graph) + new_outputs = map_aggregate( + output_node.args[0], lambda n: None + if isinstance(n, Node) and n.target == torch.ops.dc.reduce_grad.default else n) + output_node.args = (new_outputs, ) + + +def move_primals_to_head(graph: Graph): + + # Move primals to the head of the graph + primals = [n for n in graph.nodes if n.op == "placeholder"] + non_primals = [n for n in graph.nodes if n.op != "placeholder"] + all_nodes = primals + non_primals + + new_graph = Graph() + env = {} + for node in all_nodes: + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + new_graph.lint() + + return new_graph + + +def add_args_process(graph: Graph, + node: Node, + fn: Callable[..., Any], + extra_args: List[int] = [], + name=None, + meta={}) -> List[Node]: + # Apply fn to all args of node + new_nodes = [] + with graph.inserting_before(node): + target_args = [arg for arg in node.args if isinstance(arg, Node)] + + for arg in target_args: + new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name) + for k, v in meta.items(): + new_node.meta[k] = v + node.replace_input_with(arg, new_node) + new_nodes.append(new_node) + + return new_nodes + + +def add_postprocess(graph: Graph, + node: Node, + fn: Callable[..., Any], + extra_args: List[Any] = [], + extra_kwargs: Dict[str, Any] = {}, + name=None, + meta={}) -> Node: + # https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py + with graph.inserting_after(node): + args = (node, ) + for a in extra_args: # To add ds_id + args += (a, ) + + node_users = node.users.keys() + new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name) + users = {} + for u in node_users: + if u != new_node: + users[u] = (node, new_node) + for u, (old_in, new_in) in users.items(): + u.replace_input_with(old_in, new_in) + + for k, v in meta.items(): + new_node.meta[k] = v + + return new_node + + +def _make_node_meta(node: Node, ds_id: int, comm: bool): + meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm} + if "tensor_meta" in node.meta: + meta["tensor_meta"] = node.meta["tensor_meta"] + return meta + + +def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]): + node_to_last_use, _ = get_last_uses(graph) + activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names]) + + offload_id_to_node = {} + node_to_wait_reload = {} + for node in graph.nodes: + if node.target == torch.ops.dc.reload_tensor.default: + offload_act = node.args[0] + # node_to_offload_id[offload_act] = node.args[2] + offload_id_to_node[node.args[2]] = offload_act + elif node.target == torch.ops.dc.wait_reload.default: + offload_id = node.args[2] + node_to_wait_reload[offload_id_to_node[offload_id]] = node + + activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set) + + last_user_to_uses = defaultdict(list) + for node, last_user in node_to_last_use.items(): + last_user_to_uses[last_user].append(node) + + def _should_free(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if "tensor_meta" not in node.meta: + return False + return True + + def free_tensors(tensors: List[torch.Tensor]): + for a in tensors: + if a.numel() > 10_000_000: + a.data = torch.empty([0], device=a.device, dtype=a.dtype) + + for last_user, used_nodes in last_user_to_uses.items(): + activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)] + + if len(activation_args) == 0: + continue + + node_name = f"free_activations_{[n.name for n in used_nodes]}" + with graph.inserting_after(last_user): + args = (activation_args, ) + graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name) + + # Python version for debugging + # graph.create_node('call_function', free_tensors, args, {}, name=node_name) + + +def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]: + for node in gm.graph.nodes: + if node.name == name: + return node + return None + + +def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]: + return node.meta.get("val") or node.meta.get("example_value") + + +def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]: + input_id_node = None + for node in gm.graph.nodes: + # https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048 + tensor_dict = node.meta.get('tensor_dict') + if tensor_dict and tensor_dict.get('tag') == tag: + input_id_node = node + break + return input_id_node + + +def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None): + exclude = exclude or [] + to_replace = [u for u in node.users if u not in exclude] + for user in to_replace: + user.replace_input_with(node, replacement) diff --git a/deepspeed/compile/graph_param.py b/deepspeed/compile/graph_param.py new file mode 100644 index 000000000000..5a330562b39c --- /dev/null +++ b/deepspeed/compile/graph_param.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple +from functools import reduce + +import torch +from torch.fx import Graph, Node + +from .fx import get_output_node +from .util import get_param_nodes, get_input_nodes + + +@dataclass +class DSGraphParam: + name: str + shape: torch.Size + dtype: torch.dtype + device: torch.device + node: Node + allgather_node: Node + release_node: Node + param: torch.Tensor + numel: int = field(init=False) + + def __post_init__(self): + self.numel = reduce(lambda x, y: x * y, self.shape) + + +class DSGraphParamManager: + + def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]): + self._fw_graph = fw_graph + self._bw_graph = None + self._params: Dict[str, DSGraphParam] = {} + self._param_name_to_grad: Dict[str, Node] = {} + self._ds_ids: Dict[str, int] = {} + + param_nodes = get_param_nodes(fw_graph, index_to_ds_ids) + self._param_names = [pn.name for pn in param_nodes] + self._param_indices = [i for i, _, _ in index_to_ds_ids] + + param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids] + ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids] + ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids] + + for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes): + self._params[pn.name] = DSGraphParam(name=pn.name, + shape=ds_shape, + dtype=pi.dtype, + device=pi.device, + node=pn, + allgather_node=None, + release_node=None, + param=pi) + self._ds_ids[pn.name] = ds_id + + def get_bwd_mapping(self, bw_graph: Graph): + self._bw_graph = bw_graph + + output_node = get_output_node(bw_graph) + param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names] + grad_outputs = [output_node.args[0][i] for i in self._param_indices] + param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)} + return param_nodes_bw, param_name_to_grad + + @property + def param_names(self) -> List[str]: + return self._param_names + + @property + def params(self) -> Dict[str, DSGraphParam]: + return self._params + + @property + def ds_ids(self) -> Dict[str, int]: + return self._ds_ids + + def get_grad_name(self, param_name) -> str: + assert self._param_name_to_grad is not None, "Backward graph is not added yet" + return self._param_name_to_grad[param_name] + + def replace_fake_tensors_with_real_params(self, sample_inputs: List[Any], bw_graph: Graph) -> List[Any]: + """Replace fake tensors in sample_inputs with real parameters from DSGraphParamManager + + Args: + sample_inputs: The input tensors that may contain fake tensors + bw_graph: The backward graph to get parameter mapping from (if in backward pass) + """ + replaced_inputs = list(sample_inputs) + + # For backward pass, get the parameter nodes and their mapping + param_nodes_bw, _ = self.get_bwd_mapping(bw_graph) + param_names_bw = [n.name for n in param_nodes_bw] + + for i, inp in enumerate(get_input_nodes(bw_graph)): + if inp.name in param_names_bw: + replaced_inputs[i] = self._params[inp.name].param + + return replaced_inputs diff --git a/deepspeed/compile/inductor.py b/deepspeed/compile/inductor.py new file mode 100644 index 000000000000..0fe880260439 --- /dev/null +++ b/deepspeed/compile/inductor.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Set + +import torch + +try: + import torch.utils._pytree as pytree + from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs + from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode + from torch._inductor.virtualized import V + from torch._inductor.scheduler import Scheduler +except ImportError: + pass + +from deepspeed.utils.torch import required_torch_version +from .util import get_input_nodes +from .graph_param import DSGraphParamManager +from .partitioner import get_wrapped_partitioner + + +def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool): + + def wrapped_compiler(gm, fake_inputs): + mod_graph = dc_compiler(gm, fake_inputs) + + # For symint case + if mod_graph is None: + return None + + if z3_partition: + # Inductor validates input size estimated by the first trace, where ds tensor is materialized. + # We need to patch the input tensors to avoid the validation error. + patched_inputs = [] + if bwd: + param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph) + param_names = [n.name for n in param_nodes_bw] + else: + param_names = graph_param_manager[graph_id].param_names + input_nodes = get_input_nodes(gm.graph) + + for in_node, in_v in zip(input_nodes, fake_inputs): + ds_param = in_node.name in param_names + if ds_param: + from torch._subclasses.fake_tensor import is_fake + from torch._dynamo.utils import to_fake_tensor + assert is_fake(in_v), f"Input {in_v} should be fake tensor" + patched_inputs.append( + to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode)) + else: + patched_inputs.append(in_v) + + patched_inputs = tuple(patched_inputs) + else: + patched_inputs = fake_inputs + + return original_compiler(gm, patched_inputs) + + return wrapped_compiler + + +def wrap_partition_fn(z3_partition: bool, partition_fn, real_inputs, param_indices, frame_id: int, + frames_partitioned: Set[int]): + + def wrapped_partition_fn(*args, **kwargs): + + fn = get_wrapped_partitioner(z3_partition, + param_indices, + partition_fn=partition_fn, + frame_id=frame_id, + frames_partitioned=frames_partitioned) + fw_module, bw_module = fn(*args, **kwargs) + + if z3_partition: + # get parameter names + pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices) + + def fix_placeholder_meta(graph): + for n in graph.nodes: + if n.op == "placeholder" and n.name in pm.param_names: + n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device) + + fix_placeholder_meta(fw_module.graph) + fix_placeholder_meta(bw_module.graph) + + return fw_module, bw_module + + return wrapped_partition_fn + + +def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs, + param_indices, param_manager, frame_id: int, frames_partitioned: Set[int]): + + from torch._dynamo.backends.common import AotAutograd + import functools + + def patch_aotautograd(): + # Unpatch if it was already patched + if hasattr(AotAutograd, "__original_init"): + AotAutograd.__init__ = AotAutograd.__original_init + + original_init = AotAutograd.__init__ + + @functools.wraps(original_init) + def patched_init(self, **kwargs): + kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"], + make_fw_graph, + z3_partition, + graph_id, + param_manager, + bwd=False) + kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"], + make_bw_graph, + z3_partition, + graph_id, + param_manager, + bwd=True) + kwargs["inference_compiler"] = kwargs["fw_compiler"] + + kwargs["partition_fn"] = wrap_partition_fn(z3_partition, kwargs["partition_fn"], real_inputs, + param_indices, frame_id, frames_partitioned) + + original_init(self, **kwargs) + + AotAutograd.__original_init = original_init + AotAutograd.__init__ = patched_init + + patch_aotautograd() + + +def register_custom_ops(): + + def fallback_handler_no_reuse(kernel, + never_reuse_input, + never_reuse_output, + force_free_input, + add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + + def wrap_tensors(x): + out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x + if out is not None and never_reuse_output: + V.graph.never_reuse_buffers.add(out.get_name()) + return out + + class CustomDCKernel(FallbackKernel): + + def __init__(self, op, *args, **kwargs): + super().__init__(op, *args, **kwargs) + + def add_to_never_reuse(x): + if isinstance(x, IRNode): + assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}" + V.graph.never_reuse_buffers.add(x.get_name()) + + if never_reuse_input: + pytree.tree_map(add_to_never_reuse, args) + + def get_var_name_for_arg(self, arg: str): + if arg.isidentifier(): + return arg + + import re + match = re.match(r"reinterpret_tensor\((\w+),", arg) + if match: + return match.group(1) + return None + + def codegen(self, wrapper): + if not force_free_input: + return super().codegen(wrapper) + + kernel = self.op_overload + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + + if required_torch_version(min_version=2.8): + V.graph.wrapper_code.generate_fallback_kernel(self) + else: + V.graph.wrapper_code.generate_fallback_kernel(self, args) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + var_name = self.get_var_name_for_arg(args[0]) + if var_name: + wrapper.writeline(f"{var_name} = None") + + self.codegen_unbacked_symbol_defs(wrapper) + + kernel_cls = CustomDCKernel if force_free_input else FallbackKernel + return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs)) + + return handler + + def register_fallback_no_reuse(op_overload, + never_reuse_input=False, + never_reuse_output=False, + force_free_input=False): + add_needs_realized_inputs(op_overload) + return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse( + op_overload, + never_reuse_input=never_reuse_input, + never_reuse_output=never_reuse_output, + force_free_input=force_free_input)) + + # Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops. + # -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops. + register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True) + register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True) + register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False) + register_fallback_no_reuse(torch.ops.dc.reduce_grad.default, + never_reuse_input=True, + never_reuse_output=True, + force_free_input=True) + register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True) + register_fallback_no_reuse(torch.ops.dc.end_backward.default, never_reuse_input=True, never_reuse_output=False) + + if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched: + Scheduler.is_dc_patched = True + Scheduler.dead_node_elimination = lambda _: None diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py new file mode 100644 index 000000000000..23ffd28cae4c --- /dev/null +++ b/deepspeed/compile/init_sp.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch.fx import GraphModule +from .passes.sp_compile import apply_autosp +from .passes.long_context_checkpointing import register_long_context_checkpointing +from .custom_ops.sp_dp_registry import extract_mesh_size +from .custom_ops.sp_compat import _check_autosp_compatibility + + +def init_autosp(config): + _check_autosp_compatibility() + sp_size, dp_size = extract_mesh_size(config._param_dict) + register_long_context_checkpointing() + + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size) + return torch._inductor.compile(gm, real_inputs) + + return backend_fn diff --git a/deepspeed/compile/init_z1.py b/deepspeed/compile/init_z1.py new file mode 100644 index 000000000000..f73a1953f7e4 --- /dev/null +++ b/deepspeed/compile/init_z1.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy + +import torch + +from deepspeed.accelerator import get_accelerator +from .passes import zero1_compile, zero3_compile +from .backend import make_backend, launch_compile_passes, init_schedule +from .util import get_deepcompile_handle, add_pre_backward_hook + +WARMUP = 5 + + +def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_z2=False): + + optimizer = engine.optimizer + optimizer.contiguous_gradients = False # Avoid creating unnecessary buffer + for hook in optimizer._grad_acc_hooks: + hook.remove() + optimizer._grad_acc_hooks.clear() + + dc = get_deepcompile_handle() + dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size()) + + grad_buffer = {} + + # Save original all_grad_tensors state as we temporarily modify it + original_all_grad_tensors = optimizer.all_grad_tensors.copy() if hasattr(optimizer, 'all_grad_tensors') else {} + + for i, group in enumerate(optimizer.bit16_groups): + # Temporarily populate all_grad_tensors for get_flat_partition call + # This is needed because get_flat_partition accesses all_grad_tensors[param_group_idx][i] + # but it's empty during initialization + if i not in optimizer.all_grad_tensors or optimizer.all_grad_tensors[i] is None: + optimizer.all_grad_tensors[i] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[i], + optimizer.gradient_accumulation_dtype) + + grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i], + optimizer.first_offset[i], + optimizer.partition_size[i], + dtype=optimizer.gradient_accumulation_dtype, + device=get_accelerator().current_device_name(), + param_group_idx=i, + return_tensor_list=True) + grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary + + index_in_partition = 0 + first_in_partition = True + for p in group: + param_id = optimizer.get_param_id(p) + p.param_id = param_id + in_partition = optimizer.is_param_in_current_partition[param_id] + + if in_partition: + buf = grad_buffer[i][index_in_partition] + offset = optimizer.first_offset[i] if first_in_partition else 0 + # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}") + dc.register_param(p.param_id, p.shape, p, buf, int(offset)) + index_in_partition += 1 + first_in_partition = False + else: + # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None") + dc.register_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0) + + # Restore original all_grad_tensors state + optimizer.all_grad_tensors = original_all_grad_tensors + + def set_grad_buffer(): + optimizer.averaged_gradients = copy.copy(grad_buffer) + + add_pre_backward_hook(set_grad_buffer) + + if schedule is None: + schedule = [] + if use_z2: + schedule.append((0, [zero1_compile.add_z2_reduce])) + else: + schedule.append((0, [zero1_compile.add_z1_reduce])) + else: + for opt in schedule: + # avoid typical misconfiguration + if zero3_compile.add_z3_gather_release in opt[1]: + raise ValueError("A pass for ZeRO3 is not specified though ZeRO1 is enabled") + + init_schedule(schedule) + + engine.launch_compile_passes = launch_compile_passes + return make_backend(backend, compile_config, compile_kwargs=compile_kwargs) diff --git a/deepspeed/compile/init_z3.py b/deepspeed/compile/init_z3.py new file mode 100644 index 000000000000..2b9d404b3781 --- /dev/null +++ b/deepspeed/compile/init_z3.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload + +from .passes import zero3_compile, prefetch, selective_gather, offload_parameters +from .backend import make_backend, launch_compile_passes, init_schedule +from .patch_fake_tensor import patch_fake_tensor +from .util import get_deepcompile_handle, add_pre_backward_hook + +WARMUP = 5 + + +def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): + + optimizer = engine.optimizer + use_opt = not isinstance(optimizer, DeepSpeedZeRoOffload) + + if use_opt and hasattr(optimizer, "ipg_buckets"): + optimizer.ipg_buckets.clear() + get_accelerator().empty_cache() + + dc = get_deepcompile_handle() + dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size()) + + # Unset hooks + for m in engine.module.modules(): + m._parameters = m._original_parameters + + if use_opt: + optimizer.parameter_offload._remove_module_hooks() + + for hook in optimizer._grad_acc_hooks: + hook.remove() + optimizer._grad_acc_hooks.clear() + + # Unpatch linear + if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): + torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk + + if compile_config.symmetric_memory: + group_name = engine.data_parallel_group.group_name + dist.enable_symm_mem_for_group(group_name) + + for p in engine.module.parameters(): + grad_buffer = torch.Tensor() + if use_opt: + grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id] + + # Disable persistent param + p.ds_persist = False + dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist) + + if schedule is None: + schedule = [] + if (compile_config.offload_parameters): + schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd])) + else: + schedule.append((0, [zero3_compile.add_z3_gather_release])) + schedule.append( + (WARMUP, + [zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather])) + + init_schedule(schedule) + + if use_opt: + + def set_grad_buffer(): + for i, sub_group in enumerate(optimizer.fp16_groups): + optimizer.averaged_gradients[i] = [ + optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group + ] + + add_pre_backward_hook(set_grad_buffer) + + # offloading opt states need additional setup + from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states + for _, passes in schedule: + if move_opt_states in passes or move_opt_states_sync in passes: + init_offload_opt_states(optimizer, dc) + + engine.launch_compile_passes = launch_compile_passes + + patch_fake_tensor() + torch._inductor.config.size_asserts = False + + return make_backend(backend, compile_config, compile_kwargs=compile_kwargs) diff --git a/deepspeed/compile/input_storage.py b/deepspeed/compile/input_storage.py new file mode 100644 index 000000000000..3866f1afb278 --- /dev/null +++ b/deepspeed/compile/input_storage.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Tuple, Optional +from dataclasses import dataclass + +import torch + + +@dataclass +class TensorMetadata: + """Metadata for a tensor to be stored in CPU memory""" + shape: Tuple[int, ...] + dtype: torch.dtype + device: torch.device + stride: Tuple[int, ...] + storage_offset: int + requires_grad: bool + layout: torch.layout + memory_format: torch.memory_format = torch.contiguous_format + real_data: Optional[torch.Tensor] = None # Store actual tensor data when configured + + +class InputStorage: + """Storage class to keep real inputs in CPU memory with tensor metadata""" + + def __init__(self, keep_int_input_tensors: bool = False, keep_all_input_tensors: bool = False): + self._stored_inputs: Any = None + self._has_data: bool = False + self._keep_int_input_tensors: bool = keep_int_input_tensors + self._keep_all_input_tensors: bool = keep_all_input_tensors + + def _is_int_tensor(self, tensor: torch.Tensor) -> bool: + """Check if tensor has integer dtype""" + return tensor.dtype in [ + torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.uint16, torch.uint32, torch.uint64, + torch.bool + ] + + def _extract_tensor_metadata(self, tensor: torch.Tensor) -> TensorMetadata: + """Extract metadata from a tensor""" + # Get memory format safely + try: + memory_format = tensor.memory_format() if hasattr(tensor, 'memory_format') else torch.contiguous_format + except Exception: + memory_format = torch.contiguous_format + + # Store real data for tensors if configured to do so + real_data = None + if self._keep_all_input_tensors or (self._keep_int_input_tensors and self._is_int_tensor(tensor)): + # Move to CPU to save GPU memory + real_data = tensor.detach().cpu() + + return TensorMetadata(shape=tuple(tensor.shape), + dtype=tensor.dtype, + device=tensor.device, + stride=tuple(tensor.stride()), + storage_offset=tensor.storage_offset(), + requires_grad=tensor.requires_grad, + layout=tensor.layout, + memory_format=memory_format, + real_data=real_data) + + def _store_value(self, value: Any) -> Any: + """ + Recursively store a value, converting tensors to metadata and keeping non-tensors as-is + """ + if isinstance(value, torch.Tensor): + return self._extract_tensor_metadata(value) + elif isinstance(value, (list, tuple)): + stored_items = [self._store_value(item) for item in value] + return type(value)(stored_items) if isinstance(value, tuple) else stored_items + elif isinstance(value, dict): + return {k: self._store_value(v) for k, v in value.items()} + else: + # For non-tensor values (int, float, str, bool, etc.), store as-is + return value + + def _materialize_value(self, stored_value: Any) -> Any: + """ + Recursively materialize a stored value, creating tensors from metadata and keeping non-tensors as-is + """ + if isinstance(stored_value, TensorMetadata): + # If we have real data stored, use it + if stored_value.real_data is not None: + try: + # Use the stored real data + tensor = stored_value.real_data.clone() + + # Set stride if different from default and tensor is contiguous + if tensor.stride() != stored_value.stride and len(stored_value.shape) > 0: + try: + # Create tensor with specific stride + tensor = torch.as_strided(tensor, stored_value.shape, stored_value.stride, + stored_value.storage_offset) + except RuntimeError: + # If stride setting fails, use default stride + pass + + # Move to target device and set requires_grad + tensor = tensor.to(device=stored_value.device) + tensor.requires_grad_(stored_value.requires_grad) + + return tensor + + except Exception as e: + # Fallback to dummy data if real data fails + pass + + # Create a tensor with the stored metadata (original behavior for non-int tensors) + # Use CPU first to avoid GPU memory issues, then move to target device + try: + tensor = torch.empty(stored_value.shape, + dtype=stored_value.dtype, + layout=stored_value.layout, + device='cpu') + + # Fill with dummy data (ones) for profiling purposes + tensor.fill_(1.0) + + # Set stride if different from default and tensor is contiguous + if tensor.stride() != stored_value.stride and len(stored_value.shape) > 0: + try: + # Create tensor with specific stride + tensor = torch.as_strided(tensor, stored_value.shape, stored_value.stride, + stored_value.storage_offset) + except RuntimeError: + # If stride setting fails, use default stride + pass + + # Move to target device and set requires_grad + tensor = tensor.to(device=stored_value.device) + tensor.requires_grad_(stored_value.requires_grad) + + return tensor + + except Exception as e: + # Fallback: create a simple tensor if anything fails + tensor = torch.ones(stored_value.shape, dtype=stored_value.dtype, device=stored_value.device) + tensor.requires_grad_(stored_value.requires_grad) + return tensor + + elif isinstance(stored_value, (list, tuple)): + materialized_items = [self._materialize_value(item) for item in stored_value] + return type(stored_value)(materialized_items) if isinstance(stored_value, tuple) else materialized_items + elif isinstance(stored_value, dict): + return {k: self._materialize_value(v) for k, v in stored_value.items()} + else: + # Non-tensor values are returned as-is + return stored_value + + def put(self, real_inputs: Any) -> None: + """ + Store real inputs + + Args: + real_inputs: The real inputs to store (can be tensors, lists, tuples, etc.) + """ + stored_inputs = self._store_value(real_inputs) + self._stored_inputs = stored_inputs + self._has_data = True + + def get(self) -> Any: + """ + Retrieve and materialize stored real inputs + + Returns: + Materialized real inputs with actual tensors + + Raises: + RuntimeError: If no inputs are stored + """ + if not self._has_data: + raise RuntimeError("No inputs stored in InputStorage") + + return self._materialize_value(self._stored_inputs) + + def has_data(self) -> bool: + """ + Check if storage contains inputs + + Returns: + True if inputs are stored, False otherwise + """ + return self._has_data + + def clear(self) -> None: + """Clear stored inputs""" + self._stored_inputs = None + self._has_data = False diff --git a/deepspeed/compile/list_schedule.py b/deepspeed/compile/list_schedule.py new file mode 100644 index 000000000000..dbccd4b1ebf1 --- /dev/null +++ b/deepspeed/compile/list_schedule.py @@ -0,0 +1,432 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from collections import defaultdict +from typing import List, Dict +from copy import copy +from dataclasses import dataclass + +import torch +from torch.fx import Graph, Node +from torch.fx.node import map_arg + +try: + from torch.utils._pytree import tree_iter +except ImportError: + pass + +from .util import get_last_uses, is_release_node +from .fx import get_output_node + + +def make_graph_from_schedule(scheduled: List[Node]): + new_graph = Graph() + env = {} + for node in scheduled: + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + + return new_graph + + +def get_original_args_num(node: Node): + if node.name.startswith("allgather_ds_param") \ + or node.name.startswith("release_ds_param") \ + or node.name.startswith("wait_allgather_ds_param") \ + or node.name.startswith("reduce_ds_param"): + return 1 + + return len(node.args) + + +def flat_nodes_in_args(args: List[Node]): + return [a for a in tree_iter(args) if isinstance(a, Node)] + + +def filter_args(node: Node): + args = node.args[:get_original_args_num(node)] + return flat_nodes_in_args(args) + + +def init_schedule(graph: Graph): + mem_table = create_mem_table(graph) + remaining_users = defaultdict(set) + user_to_producer = {} + + scheduled = [] + unscheduled = [] + edges = defaultdict(list) + for node in graph.nodes: + filtered_args = filter_args(node) + # print(f"Node: {node} args: {node.args}") + if len(filtered_args) == 0: + scheduled.append(node) + + remaining_users[node] = set(node.users.keys()) + for user in node.users.keys(): + user_to_producer[user] = node + else: + unscheduled.append(node) + for a in filtered_args: + for elem_a in tree_iter(a): + if isinstance(elem_a, Node): + if node not in edges[elem_a]: + edges[elem_a].append(node) + + return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer + + +def get_runnable_nodes(scheduled: List[Node], unscheduled: List[Node]): + scheduled = set(scheduled) + return [node for node in unscheduled if all(arg in scheduled for arg in filter_args(node))] + + +def choose_next_node(scheduled: List[Node], unscheduled: List[Node], mem_table: Dict[str, int]): + runnable_nodes = get_runnable_nodes(scheduled, unscheduled) + + # sort by memory usage + runnable_nodes = sorted(runnable_nodes, key=lambda n: mem_table[n.name]) + return runnable_nodes[0] + + +def create_mem_table(graph: Graph) -> Dict[str, int]: + mem_table = {} + for node in graph.nodes: + if node.name.startswith("allgather_ds_param"): + mem_table[node.name] = node.meta["tensor_size"] + elif node.name.startswith("release_ds_param") or node.name.startswith("reduce_ds_param"): + mem_table[node.name] = -node.meta["tensor_size"] + else: + mem_table[node.name] = 0 + + return mem_table + + +def list_schedule(graph: Graph) -> Graph: + + scheduled, unscheduled, mem_table = init_schedule(graph) + + while len(unscheduled) > 0: + next_node = choose_next_node(scheduled, unscheduled, mem_table) + scheduled.append(next_node) + unscheduled.remove(next_node) + + return make_graph_from_schedule(scheduled) + + +############################### + + +def get_new_runnable_nodes_with(scheduled: List[Node], edges: Dict[Node, List[Node]], new_scheduled: Node): + scheduled = set(scheduled) + new_runnables = [] + for node in edges[new_scheduled]: + if all(arg in scheduled for arg in filter_args(node) if arg != new_scheduled): + new_runnables.append(node) + + return new_runnables + + +def _do_schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]], + non_ag_runnable: List[Node]): + + while len(non_ag_runnable) > 0: + next_node = non_ag_runnable.pop() + + new_runnables = get_new_runnable_nodes_with(scheduled, edges, next_node) + non_ag_runnable += [n for n in new_runnables if not n.name.startswith("allgather_ds_param")] + + scheduled.append(next_node) + unscheduled.remove(next_node) + + return scheduled, unscheduled + + +def schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]]): + runnable = get_runnable_nodes(scheduled, unscheduled) + non_ag_runnable = [n for n in runnable if not n.name.startswith("allgather_ds_param")] + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + + return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable) + + +def try_schedule_with_new_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]], + new_scheduled: Node): + new_runnables = get_new_runnable_nodes_with(scheduled, edges, new_scheduled) + non_ag_runnable = [n for n in new_runnables if not n.name.startswith("allgather_ds_param")] + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + + tmp_scheduled.append(new_scheduled) + tmp_unscheduled.remove(new_scheduled) + + return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable) + + +def simple_prefetch(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph: + + scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph) + tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges) + + while len(tmp_unscheduled) > 0: + + runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled) + ag_with_unblock_time = [] + + for ag_node in runnable: + ag_scheduled, ag_unscheduled = try_schedule_with_new_allgather(tmp_scheduled, tmp_unscheduled, edges, + ag_node) + unblock_time = sum(n.meta["device_time"] for n in ag_scheduled[len(tmp_scheduled) + 1:]) + ag_with_unblock_time.append((ag_node, unblock_time, ag_scheduled, ag_unscheduled)) + + ag_with_unblock_time = sorted(ag_with_unblock_time, key=lambda x: x[1], reverse=True) + best_ag_node = ag_with_unblock_time[0][0] + best_ag_scheduled = ag_with_unblock_time[0][2] + + no_ag_runnables = tmp_scheduled[len(scheduled):] + after_ag_runnables = best_ag_scheduled[len(tmp_scheduled) + 1:] + + scheduled.append(best_ag_node) + unscheduled.remove(best_ag_node) + for n in no_ag_runnables: + scheduled.append(n) + unscheduled.remove(n) + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + for n in after_ag_runnables: + tmp_scheduled.append(n) + tmp_unscheduled.remove(n) + + return make_graph_from_schedule(tmp_scheduled) + + +############################### + + +def init_schedule_with_placeholders(graph: Graph): + mem_table = create_mem_table(graph) + remaining_users = defaultdict(set) + user_to_producer = {} + + scheduled = [] + unscheduled = [] + edges = defaultdict(list) + for node in graph.nodes: + if node.op == 'placeholder': + scheduled.append(node) + + remaining_users[node] = set(node.users.keys()) + for user in node.users.keys(): + user_to_producer[user] = node + else: + unscheduled.append(node) + + return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer + + +def get_node_requirements(target_node: Node, scheduled: List[Node]): + scheduled = set(scheduled) + visited = set() + ordered_nodes = [] + + def dfs(node: Node): + if node in scheduled: + return + if node in visited: + return + visited.add(node) + + args = [] + + def register_arg(n: Node): + args.append(n) + + map_arg(node.args, register_arg) + + for arg in args: + dfs(arg) + ordered_nodes.append(node) + + dfs(target_node) + + return ordered_nodes + + +@dataclass +class AllgatherTask: + node: Node + allgather_cost: float + free_cost: float + allgathered_mem: int + allgather_acc_mem: int + free_acc_mem: int + last_use: Node + n_scheduled_ags: int + schedule_until_ag: List[Node] + schedule_until_free: List[Node] + + +def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph: + node_to_last_use, user_to_last_uses = get_last_uses(graph) + + # check tensor size + for node in graph.nodes: + if "tensor_size" not in node.meta: + # Our profiler may not visit all nodes because of the control flow. + node.meta["tensor_size"] = 0 + + scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders( + graph) + + unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.dc.allgather_param.default] + + release_nodes = defaultdict(list) + for n in unscheduled: + if is_release_node(n): + release_nodes[n.args[2]].append(n) + + ag_nodes_in_path = {} + for ag_node in unscheduled_ags: + last_use = node_to_last_use[ag_node] + required_nodes = get_node_requirements(last_use, scheduled) + ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.dc.allgather_param.default) + + reduce_nodes = [n for n in unscheduled if n.target == torch.ops.dc.reduce_grad.default] + ag_nodes_in_path_to_reduce_nodes = {} + for reduce_node in reduce_nodes: + ag_nodes_in_path_to_reduce_nodes[reduce_node] = set(n for n in get_node_requirements(reduce_node, scheduled) + if n.target == torch.ops.dc.allgather_param.default) + + output_nodes = [ + n for n in get_output_node(graph).args[0] + if isinstance(n, Node) and n.target != torch.ops.dc.reduce_grad.default + ] + ag_nodes_in_path_to_output_nodes = {} + for output_node in output_nodes: + ag_nodes_in_path_to_output_nodes[output_node] = set(n for n in get_node_requirements(output_node, scheduled) + if n.target == torch.ops.dc.allgather_param.default) + + while len(unscheduled_ags) > 0: + + ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()} + count_list = sorted(set(ag_nodes_count.values())) + + runnable_ags = [] + for ag_count in count_list: + + target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count] + + for node in target_unscheduled_ags: + ds_id = node.args[2] + + schedule_until_ag = get_node_requirements(node, scheduled) + if schedule_until_ag is None: + continue + + last_use = node_to_last_use[node] + + diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag) + + allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag) + free_cost = sum(n.meta["device_time"] for n in diff_required_nodes) + allgathered_mem = node.meta["tensor_size"] + allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag + if n.target == torch.ops.dc.allgather_param.default) + free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes + if n.target == torch.ops.dc.allgather_param.default) + + schedule_until_free = schedule_until_ag + diff_required_nodes + for release_node in release_nodes[ds_id]: + for release_dep_node in get_node_requirements(release_node, scheduled + schedule_until_free): + if release_dep_node not in schedule_until_free: + schedule_until_free.append(release_dep_node) + + n_scheduled_ags = len( + [n for n in schedule_until_free if n.target == torch.ops.dc.allgather_param.default]) + + task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem, + last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free) + + # print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}") + runnable_ags.append(task) + + if len(runnable_ags) > 0: + break + + assert len(runnable_ags) > 0, "No runnable allgather nodes" + + # Criteria of the choice: + # We want to choose allgather that does not require additional allgather until releasing the param. + # When we can find such a node, free_acc_mem will be zero. In that case, we choose the one with the smallest cost until free to minimize the period of occupying memory for the gathered param. + # If there is no such node, we choose the one with the smallest free_cost to minimize the period of occupying memory for the gathered param. + ags_with_no_additional_ag = [ag for ag in runnable_ags if ag.free_acc_mem == 0] + if len(ags_with_no_additional_ag) > 0: + sorted_ags = sorted(runnable_ags, key=lambda x: x.free_cost) + next_ag = sorted_ags[0] + nodes_to_schedule = next_ag.schedule_until_free + else: + # sorted_ags = sorted(runnable_ags, key=lambda x: x.allgathered_mem) + sorted_ags = sorted(runnable_ags, key=lambda x: x.free_acc_mem) + next_ag = sorted_ags[0] + nodes_to_schedule = next_ag.schedule_until_ag + + # print(f" next_ag {next_ag}") + for n in nodes_to_schedule: + scheduled.append(n) + unscheduled.remove(n) + + unscheduled_ags.remove(next_ag.node) + + ag_nodes_in_path.pop(next_ag.node) + for ag_node, nodes in ag_nodes_in_path.items(): + if next_ag.node in nodes: + nodes.remove(next_ag.node) + + # Schedule reduce nodes when possible to free memory earlier + reduces_to_schedule = [] + for reduce_node in reduce_nodes: + if next_ag.node in ag_nodes_in_path_to_reduce_nodes[reduce_node]: + ag_nodes_in_path_to_reduce_nodes[reduce_node].remove(next_ag.node) + if len(ag_nodes_in_path_to_reduce_nodes[reduce_node]) == 0: + reduces_to_schedule.append(reduce_node) + + for n in reduces_to_schedule: + need_to_schedule = get_node_requirements(n, scheduled) + for nn in need_to_schedule: + scheduled.append(nn) + unscheduled.remove(nn) + + # Do the same for output nodes + outputs_to_schedule = [] + for output_node in output_nodes: + if next_ag.node in ag_nodes_in_path_to_output_nodes[output_node]: + ag_nodes_in_path_to_output_nodes[output_node].remove(next_ag.node) + if len(ag_nodes_in_path_to_output_nodes[output_node]) == 0: + outputs_to_schedule.append(output_node) + + for n in outputs_to_schedule: + need_to_schedule = get_node_requirements(n, scheduled) + for nn in need_to_schedule: + scheduled.append(nn) + unscheduled.remove(nn) + + # print(f"After ag scheduled: scheduled: {scheduled}") + + scheduled_set = set(scheduled) + for node in graph.nodes: + if node in scheduled_set: + continue + scheduled.append(node) + unscheduled.remove(node) + + assert len(unscheduled) == 0, f"There are unscheduled nodes: {unscheduled}" + + ret_graph = make_graph_from_schedule(scheduled) + ret_graph.lint() + return ret_graph diff --git a/deepspeed/compile/partitioner.py b/deepspeed/compile/partitioner.py new file mode 100644 index 000000000000..9f1345da875a --- /dev/null +++ b/deepspeed/compile/partitioner.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple, List, Set + +import torch +from torch.fx import GraphModule, Graph, Node + +try: + from torch.utils.checkpoint import CheckpointPolicy + from torch._functorch.partitioners import _is_primal +except ImportError: + pass + +from .util import get_no_copy_ops, is_cast_op + + +def _recompute_param_aliases(joint_graph: Graph, param_indices: List[Tuple[int, int, torch.Size]]): + """Recompute nodes aliasing or downcasting any parameter + + In ZeRO3, sharded parameters are gathered before use and the gathered + parameters should be freed once they are no longer needed to save GPU + memory. + + When DeepCompile is active for ZeRO3, parameter gathering is done by custom + passes after the joint graph captured by Dynamo and AOT Autograd is + partitioned into fwd and bwd parts. Since the partitioner has no clue about + parameter sharding now, the partitioned graphs will save for backward all + intermediate activations including those aliasing the gathered parameters. + That essentially nullifies the memory reduction that ZeRO3 is designed to + bring. + + The solution is to recompute the parameter-aliasing activations in the + backward. It is done by marking such nodes as MUST_RECOMPUTE and reusing the + min-cut partitioner originally designed for checkpointing. If autocast is + enabled, parameter downcasts are also recomputed. + + This cannot be converted to a standalone pass because it must be applied + before partitioning the joint graph, but passes run after the partitioning. + + TODO(eternalNight) `min_cut_rematerialization_partition` may recompute more + nodes than required for ZeRO3. Need investigate its performance + implications. + """ + no_copy_ops = get_no_copy_ops() + + def need_recompute(n: Node) -> bool: + if n.op == "call_function": + is_cast, _ = is_cast_op(n) + return n.target in no_copy_ops or is_cast + return False + + primal_inputs = list(filter(_is_primal, joint_graph.nodes)) + ds_param_inputs = set([primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]) + recomputed_nodes = set() + + for node in joint_graph.nodes: + # The `ac_graph_id` tag tracks the checkpoint module that a node belongs + # to, and is for enforcing the saving of activations at the boundary of + # consecutive checkpointed blocks. It starts from 1 and increments by 1 + # each time a graph module is checkpointed. + # + # `min_cut_rematerialization_partition` requires every node to have + # `ac_graph_id`. If this graph is not checkpointed (and thus + # `ac_graph_id` is missing), we tag all nodes to 1 to prevent the + # partition function from modifying the recompute tag. + node.meta.setdefault("ac_graph_id", 1) + + # Arguments can be non-tensor types some of which are not hashable. So + # we must inspect the type of an argument before checking if it is in + # any set. + if need_recompute(node) and \ + any([(isinstance(a, Node) and (a in ds_param_inputs or a in recomputed_nodes)) for a in node.args]): + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + recomputed_nodes.add(node) + else: + # If checkpointing is not enabled for this graph, assume all + # activations required by the backward pass should be saved. + node.meta.setdefault("recompute", CheckpointPolicy.MUST_SAVE) + + +def get_wrapped_partitioner( + z3_partition: bool, + param_indices: List[Tuple[int, int, torch.Size]], + partition_fn, + frame_id: int, + frames_partitioned: Set[int], +): + + def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *, num_fwd_outputs, + **kwargs) -> Tuple[GraphModule, GraphModule]: + frames_partitioned.add(frame_id) + if z3_partition: + _recompute_param_aliases(joint_module.graph, param_indices) + return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs) + + return partition_recompute_ds_params diff --git a/deepspeed/compile/passes/__init__.py b/deepspeed/compile/passes/__init__.py new file mode 100644 index 000000000000..620e99147647 --- /dev/null +++ b/deepspeed/compile/passes/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..profilers.graph_profile import MemoryProfilingInterpreter + +import deepspeed.comm as dist + + +def run_opt_passes(nz3, + graph_index, + graph_id, + gm, + create_inputs_fn, + opt_passes, + graph_order, + profiling_results, + param_manager, + bwd, + debug_log=False): + profile = profiling_results[graph_id] + rank = dist.get_rank() + + for i, opt_pass in enumerate(opt_passes): + + opt_pass_fn, mem_budget = opt_pass + + graph = opt_pass_fn(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd) + graph.lint() + gm.graph = graph + gm.recompile() + + if debug_log: + print(f"Prefetching enabled for {'bwd' if bwd else 'fwd'} graph_id={graph_id} {graph}") + + mem_prof = MemoryProfilingInterpreter(nz3, gm) + mem_prof.run(*create_inputs_fn()) + if debug_log and rank == 0: + mem_prof.dump(f"mem_prof_r{rank}_{'bwd' if bwd else 'fwd'}_{graph_index}_{graph_id}_pass_{i}.csv") + + mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] + if bwd: + profile.bwd_mem = mem + else: + profile.fwd_mem = mem + + return gm diff --git a/deepspeed/compile/passes/long_context_checkpointing.py b/deepspeed/compile/passes/long_context_checkpointing.py new file mode 100644 index 000000000000..0f72d94fdf9e --- /dev/null +++ b/deepspeed/compile/passes/long_context_checkpointing.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import inspect +import textwrap +import torch._functorch.partitioners as _partitioners + +# The custom should_ban_recomputation to splice into solve_min_cut. +# All names it references (aten, operator, config, op_types, min_cut_options, +# is_materialized_backwards, get_aten_target, _size_of, fx, torch, +# CheckpointPolicy) are either module-level in torch._functorch.partitioners +# or local variables already in scope when this function executes inside +# solve_min_cut. +_CUSTOM_SHOULD_BAN = """\ +def should_ban_recomputation(node): + \"\"\"Sequence-aware recomputation banning logic\"\"\" + if node.op != "call_function": + return False + if node.target == operator.getitem: + return False + if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: + return True + if config.recompute_views and op_types.is_view(node): + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + must_save_set = [ + aten.convolution, + aten.convolution_backward, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten._flash_attention_forward, + aten._efficient_attention_forward, + aten.upsample_bilinear2d, + aten.native_dropout, + aten.rand_like, + aten.randn_like, + ] + + if get_aten_target(node) in must_save_set: + return True + + def heuristic(node): + if "val" in node.meta: + if isinstance(node.meta["val"], torch.Tensor) and node.meta["val"].dim() >= 2: + return node.meta["val"].shape[1] >= 4096 + return False + + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): + return False + + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(node): + if heuristic(node): + return False + return True + + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return False + + if min_cut_options.ban_if_reduction: + input_tensors_size = sum( + _size_of(i) for i in node.args if isinstance(i, fx.Node) + ) + output_size = _size_of(node) + return output_size * 4 < input_tensors_size + return False +""" + + +def register_long_context_checkpointing(): + """Splice the custom should_ban_recomputation into solve_min_cut. + + Uses inspect.getsource to extract solve_min_cut's source, replaces the + original should_ban_recomputation with _CUSTOM_SHOULD_BAN, then execs the + result directly in _partitioners.__dict__. + + The exec'd function's __globals__ is the real partitioners module dict, so + every other nested function (is_fusible, is_materialized_backwards, + can_fuse_into_*, etc.) and every local/closure variable (op_types, + min_cut_options, node_info, config, …) is exactly as in the original — + nothing else changes. + + Backward compatible: if solve_min_cut gains new heuristics in a future + PyTorch version the exec automatically picks them up; only + _CUSTOM_SHOULD_BAN needs to stay in sync with any changes to the + original should_ban_recomputation signature/contract. + """ + src = inspect.getsource(_partitioners.solve_min_cut) + lines = src.split('\n') + + # Locate the original should_ban_recomputation and the function after it. + start = next(i for i, l in enumerate(lines) if l.startswith(' def should_ban_recomputation(')) + end = next(i for i, l in enumerate(lines) if i > start and l.startswith(' def ')) + + # Indent the replacement to the nesting level inside solve_min_cut (4 spaces). + replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ') + + new_src = '\n'.join(lines[:start]) + '\n' + replacement + '\n'.join(lines[end:]) + exec(new_src, _partitioners.__dict__) # redefines _partitioners.solve_min_cut diff --git a/deepspeed/compile/passes/offload_activation.py b/deepspeed/compile/passes/offload_activation.py new file mode 100644 index 000000000000..7443ecd1de90 --- /dev/null +++ b/deepspeed/compile/passes/offload_activation.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Dict, Set, Tuple +import random +from collections import defaultdict + +import torch +from torch.fx import Graph, Node + +from ..fx import get_output_node, move_primals_to_head +from ..graph_param import DSGraphParamManager + +value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict) +used_ids: Set[int] = set() + + +def get_random_id() -> int: + + def _gen(): + # generate random int + return random.randint(10000, 2**31) + + global used_ids + v = _gen() + while v in used_ids: + v = _gen() + used_ids.add(v) + return v + + +def _should_offload(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if "tensor_meta" not in node.meta: + return False + + return True + + +def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]], + graph_order: List[Tuple[int, bool]], mem_budget: float, + param_manager: DSGraphParamManager) -> Graph: + param_names = set(param_manager.param_names) + + import copy + cl_graph = copy.deepcopy(graph) + cl_graph.erase_node(get_output_node(cl_graph)) + + global value_to_id + for name, node in nodes_to_offload_with_names: + if node.name in param_names: + continue + + if not _should_offload(node): + continue + + val_id = get_random_id() + with graph.inserting_after(node): + offload_node = graph.create_node('call_function', + torch.ops.dc.offload_tensor.default, (node, graph_id, val_id), {}, + name=f"offload_{node.name}_{val_id}") + with graph.inserting_after(offload_node): + wait_node = graph.create_node('call_function', + torch.ops.dc.wait_offload.default, (offload_node, graph_id, val_id), {}, + name=f"wait_copy_{node.name}_{val_id}") + + output_node = get_output_node(graph) + output_node.replace_input_with(node, wait_node) + + value_to_id[graph_id][name] = val_id + + graph = move_primals_to_head(graph) + + graph.lint() + return graph + + +def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], mem_budget: float, + param_manager: DSGraphParamManager) -> Graph: + + graph_value_to_id = value_to_id[graph_id] + name_to_node = {n.name: n for n in graph.nodes} + act_nodes = [name_to_node[n] for n in graph_value_to_id.keys()] + + node_to_first_user = {} + for act in act_nodes: + for node in graph.nodes: + if act in node.args: + node_to_first_user[act] = node + break + + for node in act_nodes: + val_id = graph_value_to_id[node.name] + + with graph.inserting_before(node_to_first_user[node]): + reload_node = graph.create_node('call_function', + torch.ops.dc.reload_tensor.default, (node, graph_id, val_id), {}, + name=f"reload_{node.name}_{val_id}") + with graph.inserting_after(reload_node): + wait_node = graph.create_node('call_function', + torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {}, + name=f"wait_copy_{reload_node.name}_{val_id}") + + # replace all uses of node with wait_node + users = {} + for u in node.users.keys(): + if u != reload_node: + users[u] = (node, wait_node) + for u, (old_in, new_in) in users.items(): + u.replace_input_with(old_in, new_in) + + graph = move_primals_to_head(graph) + graph.lint() + return graph diff --git a/deepspeed/compile/passes/offload_adam_states.py b/deepspeed/compile/passes/offload_adam_states.py new file mode 100644 index 000000000000..21f03a4ad4eb --- /dev/null +++ b/deepspeed/compile/passes/offload_adam_states.py @@ -0,0 +1,549 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +from typing import List, Tuple + +import torch +from torch.fx import Graph, GraphModule + +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.offload_states import _make_offload_state_key + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +from ..profilers import ProfilingResult +from ..graph_param import DSGraphParamManager +from ..fx import move_primals_to_head + +import deepspeed.comm as dist + +NAME = "offload_adam_states" + + +def print_r0(msg): + if dist.get_rank() == 0: + print(msg) + + +MARGIN = 0.2 + +copy_stream = None +offload_event = None +reload_event = None + +offload_key_events = {} +reload_key_events = {} + +max_memory = 0 + + +def lazy_init(): + global copy_stream + global offload_event + global reload_event + + if copy_stream is None: + + copy_stream = get_accelerator().Stream() + offload_event = get_accelerator().Event() + reload_event = get_accelerator().Event() + + +optimizer = None +device = None +nz3 = None + + +def move_key(state, key, key_event=None): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu")) + + if key not in state: + return + + with get_accelerator().stream(copy_stream): + state[offload_buf_key].copy_(state[key], non_blocking=True) + + if key_event is None: + offload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_back_key(state, key, key_event=None): + + with get_accelerator().stream(copy_stream): + state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device) + state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True) + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_hp_param(src_tensor, dest_buf, key_event=None): + with get_accelerator().stream(copy_stream): + dest_buf.copy_(src_tensor, non_blocking=True) + src_tensor.data = dest_buf + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_back_hp_param(src_tensor, dest_buf, key_event=None): + with get_accelerator().stream(copy_stream): + dest_buf.data = torch.empty_like(src_tensor, device=device) + dest_buf.copy_(src_tensor, non_blocking=True) + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def offload_adam_states_sync(): + + with unset_fake_temporarily(): + + if not hasattr(optimizer, "hp_params_pin_buffers"): + optimizer.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device="cpu")) + for t in optimizer.fp32_partitioned_groups_flat + ] + + for i, (k, state) in enumerate(optimizer.state.items()): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + del state["exp_avg"] + if "exp_avg_sq" in state: + del state["exp_avg_sq"] + + for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers): + move_hp_param(src_tensor, dest_buf) + + get_accelerator().synchronize() + + +def reload_adam_states_sync(): + + with unset_fake_temporarily(): + # print_r0("Reloading Adam states") + + for _, state in optimizer.state.items(): + if _make_offload_state_key("exp_avg") in state: + move_back_key(state, "exp_avg") + if _make_offload_state_key("exp_avg_sq") in state: + move_back_key(state, "exp_avg_sq") + + for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat): + move_back_hp_param(src, dest) + + get_accelerator().synchronize() + + +def sync_offload_states(event=None): + if nz3.is_profiling(): + offload_adam_states_sync() + else: + if event is None: + offload_event.wait(copy_stream) + else: + event.wait(copy_stream) + + +def sync_reload_states(event=None): + if nz3.is_profiling(): + reload_adam_states_sync() + else: + if event is None: + reload_event.wait(copy_stream) + else: + event.wait(copy_stream) + + +def make_offload_task(task): + + def run_offload_task(): + # if not nz3.is_profiling(): + # print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}") + + if offload_key_events.get(task[1]) is None: + offload_key_events[task[1]] = get_accelerator().Event() + + if task[2] == "hp_param": + move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]]) + else: + assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer" + state = optimizer.state[task[1]] + # if offload_key_events.get(task[1]) is None: + # offload_key_events[task[1]] = get_accelerator().Event() + move_key(state, task[2], offload_key_events[task[1]]) + + return run_offload_task + + +def make_offload_sync(task): + + def run_offload_sync(): + # if not nz3.is_profiling(): + event = offload_key_events[task[1]] + event.synchronize() + + if task[2] != "hp_param": + state = optimizer.state[task[1]] + key = task[2] + if key in state: + del state[key] + # print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}") + + return run_offload_sync + + +def make_reload_task(task): + + def run_reload_task(): + if not nz3.is_profiling(): + if reload_key_events.get(task[1]) is None: + reload_key_events[task[1]] = get_accelerator().Event() + + if task[2] == "hp_param": + move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]]) + else: + state = optimizer.state[task[1]] + # print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}") + move_back_key(state, task[2], reload_key_events[task[1]]) + + return run_reload_task + + +def update_max_memory(name): + + global max_memory + mem = get_accelerator().max_memory_allocated() + max_memory = max(max_memory, mem) + + +def empty_cache(): + get_accelerator().empty_cache() + + +offload_tasks = [] +offload_tasks_remaining = [] +offload_tasks_scheduled = [] +reload_task_remaining = [] +total_reload_mem = 0 + + +def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], + profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> Graph: + + to_remove = [] + for node in graph.nodes: + if node.op == 'call_function' and \ + node.target in [offload_adam_states_sync, sync_offload_states, reload_adam_states_sync, sync_reload_states, update_max_memory]: + to_remove.append(node) + + for node in to_remove: + graph.erase_node(node) + + accelerator = get_accelerator() + total_mem = accelerator.total_memory() * (1 - MARGIN) + print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}") + + mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem + mem_dict = {name: peak for name, alloc_mem, delta, peak in mem} + + current_peak_mem = 0 + peak_mem = {} + + ordered_node = reversed(graph.nodes) if bwd else graph.nodes + for node in ordered_node: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if mem_dict[node.name] > current_peak_mem: + current_peak_mem = mem_dict[node.name] + peak_mem[node.name] = current_peak_mem + + # fwd_max_mem = max(m[3] for m in prof.fwd_mem) + # bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0 + # peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem) + + global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled + + if not bwd: + is_first_graph = graph_id == graph_order[0][0] + # print_r0( + # f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}" + # ) + + # At the beginning of the first graph, we schedule offload tasks to launch all offloading + if is_first_graph: + # print_r0( + # f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}" + # ) + + with unset_fake_temporarily(): + offload_adam_states_sync() + reload_adam_states_sync() + sync_reload_states() + + reload_size = 0 + + for i, ((k, state), hp_param, hp_param_cpu) in enumerate( + zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat, + optimizer.hp_params_pin_buffers)): + # print_r0( + # f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}") + + if _make_offload_state_key("exp_avg") in state: + key = _make_offload_state_key("exp_avg") + size = state[key].numel() * state[key].element_size() + + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype)) + # print_r0( + # f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) + + if _make_offload_state_key("exp_avg_sq") in state: + key = _make_offload_state_key("exp_avg_sq") + size = state[key].numel() * state[key].element_size() + + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype)) + # print_r0( + # f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) + + hp_param_size = hp_param.numel() * hp_param.element_size() + # if total_mem < max_memory + reload_size + hp_param_size: + offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param", + hp_param.numel() * hp_param.element_size(), hp_param.dtype)) + # print_r0( + # f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}" + # ) + + # print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") + + for node in graph.nodes: + # print_r0(f"checking sync node insert node: {node.name}") + + if node.name not in peak_mem \ + or node.op == 'placeholder' \ + or "offload_opt_" in node.name: + continue + + to_offload = [] + optim_size = sum([task[3] for task in offload_tasks]) + + # print_r0( + # f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) + while total_mem - peak_mem[node.name] - optim_size < 0: + if len(offload_tasks) == 0: + break + + task = offload_tasks.pop(0) + to_offload.append(task) + optim_size = sum([task[3] for task in offload_tasks]) + # print_r0( + # f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) + + for task in to_offload: + with graph.inserting_before(node): + graph.create_node('call_function', + make_offload_sync(task), (), {}, + name=f"offload_opt_sync_{task[0]}_{task[2]}") + print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}") + offload_tasks_scheduled.append(task) + + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op != 'placeholder': + print_r0(f"Inserting all offload tasks before {node.name}") + for task in offload_tasks_scheduled: + name = f"offload_opt_{task[0]}_{task[2]}" + with graph.inserting_before(node): + offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name) + break + + # print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}") + print_r0(f"offload_opt_states_inc finish graph {graph_id}") + else: + + graph_order_with_backward = [g[0] for g in graph_order if g[1]] + is_first_graph = graph_id == graph_order_with_backward[-1] + is_last_graph = graph_id == graph_order_with_backward[0] + + # print_r0( + # f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}" + # ) + + if is_first_graph: + inserted_sync = False + for node in graph.nodes: + if node.op != 'placeholder' and not inserted_sync: + # print(f"Inserting offload_sync before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', empty_cache, (), {}, name="empty_cache") + + inserted_sync = True + reload_tasks_remaining = copy.copy(offload_tasks_scheduled) + + global total_reload_mem + for node in graph.nodes: + if node.name not in peak_mem \ + or node.op == 'placeholder' \ + or node.op == 'output' \ + or "offload_opt_sync_" in node.name: + continue + + if len(reload_tasks_remaining) > 0: + task = reload_tasks_remaining[0] + next_reload_mem = task[3] + + insert_pos = node + while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem: + expected_mem = peak_mem[node.name] + total_reload_mem + print_r0( + f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}" + ) + + with graph.inserting_after(insert_pos): + insert_pos = graph.create_node('call_function', + make_reload_task(task), (), {}, + name=f"reload_opt_{task[0]}_{task[2]}") + + total_reload_mem += next_reload_mem + reload_tasks_remaining.pop(0) + if len(reload_tasks_remaining) == 0: + break + + task = reload_tasks_remaining[0] + next_reload_mem = task[3] + + # prev_node = node + + if is_last_graph: + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op == 'output': + for task in reload_tasks_remaining: + with graph.inserting_before(node): + graph.create_node('call_function', + make_reload_task(task), (), {}, + name=f"reload_opt_{task[0]}_{task[2]}") + + sync_fn = lambda: copy_stream.synchronize() + with graph.inserting_before(node): + graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream") + + print_r0( + f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}" + ) + + return graph + + +def add_record_max_mem_nodes(graph: Graph): + + nodes = list(graph.nodes) + for node in nodes: + if node.op == "output" or node.op == "placeholder": + continue + + with graph.inserting_after(node): + name = f"update_max_memory_{node.name}" + graph.create_node('call_function', update_max_memory, (name, ), {}, name=name) + + +def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], + profiling_results: ProfilingResult, mem_budget: float, + param_manager: DSGraphParamManager, bwd: bool) -> Graph: + + if bwd: + graph_order_with_backward = [g[0] for g in graph_order if g[1]] + is_last_graph = graph_id == graph_order_with_backward[0] + + inserted_reload = False + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op == 'output' and not inserted_reload and is_last_graph: + # print(f"Inserting reload_opt before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt") + inserted_reload = True + + # add_record_max_mem_nodes(graph) + + else: + is_first_graph = graph_id == graph_order[0][0] + + graph = move_primals_to_head(graph) + + inserted_offload = False + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op != 'placeholder' and not inserted_offload and is_first_graph: + print(f"Inserting offload_opt before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt") + inserted_offload = True + + add_record_max_mem_nodes(graph) + + return graph + + +def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, + bwd) + return gm + + +def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, + bwd) + return gm + + +def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], + profiling_results, create_inputs_fn, mem_budget: float, + param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + if not bwd and graph_id == graph_order[0][0]: + with unset_fake_temporarily(): + offload_adam_states_sync() + # returns None, and profiling will be skipped + + +def init_offload_opt_states(adam_optimizer, _nz3): + lazy_init() + + global optimizer + optimizer = adam_optimizer + global device + device = torch.device(get_accelerator().current_device()) + global nz3 + nz3 = _nz3 diff --git a/deepspeed/compile/passes/offload_parameters.py b/deepspeed/compile/passes/offload_parameters.py new file mode 100644 index 000000000000..a8922f9eae25 --- /dev/null +++ b/deepspeed/compile/passes/offload_parameters.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import torch +from torch.fx import Node, GraphModule +from deepspeed.compile.util import get_last_uses +from ..graph_param import DSGraphParamManager + + +def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int): + new_node = None + with gm.graph.inserting_after(node): + args = (node, ) + for a in [graph_id, ds_id]: # To add ds_id + args += (a, ) + new_node = gm.graph.create_node('call_function', + torch.ops.dc.offload_parameter.default, + args, {}, + name="offload_parameter") + + return new_node + + +def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int): + new_node = None + with gm.graph.inserting_after(node): + args = (node, ) + for a in [graph_id, ds_id]: # To add ds_id + args += (a, ) + new_node = gm.graph.create_node('call_function', + torch.ops.dc.reload_parameter.default, + args, {}, + name="reload_parameter") + return new_node + + +def get_ds_id(node: Node): + assert node.target == torch.ops.dc.allgather_param.default + return node.args[2] + + +def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + node_to_last_use, user_to_last_uses = get_last_uses(gm.graph) + for node in gm.graph.nodes: + if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default): + add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node)) + add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node)) + gm.graph.lint() + return gm diff --git a/deepspeed/compile/passes/prefetch.py b/deepspeed/compile/passes/prefetch.py new file mode 100644 index 000000000000..29fd1ebadd74 --- /dev/null +++ b/deepspeed/compile/passes/prefetch.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import torch +from torch.fx import Graph, Node, GraphModule + +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from ..profilers.comm_profile import create_predictor +from ..graph_param import DSGraphParamManager + +NAME = "prefetch" + +FUSE_FACTOR = 0.8 +MARGIN = 0.1 +MAX_FUSE_SIZE = 1e9 +MAX_BUFFERED_SIZE = 4e9 + +run_prefetch_pass = False + + +def print_rank_0(message): + if dist.get_rank() == 0: + print(message) + + +def get_ds_id(node: Node): + assert node.target == torch.ops.dc.allgather_param.default + return node.args[2] + + +def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + + max_mem = get_accelerator().total_memory() * (1 - MARGIN) + vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device())) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN) + max_mem = vals_to_bcast[0].item() + + mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem + op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time + tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes + + mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem} + time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time} + tensor_size_dict = {name: size for name, size in tensor_sizes} + + graph = gm.graph + total_param_size = sum( + [tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default]) + + print_rank_0( + f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}" + ) + + # Fill missing values + prev_mem = 0 + prev_peak = 0 + for node in graph.nodes: + if node.name in mem_dict: + prev_mem = mem_dict[node.name][0] + prev_peak = mem_dict[node.name][1] + else: + print_rank_0(f"node {node.name} not in mem_dict") + mem_dict[node.name] = (prev_mem, prev_peak) + + comm_predictor = create_predictor() + + order_rev = list(reversed(graph.nodes)) + new_order_rev = [] + prefetch_ags = [] + prefetch_ag_groups = [] + ag_tensor_size_sum = 0 + for i, node in enumerate(order_rev): + # print_rank_0( + # f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + + if node.op != "placeholder": + assert i < len(order_rev) - 1 + assert node.name in mem_dict + next_node = order_rev[i + 1] + next_alloc_mem, next_peak = mem_dict[next_node.name] + + # Free up memory + while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE: + if len(prefetch_ag_groups) > 0: + # launch prefetch + fused_ag_nodes = prefetch_ag_groups.pop(0) + total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes]) + ag_tensor_size_sum -= total_ag_tensor_size + new_order_rev.append(fused_ag_nodes) + assert len(fused_ag_nodes) > 0 + # print_rank_0( + # f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + elif len(prefetch_ags) > 0: + prefetch_ag_groups.append(prefetch_ags) + prefetch_ags = [] + # print_rank_0( + # f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + else: + break + + if node.target == torch.ops.dc.allgather_param.default: + + current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags]) + pred_time_current = comm_predictor(current_ag_size) + pred_time_next = comm_predictor(tensor_size_dict[node.name]) + pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name]) + + do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and ( + current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE + # print_rank_0( + # f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}" + # ) + + if len(prefetch_ags) > 0 and not do_fuse: + # stop fusing here + prefetch_ag_groups.append(prefetch_ags) + prefetch_ags = [] + # print_rank_0( + # f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}") + # else: + # print_rank_0( + # f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}" + # ) + prefetch_ags.append(node) + ag_tensor_size_sum += tensor_size_dict[node.name] + + new_order_rev.append(node) + + if (node.op != "placeholder" + and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder": + for ag_group in prefetch_ag_groups: + assert len(ag_group) > 0 + new_order_rev.append(ag_group) + total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group]) + ag_tensor_size_sum -= total_ag_tensor_size + if len(prefetch_ags) > 0: + new_order_rev.append(prefetch_ags) + ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags]) + assert ag_tensor_size_sum == 0 + + # print_rank_0( + # f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}" + # ) + + assert ag_tensor_size_sum >= 0 + + new_graph = Graph() + env = {} + for node in reversed(new_order_rev): + if isinstance(node, Node): + #print(f"reconstruct {node.name} {node.target}") + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + else: + param_nodes = [ag_node.args[0] for ag_node in node] + param_nodes_copy = [env[param_node.name] for param_node in param_nodes] + + ds_ids = [get_ds_id(ag_node) for ag_node in node] + new_graph.call_function(torch.ops.dc.prefetch_params_fused.default, + args=(graph_id, param_nodes_copy, ds_ids)) + new_graph.lint() + gm.graph = new_graph + + return gm diff --git a/deepspeed/compile/passes/selective_gather.py b/deepspeed/compile/passes/selective_gather.py new file mode 100644 index 000000000000..0b262984ae05 --- /dev/null +++ b/deepspeed/compile/passes/selective_gather.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch +from torch.fx import GraphModule + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import log_dist + +from ..util import get_deepcompile_handle +from ..graph_param import DSGraphParamManager + +NAME = "selective_gather" + +max_alloc_mem = 0 +last_optimize_step = 0 +MEM_MARGIN = 0.1 + + +def print_rank_0(message): + log_dist(message, ranks=[0]) + + +def _compute_persistence_budget(all_graph_mem_records: List[List[Tuple[str, int, int, int]]], total_mem: int, + mem_margin: float) -> Dict[str, int]: + usable_mem = int(total_mem * (1 - mem_margin)) + non_empty_records = [mem_records for mem_records in all_graph_mem_records if mem_records] + + if not non_empty_records: + return { + "usable_mem": usable_mem, + "peak_resident_alloc": 0, + "transient_peak": 0, + "available_mem": 0, + "profiled_list_count": 0, + } + + # Persistent parameters add to live allocations that remain resident past an op boundary. + peak_resident_alloc = max(record[1] for mem_records in non_empty_records for record in mem_records) + transient_peak = max(record[3] for mem_records in non_empty_records for record in mem_records) + + return { + "usable_mem": usable_mem, + "peak_resident_alloc": peak_resident_alloc, + "transient_peak": transient_peak, + "available_mem": max(0, usable_mem - peak_resident_alloc), + "profiled_list_count": len(non_empty_records), + } + + +def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + target_graph_id = graph_id + + if not bwd: + return gm + + last_backward_graph_id = None + for g_id, needs_bwd in graph_order: + if needs_bwd: + last_backward_graph_id = g_id + break + + # Run only on the last backward graph + if last_backward_graph_id is None or graph_id != last_backward_graph_id: + return gm + + all_graph_mem_records = [] + for profile_graph_id, prof in profiling_results.items(): + all_graph_mem_records.extend([prof.fwd_mem, prof.bwd_mem]) + + fwd_peak_resident = max((m[1] for m in prof.fwd_mem), default=0) + fwd_transient_peak = max((m[3] for m in prof.fwd_mem), default=0) + bwd_peak_resident = max((m[1] for m in prof.bwd_mem), default=0) + bwd_transient_peak = max((m[3] for m in prof.bwd_mem), default=0) + + print_rank_0(f"selective_gather graph_id={profile_graph_id} " + f"fwd_peak_resident={fwd_peak_resident} fwd_transient_peak={fwd_transient_peak} " + f"bwd_peak_resident={bwd_peak_resident} bwd_transient_peak={bwd_transient_peak}") + + persistent_ds_ids = set() + for param_graph_id, pm in param_manager.items(): + for name, ds_param in pm.params.items(): + if ds_param.param.ds_persist: + persistent_ds_ids.add(pm.ds_ids[name]) + + ds_id_to_size = {} + ds_id_to_time = defaultdict(float) + ds_id_to_prof_dtime = defaultdict(float) + ds_id_to_prof_wtime = defaultdict(float) + + for param_graph_id, pm in param_manager.items(): + params = pm.params + for param_name, param in params.items(): + ds_id = pm.ds_ids[param_name] + ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize + + profile = profiling_results[param_graph_id] + for n in profile.fwd_graph.nodes: + if n.target == torch.ops.dc.allgather_param.default: + assert "tensor_size" in n.meta + ds_id_to_size[n.args[2]] = n.meta["tensor_size"] + assert "device_time" in n.meta + ds_id_to_time[n.args[2]] += n.meta["device_time"] + + ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"] + ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"] + + if profile.bwd_graph is not None: + for n in profile.bwd_graph.nodes: + if n.target == torch.ops.dc.allgather_param.default: + assert "tensor_size" in n.meta + ds_id_to_size[n.args[2]] = n.meta["tensor_size"] + assert "device_time" in n.meta + ds_id_to_time[n.args[2]] += n.meta["device_time"] + + ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids] + ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True) + + # print(f"ds_id_to_size={ds_id_to_size}") + # print(f"ds_id_to_time={ds_id_to_time}") + + # if dist.get_rank() == 0: + # for ds_id in ds_ids: + # dtime_in_sec = ds_id_to_prof_dtime[ds_id] + # wtime_in_sec = ds_id_to_prof_wtime[ds_id] + # size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024 + # print( + # f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s" + # ) + + accelerator = get_accelerator() + total_mem = accelerator.total_memory() + current_available_mem = accelerator.available_memory() + vals_to_bcast = torch.tensor([total_mem, current_available_mem], + device=torch.device(get_accelerator().current_device())) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN) + total_mem = vals_to_bcast[0].item() + current_available_mem = vals_to_bcast[1].item() + + budget = _compute_persistence_budget(all_graph_mem_records, total_mem, MEM_MARGIN) + available_mem = int(current_available_mem * (1 - MEM_MARGIN)) + + ds_id_to_param = {} + for g_id, g_pm in param_manager.items(): + for name, ds_param in g_pm.params.items(): + ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param + + candidate_bytes = sum(ds_id_to_size[ds_id] for ds_id in ds_ids) + persistent_bytes = sum(ds_id_to_size.get(ds_id, 0) for ds_id in persistent_ds_ids) + + print_rank_0( + f"selective_gather target_graph_id={target_graph_id} profiled_mem_lists={budget['profiled_list_count']} " + f"total_mem={total_mem} usable_mem={budget['usable_mem']} peak_resident_alloc={budget['peak_resident_alloc']} " + f"transient_peak={budget['transient_peak']} current_available_mem={current_available_mem} " + f"usable_available_mem={available_mem} " + f"persistent_count={len(persistent_ds_ids)} persistent_bytes={persistent_bytes} " + f"candidate_count={len(ds_ids)} candidate_bytes={candidate_bytes}") + + if budget["profiled_list_count"] == 0: + print_rank_0("selective_gather no profiling data; skipping persistence update") + return gm + + if len(ds_ids) == 0: + print_rank_0("selective_gather no candidates to persist") + return gm + + if available_mem == 0: + print_rank_0("selective_gather no currently available memory for new persistent params") + return gm + + persistent_mem = 0 + selected_count = 0 + nz3 = get_deepcompile_handle() + for ds_id in ds_ids: + size = ds_id_to_size[ds_id] + if persistent_mem + size > available_mem: + break + persistent_mem += size + selected_count += 1 + + param_obj = ds_id_to_param[ds_id] + + nz3.set_persistent(ds_id) + print_rank_0( + f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}") + + if selected_count == 0: + smallest_candidate = min(ds_id_to_size[ds_id] for ds_id in ds_ids) + print_rank_0(f"selective_gather selected no new params: available_mem={available_mem} " + f"smallest_candidate={smallest_candidate}") + else: + print_rank_0(f"selective_gather selected_count={selected_count} selected_bytes={persistent_mem}") + + return gm + + +# def make_selective_gather(z3_optimizer, nz3): + +# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, +# mem_budget: float, param_manager, bwd: bool) -> Graph: +# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd, +# z3_optimizer, nz3) + +# return selective_gather_wrapper diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py new file mode 100644 index 000000000000..ab2b3fb9fa33 --- /dev/null +++ b/deepspeed/compile/passes/sp_compile.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from typing import Optional, List, Callable + +import torch +import deepspeed.comm as dist +from torch._subclasses.fake_tensor import FakeTensorMode, maybe_get_fake_mode +from torch.fx import GraphModule, Node +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from deepspeed.compile import constants + +from ..custom_ops import all_to_all, sp_dp_registry # noqa: F401 +from ..fx import find_node_by_name, get_node_shape_meta +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes + + +def prepare_autosp_inputs(input_id: torch.Tensor, + label_id: torch.Tensor, + position_id: torch.Tensor = None, + attention_mask: torch.Tensor = None, + seq_dim: int = 1): + """ + Prepare inputs for AutoSP by marking dynamic dimensions and tagging tensors. + + Args: + input_id: Token IDs tensor (required) + label_id: Label IDs tensor (required) + position_id: Position IDs tensor (optional) + attention_mask: Attention mask tensor (optional) + seq_dim: Sequence dimension index to mark as dynamic (default: 1) + """ + + if input_id is None: + raise ValueError("input_id is required") + if label_id is None: + raise ValueError("label_id is required") + + if seq_dim < 0 or seq_dim >= input_id.ndim: + raise ValueError(f"seq_dim {seq_dim} must be a valid index for input_id with shape {input_id.shape}") + + if position_id is not None: + if seq_dim >= position_id.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for position_id with shape {position_id.shape}") + + if attention_mask is not None: + if seq_dim >= attention_mask.ndim: + raise ValueError( + f"seq_dim {seq_dim} is out of bounds for attention_mask with shape {attention_mask.shape}") + + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + if position_id is not None: + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + if attention_mask is not None: + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + + input_id.tag = constants.AUTOSP_INPUT_ID_KEY + label_id.tag = constants.AUTOSP_LABEL_ID_KEY + if position_id is not None: + position_id.tag = constants.AUTOSP_POSITION_ID_KEY + + return input_id, label_id, position_id, attention_mask + + +def pass_shard_seq_dim(gm: GraphModule, example_inputs): + """ + Finds all direct and indirect consumers of the input sequence, label and position ids. + Shard the sequence dimension used by all such consumers. + """ + sp_size = sp_dp_registry.sp_size() + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + assert isinstance( + seq_symint, + torch.SymInt), f"expected sequence dimension to be of type {torch.SymInt!r} but found {type(seq_symint)!r}" + + sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) + if sym_seq_dim_node is None: + print(f"WARNING: Could not find the symbolic node for the sequence dimension") + return + + with gm.graph.inserting_after(sym_seq_dim_node): + sharded_node = gm.graph.call_function(operator.floordiv, args=(sym_seq_dim_node, sp_size)) + + sharded_input_nodes = set() + label_ids_node = get_label_id_node(gm) + position_ids_node = get_position_id_node(gm) + + if input_ids_node is not None: + sharded_input_nodes.add(input_ids_node) + if label_ids_node is not None: + sharded_input_nodes.add(label_ids_node) + if position_ids_node is not None: + sharded_input_nodes.add(position_ids_node) + + # find all consumers of the sharded inputs + consumer_nodes = set() + worklist = list(sharded_input_nodes) + visited = set() + + while worklist: + node = worklist.pop(0) + if node in visited: + continue + visited.add(node) + consumer_nodes.add(node) + + for user in node.users: + if user not in visited: + worklist.append(user) + + to_replace = [] + for node in consumer_nodes: + if sym_seq_dim_node in node.all_input_nodes: + to_replace.append(node) + + for user in to_replace: + user.replace_input_with(sym_seq_dim_node, sharded_node) + + +def pass_shard_input_ids(gm: GraphModule, example_inputs): + input_ids_node = get_input_id_node(gm) + shard_tensor_node(gm, input_ids_node) + + +def pass_shard_label_ids(gm: GraphModule, example_inputs): + label_ids_node = get_label_id_node(gm) + shard_tensor_node(gm, label_ids_node) + + +def pass_shard_position_ids(gm: GraphModule, example_inputs): + position_ids_node = get_position_id_node(gm) + if position_ids_node is None: + print("[WARNING] position id node not found. Skipping sharding of position ids.") + return + shard_tensor_node(gm, position_ids_node) + + +def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): + + def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: + with gm.graph.inserting_after(node): + a2a_node = gm.graph.call_function( + torch.ops.autosp.all_to_all.default, + args=(node, scatter_idx, gather_idx, name), + ) + a2a_node.name = f"a2a_{name}" + node.replace_all_uses_with(a2a_node) + a2a_node.update_arg(0, node) + return a2a_node + + attention_nodes = get_sdpa_nodes(gm) + if len(attention_nodes) == 0: + raise RuntimeError("AutoSP currently supports torch.nn.functional.scaled_dot_product_attention as the " + "attention backend. No SDPA attention operations were found in the compiled graph. " + "Please ensure your model uses torch.nn.functional.scaled_dot_product_attention " + "for AutoSP to work as expected.") + + for idx, attn_node in enumerate(attention_nodes): + q, k, v = attn_node.args[:3] + suffix = f"_{idx}" if len(attention_nodes) > 1 else "" + + # QKV: [B, N, S/P, H] -> [B, N/P, S, H] + insert_a2a(q, scatter_idx=1, gather_idx=2, name=f"q{suffix}") + insert_a2a(k, scatter_idx=1, gather_idx=2, name=f"k{suffix}") + insert_a2a(v, scatter_idx=1, gather_idx=2, name=f"v{suffix}") + + # O: [B, N/P, S, H] -> [B, N, S/P, H] + insert_a2a(attn_node, scatter_idx=2, gather_idx=1, name=f"o{suffix}") + + +def pass_canonicalize(gm: GraphModule, real_inputs): + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): + fake_mode = None + for node in gm.graph.nodes: + # Reuse the graph's existing fake mode when metadata is already present. + # Its ShapeEnv owns the symbolic dims captured during tracing, so using a + # fresh mode here can desynchronize fake inputs from graph metadata. + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_mode = maybe_get_fake_mode(fake_val) + elif fake_mode is None: + fake_val = node.meta.get("example_value", node.meta.get("val")) + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_mode = maybe_get_fake_mode(fake_val) + if fake_mode is not None: + break + + if fake_mode is None: + # Some graphs do not carry fake tensor metadata yet; create a fallback + # mode so FakeTensorProp can still run shape-only execution. + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + fake_inputs = [] + for t in real_inputs: + if isinstance(t, torch.Tensor): + fake_inputs.append(fake_mode.from_tensor(t)) + else: + fake_inputs.append(t) + + # Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path, + # even though this pass only needs output metadata. Temporarily clear + # attn_mask so shape propagation can proceed, then restore it immediately; + # SDPA output shapes are still determined by Q/K/V shapes, not mask values. + saved_sdpa_masks = [] + for attn_node in get_sdpa_nodes(gm): + attn_mask = attn_node.kwargs.get("attn_mask") + if attn_mask is not None: + saved_sdpa_masks.append((attn_node, attn_mask)) + attn_node.update_kwarg("attn_mask", None) + + try: + # fake_inputs are already created under fake_mode above, so run + # propagation without reconverting them into a different fake mode. + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs) + finally: + for attn_node, attn_mask in saved_sdpa_masks: + attn_node.update_kwarg("attn_mask", attn_mask) + + +def apply_autosp(gm: GraphModule, + real_inputs, + debug: bool = False, + passes: Optional[List[Callable]] = None, + sp_size: int = 2, + dp_size: int = 1): + """ + Apply AutoSP (Ulysses) transformation passes to the graph and setup either DP/SP (2D) or SP (1D) mesh. + + Args: + gm: GraphModule to transform + real_inputs: Example inputs for shape propagation + debug: If True, print graph before/after each pass + passes: Optional custom list of passes (default: DEFAULT_PASSES) + """ + assert sp_size * dp_size <= dist.get_world_size(), 'Insufficient device count for mesh size' + + sp_dp_registry.populate_registry(sp_size, dp_size) + + AUTOSP_PASSES = [ + pass_shard_seq_dim, + pass_shard_input_ids, + pass_shard_label_ids, + pass_shard_position_ids, + pass_insert_attention_all_to_all, + pass_propagate_shapes, + pass_canonicalize, + ] + + passes = passes or AUTOSP_PASSES + rank = dist.get_rank() + + for p in passes: + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" BEFORE: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + + p(gm, real_inputs) + + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" AFTER: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) diff --git a/deepspeed/compile/passes/zero1_compile.py b/deepspeed/compile/passes/zero1_compile.py new file mode 100644 index 000000000000..c4da6ad82fa3 --- /dev/null +++ b/deepspeed/compile/passes/zero1_compile.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import torch +from torch.fx import GraphModule + +from ..util import get_deepcompile_handle +from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, replace_reduce_outputs_with_none + +NAME = "zero1_compile" + + +def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager, use_z2=False) -> GraphModule: + + dc = get_deepcompile_handle() + param_indices = profiling_results[graph_id].param_indices + # Need this before profiling + if use_z2: + dc.register_graph_z2(graph_id, [v[1] for v in param_indices]) + else: + dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) + + return gm + + +def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule: + + graph = gm.graph + pm = param_manager[graph_id] + _, param_name_to_grad = pm.get_bwd_mapping(graph) + + for param_name in pm.param_names: + + grad_node = param_name_to_grad[param_name] + + assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids" + ds_id = pm.ds_ids[param_name] + + new_node = add_postprocess(graph, + grad_node, + torch.ops.dc.reduce_grad.default, + extra_args=[graph_id, ds_id], + name=f"reduce_param_{param_name}", + meta=_make_node_meta(grad_node, param_name, True)) + new_node.meta["val"] = None + + gm.graph = move_primals_to_head(graph) + + add_end_backward(gm.graph, graph_id) + replace_reduce_outputs_with_none(gm.graph) + + return gm + + +def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule: + if bwd: + return add_z1_reduce_bw(gm, graph_id, param_manager) + return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=False) + + +def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule: + if bwd: + return add_z1_reduce_bw(gm, graph_id, param_manager) + return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=True) diff --git a/deepspeed/compile/passes/zero3_compile.py b/deepspeed/compile/passes/zero3_compile.py new file mode 100644 index 000000000000..f09a4dee2adf --- /dev/null +++ b/deepspeed/compile/passes/zero3_compile.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import gc +from typing import List, Dict, Tuple +import _operator + +import torch +from torch.fx import Graph, Node, GraphModule + +from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op +from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, replace_reduce_outputs_with_none +from ..profilers.graph_profile import ProfilingInterpreter +from ..list_schedule import fast_free_schedule + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +NAME = "zero3_compile" + + +def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype): + new_ag_node = add_postprocess(graph, + node, + torch.ops.dc.allgather_param.default, + extra_args=[graph_id, ds_id], + extra_kwargs={"dtype": dtype}, + name=f"allgather_ds_param_{node.target}_{ds_id}", + meta=_make_node_meta(node, ds_id, True)) + new_ag_node.meta["val"] = node.meta["val"].to(dtype) + + # Set the previous node back to output + # We don't want to change the output node to allgather + output_node = get_output_node(graph) + output_node.replace_input_with(new_ag_node, node) + + # Add wait as well + new_wait_node = add_postprocess(graph, + new_ag_node, + torch.ops.dc.wait_allgather.default, + extra_args=[graph_id, ds_id], + name=f"wait_allgather_ds_param__{node.target}_{ds_id}", + meta=_make_node_meta(node, ds_id, False)) + new_wait_node.meta["val"] = new_ag_node.meta["val"] + + return new_ag_node + + +def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int): + new_node = add_postprocess(graph, + node, + torch.ops.dc.release_param.default, + extra_args=[graph_id, ds_id, n_users], + name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}", + meta=_make_node_meta(node, ds_id, False)) + new_node.meta["val"] = None + + +def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int): + new_node = add_postprocess(graph, + grad_node, + torch.ops.dc.reduce_grad.default, + extra_args=[graph_id, ds_id], + name=f"reduce_ds_param_{param_name}", + meta=_make_node_meta(grad_node, ds_id, True)) + new_node.meta["val"] = None + + +def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph: + + node_to_uses = get_real_uses(graph) + for pn in param_nodes: + if len(pn.users) == 0: + continue + + # If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather. + fuse_typecast = False + target_dtype = param_manager.params[pn.name].dtype + if len([user for user in pn.users if user.op != "output"]) == 1: + typecast_node = next(iter(pn.users)) + + is_cast, casted_dtype = is_cast_op(typecast_node) + if is_cast and casted_dtype.itemsize < target_dtype.itemsize: + fuse_typecast = True + target_dtype = casted_dtype + + add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype) + if fuse_typecast: + users = node_to_uses[typecast_node] + wait_node = typecast_node.args[0] + for user in list(typecast_node.users.keys()): + if user.op == "output": + wait_node.meta["original_output_name"] = typecast_node.name + user.replace_input_with(typecast_node, wait_node) + graph.erase_node(typecast_node) + else: + users = node_to_uses[pn] + + ds_id = param_manager.ds_ids[pn.name] + for user in users: + # release_param() only accepts tensors as its first argument. If + # `user` is a tuple, we should release the param after any of + # operator.getitem of that tuple. + # + # Since no torch op takes a tuple as an input, we simply walk + # through users of `user` and check if there is any call to + # operator.getitem. + for secondary_user in user.users: + if secondary_user.op == "call_function" and secondary_user.target == _operator.getitem: + add_release(graph_id, graph, secondary_user, pn, ds_id, len(users)) + break + else: + add_release(graph_id, graph, user, pn, ds_id, len(users)) + + return move_primals_to_head(graph) + + +def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node], + param_name_to_grad: Dict[str, Node]) -> Graph: + + add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw) + + for param_name in param_manager.param_names: + if param_name_to_grad[param_name] is None: + continue + add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name]) + + return move_primals_to_head(graph) + + +def add_z3_gather_release_fw(gm: GraphModule, + graph_id: int, + graph_order: List[Tuple[int, bool]], + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) -> GraphModule: + + nz3 = get_deepcompile_handle() + + real_inputs = create_inputs_fn() + param_indices = profiling_results[graph_id].param_indices + + gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id], + get_param_nodes(gm.graph, param_indices)) + + nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling + + profiler = ProfilingInterpreter(gm, debug_log=debug_log) + profiler.run(*real_inputs) + del profiler + gc.collect() + get_accelerator().empty_cache() + + rank = dist.get_rank() + graph_index = get_index_by_graph_id(graph_order, graph_id) + if rank == 0 and debug_log: + print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + for n in gm.graph.nodes: + is_ds_param = n.name in param_manager[graph_id].ds_ids + if "val" in n.meta and is_ds_param: + # Used for Inductor's validation + n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device) + + gm.graph = fast_free_schedule( + gm.graph, + get_accelerator().available_memory(), + 0, # unused + debug_log=debug_log) + + if rank == 0 and debug_log: + print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + return gm + + +def add_z3_gather_release_bw(gm: GraphModule, + graph_id: int, + graph_order: List[Tuple[int, bool]], + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) -> GraphModule: + + param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph) + gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad) + + input_nodes = get_input_nodes(gm.graph) + real_inputs = create_inputs_fn() + assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}" + + real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs) + + del real_outputs + gc.collect() + get_accelerator().empty_cache() + + rank = dist.get_rank() + graph_index = get_index_by_graph_id(graph_order, graph_id) + if rank == 0 and debug_log: + print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + gm.graph = fast_free_schedule( + gm.graph, + get_accelerator().available_memory(), + 0, # unused + debug_log=debug_log) + + add_end_backward(gm.graph, graph_id) + replace_reduce_outputs_with_none(gm.graph) + + return gm + + +def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results, + create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule: + if bwd: + return add_z3_gather_release_bw(gm, + graph_id, + graph_order, + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) + return add_z3_gather_release_fw(gm, + graph_id, + graph_order, + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) diff --git a/deepspeed/compile/patch_compiled_func.py b/deepspeed/compile/patch_compiled_func.py new file mode 100644 index 000000000000..c77d529a64ac --- /dev/null +++ b/deepspeed/compile/patch_compiled_func.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.utils.torch import required_torch_version + +backward_inputs = [] + +enabled_patched_func = False +original_grad_fn = None +base_meta = type(torch.autograd.Function) + +if required_torch_version(min_version=2.7): + + class FunctionMeta(base_meta): + + def __new__(cls, name, bases, dct): + if name == "CompiledFunction": + original_backward_impl = dct.get("_backward_impl") + + def wrapped_backward_impl(ctx, all_args): + assert original_backward_impl is not None + + if enabled_patched_func: + backward_inputs.append(all_args) + wrapped_backward_impl.owner_class.compiled_bw = None + + return original_backward_impl(ctx, all_args) + + wrapped_backward_impl.owner_class = None + dct["_backward_impl"] = staticmethod(wrapped_backward_impl) + new_class = super().__new__(cls, name, bases, dct) + wrapped_backward_impl.owner_class = new_class + + return new_class + + return super().__new__(cls, name, bases, dct) + +elif required_torch_version(min_version=2.6): + + class FunctionMeta(base_meta): + + def __new__(cls, name, bases, dct): + if name == "CompiledFunction": + original_backward_prologue = dct.get("_backward_prologue") + + def wrapped_backward_prologue(ctx, *grad_outputs): + assert original_backward_prologue is not None + + all_args = original_backward_prologue(ctx, *grad_outputs) + if enabled_patched_func: + backward_inputs.append(all_args) + wrapped_backward_prologue.owner_class.compiled_bw = None + + return all_args + + wrapped_backward_prologue.owner_class = None + dct["_backward_prologue"] = staticmethod(wrapped_backward_prologue) + new_class = super().__new__(cls, name, bases, dct) + wrapped_backward_prologue.owner_class = new_class + + return new_class + + return super().__new__(cls, name, bases, dct) + + +def patch_compiled_func(): + + global enabled_patched_func + enabled_patched_func = True + + class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta): + pass + + global original_grad_fn + original_grad_fn = torch.autograd.Function + torch.autograd.Function = PatchedFunction + + return backward_inputs + + +def unpatch_compiled_func(): + global enabled_patched_func + enabled_patched_func = False + + global original_grad_fn + torch.autograd.Function = original_grad_fn + + +def get_backward_inputs(): + return backward_inputs diff --git a/deepspeed/compile/patch_fake_tensor.py b/deepspeed/compile/patch_fake_tensor.py new file mode 100644 index 000000000000..437a8008b373 --- /dev/null +++ b/deepspeed/compile/patch_fake_tensor.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +try: + from torch._subclasses import FakeTensorMode + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch._dynamo.variables.builder import wrap_to_fake_tensor_and_record +except ImportError: + # Unsupported torch version + pass + + +def wrap_if_ds_param(t): + if hasattr(t, 'ds_id'): + data = torch.rand(t.ds_shape, + dtype=t.dtype, + layout=t.layout, + device=t.device, + pin_memory=t.is_pinned(), + requires_grad=t.requires_grad) + if isinstance(t, torch.nn.Parameter): + t = torch.nn.Parameter(data, requires_grad=t.requires_grad) + else: + t = data + return t + + +def _get_guard_sizes_strides(t): + if hasattr(t, "ds_id"): + # ZeRO-3 may temporarily all-gather a parameter during tracing, but the + # stable module state used by TorchDynamo guards is the released + # partitioned form, where DeepSpeed resets param.data to empty(0). + released = torch.empty(0, dtype=t.dtype, device=t.device) + return released.size(), released.stride() + + return t.size(), t.stride() + + +def patch_fake_tensor(): + # dynamo tracer uses wrap_to_fake_tensor_and_record + # Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor + original_wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record + + def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs): + dummy_tensor = wrap_if_ds_param(t) + ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs) + tx = kwargs.get("tx") if "tx" in kwargs else args[0] + source = kwargs.get("source") + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor) + if source is not None: + # Keep the full ds_shape symbolic context from the dummy tensor, but + # use the stable released ZeRO-3 parameter representation for + # TorchDynamo's tensor-match guards. PyTorch 2.9 started enforcing + # those guards for parameters during build_guards(). + size, stride = _get_guard_sizes_strides(t) + tx.output.input_source_to_sizes_strides[source] = { + "size": size, + "stride": stride, + } + return ret + + torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper + + # aot_module_simplified uses fake_mode.from_tensor to process inputs + original_from_tensor = FakeTensorMode.from_tensor + + def from_tensor_wrapper(self, t, *args, **kwargs): + with unset_fake_temporarily(): + return original_from_tensor(self, wrap_if_ds_param(t), *args, **kwargs) + + FakeTensorMode.from_tensor = from_tensor_wrapper diff --git a/deepspeed/compile/profilers/__init__.py b/deepspeed/compile/profilers/__init__.py new file mode 100644 index 000000000000..7adb54f11872 --- /dev/null +++ b/deepspeed/compile/profilers/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple +from dataclasses import dataclass, field + +from torch.fx import Graph + + +@dataclass +class ProfilingResult: + fwd_graph: Graph = None + bwd_graph: Graph = None + needs_backward: bool = False + fwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) # name, current_alloc, delta, peak + bwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) + fwd_time: List[Tuple[str, int, int]] = field(default_factory=list) # name, device_time, wall_time + bwd_time: List[Tuple[str, int, int]] = field(default_factory=list) + fwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) # name, size + bwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) + param_indices: List[Tuple[int, int, Tuple[int, ...]]] = field(default_factory=list) # index, ds_id, ds_shape diff --git a/deepspeed/compile/profilers/comm_profile.py b/deepspeed/compile/profilers/comm_profile.py new file mode 100644 index 000000000000..18bd517c1e8f --- /dev/null +++ b/deepspeed/compile/profilers/comm_profile.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + + +def sync_all(): + get_accelerator().synchronize() + dist.barrier() + + +def get_bw(comm_op, size, duration): + n = dist.get_world_size() + tput = 0 + busbw = 0 + + if duration == 0: + raise ValueError("Error. Duration is 0.") + + if comm_op == "all_to_all": + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_gather": + size *= n + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_reduce": + tput = (size * 2 / duration) + busbw = (size / duration) * (2 * (n - 1) / n) + elif comm_op == "pt2pt" or comm_op == "broadcast": + tput = (size / duration) + busbw = tput + else: + raise ValueError("wrong comm_op specified") + + return tput, busbw + + +# Run all_gather and print metrics +def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op): + sync_all() + # Warmups, establish connections, etc. + for i in range(warmup): + dist.all_gather_into_tensor(output, input, async_op=async_op) + sync_all() + + # time the actual comm op trials times and average it + start_event.record() + for i in range(trials): + dist.all_gather_into_tensor(output, input, async_op=async_op) + end_event.record() + sync_all() + duration = start_event.elapsed_time(end_event) / 1000 + + # maintain and clean performance data + avg_duration = duration / trials + size = input.element_size() * input.nelement() * dist.get_world_size() + # tput, busbw = get_bw('all_gather', size, avg_duration) + + avg_duration_ten = torch.tensor([avg_duration], device=device) + if dist.get_world_size() > 1: + dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG) + + return size, avg_duration_ten.item() + + +def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False): + + # Prepare benchmark header + global_rank = dist.get_rank() + world_size = dist.get_world_size() + + start_event = get_accelerator().Event(enable_timing=True) + end_event = get_accelerator().Event(enable_timing=True) + + # Create list of message sizes + M_LIST = [] + for x in (2**p for p in range(1, maxsize)): + m = x // world_size + if m > 0: + M_LIST.append(m) + + results = [(0, 0)] + sync_all() + # loop over various tensor sizes + for M in M_LIST: + global_rank = dist.get_rank() + try: + mat = torch.ones(M, dtype=dtype, device=device) + sync_all() + input = ((mat.mul_(float(global_rank))).view(-1)) + # Delete original mat to avoid OOM + del mat + get_accelerator().empty_cache() + output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device) + except RuntimeError as e: + if 'out of memory' in str(e): + if dist.get_rank() == 0: + print('WARNING: Ran out of GPU memory. Exiting comm op.') + sync_all() + break + else: + raise e + sync_all() + results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op)) + + return results + + +profile_results = None + + +def create_predictor(): + global profile_results + if profile_results is None: + with unset_fake_temporarily(): + device = get_accelerator().current_device() + profile_results = run_all_gather(device, torch.bfloat16, 31) + if dist.get_rank() == 0: + for size, avg_duration in profile_results: + print(f"size: {size}, avg_duration: {avg_duration}") + + # Extract size and avg_duration from results + sizes = [result[0] for result in profile_results] + durations = [result[1] for result in profile_results] + + try: + from scipy.interpolate import interp1d + except ImportError: + raise RuntimeError("Please install scipy to use communication profiler in DeepCompile") + + predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate") + + def f(size): + if size == 0: + return 0 + return predictor(size) + + # Create an interpolation function + return f + + +if __name__ == "__main__": + local_rank = int(os.environ['LOCAL_RANK']) + get_accelerator().set_device(local_rank) + print(f"local_rank={local_rank}") + + deepspeed.init_distributed(dist_backend='nccl') + + # Create predictor function + predictor = create_predictor() + + # Predict time for a specific data size + example_size = 1e9 + predicted_time = predictor(example_size) + print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds") + + dist.destroy_process_group() diff --git a/deepspeed/compile/profilers/graph_profile.py b/deepspeed/compile/profilers/graph_profile.py new file mode 100644 index 000000000000..b2cac6859a9b --- /dev/null +++ b/deepspeed/compile/profilers/graph_profile.py @@ -0,0 +1,310 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +from typing import Any, Tuple, Dict +import statistics + +import torch +from torch.fx import GraphModule, Interpreter +from torch.fx.node import map_aggregate + +try: + from torch.utils._pytree import tree_all, tree_leaves + from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake +except ImportError: + # Unsupported torch version + pass + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from ..util import is_comm_op, is_release_node, get_deepcompile_handle + + +def _all_real_if_tensor(args): + return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args) + + +def _to(v, device): + if torch.is_tensor(v): + with unset_fake_temporarily(): + return v.to(device) + return v + + +def _args_to_key(v): + + def _tensor_to_key(v) -> str: + if torch.is_tensor(v): + if v.numel() == 1: + try: + return f"{v.dtype}{v.device}{v.item()}" + except Exception as e: + return f"{v.dtype}{v.device}ptr{v.data_ptr()}" + else: + return f"{v.dtype}{v.device}{v.shape}" + return str(v) + + return map_aggregate(v, _tensor_to_key) + + +def _node_size(out): + return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)]) + + +def _get_mem_usage_out_of_torch(): + + adjust = 0 + try: + import pynvml + pynvml.nvmlInit() + + current_dev_id = get_accelerator().current_device() + handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + torch_alloc = get_accelerator().memory_allocated() + adjust = info.used - torch_alloc + except Exception: + # pynvml not available + pass + + return adjust + + +# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html +class ProfilingInterpreter(Interpreter): + + def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False): + super().__init__(gm) + + self.nz3 = get_deepcompile_handle() + + assert iteration > 0 + assert warmup >= 0 + self.iteration = iteration + self.warmup = warmup + self.device = torch.device(get_accelerator().current_device()) + self.cache: Dict[Tuple, Any] = {} + self.distributed = dist.is_initialized() + self.allgather_mem: Dict[int, int] = {} + self.debug_log = debug_log + self.mem_usage_out_of_torch = 0 + + def run(self, *args) -> Any: + """Run the graph with profiling enabled. + + args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters. + returns: The output of the graph. Tensor in the output is real tensors. + """ + return_val = None + try: + assert _all_real_if_tensor(args), "Inputs must be real tensors" + self.nz3.enable_profiling(True) + + with unset_fake_temporarily(): + with get_accelerator().random().fork_rng(devices=[self.device]): + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() + return_val = super().run(*args) + except Exception as e: + msg = e.msg if "msg" in dir(e) else str(e) + print(f"Profiling error {msg}") + finally: + self.nz3.clear_all_gathered_params() + self.nz3.enable_profiling(False) + return return_val + + def run_node(self, n: torch.fx.Node) -> Any: + + if n.op in {"placeholder", "output"}: + n.meta["device_time"] = 0.0 + n.meta["wall_time"] = 0.0 + n.meta["alloc_mem"] = 0 + n.meta["max_mem"] = 0 + n.meta["tensor_size"] = _node_size(n) + return super().run_node(n) + + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + partitioned_params = {} + + def rebuild_param_if_necessary(v): + if hasattr(v, "ds_id"): + v.all_gather(param_list=[v]) + if hasattr(v, "ds_target_dtype"): + casted = v.to(v.ds_target_dtype) + partitioned_params[id(casted)] = v + return casted + return v + + args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x)) + + args = map_aggregate(args, lambda x: _to(x, self.device)) + kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) + + cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs)) + cache_hit = cache_key in self.cache + + cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int) + if self.distributed: + dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM) + cache_hit = cache_hit_flag.item() == 0 + + if cache_hit: + device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key] + n.meta["device_time"] = device_time + n.meta["wall_time"] = wall_time + n.meta["alloc_mem"] = alloc_mem + n.meta["max_mem"] = max_mem + n.meta["tensor_size"] = tensor_size + + is_release_op = is_release_node(n) + run_only_once = cache_hit or is_release_op + iteration = 1 if run_only_once else self.iteration + accelerator = get_accelerator() + start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] + end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] + + get_accelerator().reset_peak_memory_stats() + alloc_mem_start = get_accelerator().memory_allocated() + max_mem_start = get_accelerator().max_memory_allocated() + + if not run_only_once: + for i in range(self.warmup): + out = getattr(self, n.op)(n.target, args, kwargs) + + if is_comm_op(n): + assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used." + dist.barrier() + + start = time.time() + for i in range(iteration): + start_events[i].record() + out = getattr(self, n.op)(n.target, args, kwargs) + end_events[i].record() + accelerator.synchronize() + walltime_sum = time.time() - start + + if is_comm_op(n): + dist.barrier() + + alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch + max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch + tensor_size = _node_size(out) + + def partition_param_if_necessary(v): + if id(v) in partitioned_params: + v = partitioned_params[id(v)] + if hasattr(v, "ds_id") and not v.ds_persist: + v.partition(param_list=[v], has_been_updated=False) + return v + + args = map_aggregate(args, lambda x: partition_param_if_necessary(x)) + + if not cache_hit: + device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) + wall_time = walltime_sum / iteration * 1000 + + with unset_fake_temporarily(): + vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size], + device=self.device) + if self.distributed: + dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG) + n.meta["device_time"] = vals_to_bcast[0].item() + n.meta["wall_time"] = vals_to_bcast[1].item() + n.meta["alloc_mem"] = int(vals_to_bcast[2].item()) + n.meta["max_mem"] = int(vals_to_bcast[3].item()) + n.meta["tensor_size"] = int(vals_to_bcast[4].item()) + self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"], + n.meta["max_mem"], n.meta["tensor_size"]) + + if is_release_op: + n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0) + + if dist.get_rank() == 0 and self.debug_log: + print( + f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}" + ) + + if n.target == torch.ops.dc.allgather_param.default: + out = args[0] + assert hasattr(out, "ds_id") + if not out.ds_persist: + self.nz3.invalidate_gathered_param(args[2]) + if "dtype" in n.kwargs: + setattr(out, "ds_target_dtype", n.kwargs["dtype"]) + self.allgather_mem[out.ds_id] = n.meta["alloc_mem"] + + return out + + +class MemoryProfilingInterpreter(Interpreter): + + def __init__(self, gm: GraphModule, debug_log=False): + super().__init__(gm) + self.nz3 = get_deepcompile_handle() + self.device = torch.device(get_accelerator().current_device()) + self.mem_record = [] + self.last_alloc = get_accelerator().memory_allocated() + + self.node_counter = 0 + self.node_num = len(gm.graph.nodes) + self.debug_log = debug_log + + def run(self, *args) -> Any: + return_val = None + try: + assert _all_real_if_tensor(args), "Inputs must be real tensors" + self.nz3.enable_profiling(True) + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() + + with unset_fake_temporarily(): + with get_accelerator().random().fork_rng(devices=[self.device]): + return_val = super().run(*args) + except Exception as e: + print(f"MemoryProfiling error {e}") + finally: + self.nz3.enable_profiling(False) + + return return_val + + def run_node(self, n: torch.fx.Node) -> Any: + get_accelerator().reset_peak_memory_stats() + + if n.op in {"placeholder", "output"}: + ret = super().run_node(n) + else: + args, kwargs = self.fetch_args_kwargs_from_env(n) + args = map_aggregate(args, lambda x: _to(x, self.device)) + kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) + ret = getattr(self, n.op)(n.target, args, kwargs) + + del args, kwargs + + current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch + max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch + vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device, dtype=torch.int64) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX) + current_alloc = vals_to_bcast[0].item() + max_alloc = vals_to_bcast[1].item() + + self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc)) + + self.node_counter += 1 + if self.debug_log and dist.get_rank() == 0: + print( + f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB" + ) + + self.last_alloc = current_alloc + + return ret + + def dump(self, path): + import pandas as pd + df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"]) + df.to_csv(path, index=False) diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py new file mode 100644 index 000000000000..0fd3b6b389db --- /dev/null +++ b/deepspeed/compile/util.py @@ -0,0 +1,614 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import functools +import operator +from typing import List, Tuple, Dict, Optional +from collections import defaultdict + +import torch +from torch.fx import Node, Graph, GraphModule +from torch.fx.node import map_aggregate, Argument, map_arg +import torch.nn.functional as F + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version +from deepspeed.ops.op_builder.dc import DeepCompileBuilder +from deepspeed.compile import constants + +from .custom_ops import sp_dp_registry + + +def is_deepcompile_supported() -> bool: + return required_torch_version(min_version=2.6) and get_accelerator().device_name() == "cuda" + + +dc_handle = None + +if is_deepcompile_supported(): + sym_size_ops = { + operator.ge, + operator.le, + operator.eq, + operator.ne, + operator.gt, + operator.lt, + torch.ops.aten.sym_size.int, + operator.getitem, + } + + +def get_deepcompile_handle(): + global dc_handle + if dc_handle is None: + dc_handle = DeepCompileBuilder().load() + return dc_handle + + +def is_backend_inductor(backend): + return backend == "inductor" + + +backward_started = False +pre_backward_hooks = [] + + +def add_pre_backward_hook(hook): + pre_backward_hooks.append(hook) + + +def deepcompile_backward_prologue(is_gradient_accumulation_boundary): + + for hook in pre_backward_hooks: + hook() + + dc = get_deepcompile_handle() + dc.start_backward(is_gradient_accumulation_boundary) + + +def log_rank0(msg: str, enable: bool = False): + if dist.get_rank() == 0 and enable: + print(msg) + + +@functools.lru_cache +def get_no_copy_ops(): + # Need to compile custom ops + get_deepcompile_handle() + + no_copy_ops = {torch.ops.dc.wait_allgather.default} + + # All operations whose return value aliases any of their inputs are included + # in the returned list to ensure that the last user of a node is computed + # correctly. + # + # This can be overly conservative if not all input tensors are aliased in + # the output. While we can determine exactly which tensors are aliased, a + # finer-grained algorithm is required in get_last_uses() and get_real_uses() + # to utilize that information. This is left as future work when real needs + # arise. + warned = False + for op_name in torch.ops.aten: + packet = getattr(torch.ops.aten, op_name) + for overload_name in packet: + op = getattr(packet, overload_name) + try: + for return_info in op._schema.returns: + if isinstance(return_info.type, torch.TensorType) and return_info.alias_info is not None: + no_copy_ops.add(op) + break + except AttributeError: + # In case no schema is available, conservatively assume the op + # may reuse tensor storage and print a one-time warning on its + # potential performance impact. + if not warned: + log_rank0( + f"WARNING: Schema is missing for some torch.ops.aten ops (e.g. {op_name}.{overload_name})." + "We assume those ops may reuse tensor storage. This may impact performance of compiled models.", + enable=True, + ) + warned = True + no_copy_ops.add(op) + + return no_copy_ops + + +def get_input_nodes(graph: Graph) -> List[Node]: + return [n for n in graph.nodes if n.op == "placeholder"] + + +def get_param_nodes(graph: Graph, index_to_ds_ids: List[Tuple[int, int]]) -> List[Node]: + all_input_nodes = get_input_nodes(graph) + return [all_input_nodes[i] for i, _, _ in index_to_ds_ids] + + +def is_comm_op(node: Node) -> bool: + return "comm" in node.meta and node.meta["comm"] + + +def is_cast_op(node: Node) -> Tuple[bool, Optional[torch.dtype]]: + if node.op == "call_function": + if node.target == torch.ops.prims.convert_element_type.default: + return (True, node.args[1]) + elif node.target == torch.ops.aten._to_copy.default and set(node.kwargs.keys()) == {"dtype"}: + return (True, node.kwargs["dtype"]) + return (False, None) + + +def exclude_from_act_offload(node: Node) -> bool: + return node.target in sym_size_ops + + +def dtype_to_elem_size(dtype: torch.dtype) -> int: + if dtype == torch.float32: + elem_size = 4 + elif dtype == torch.float64: + elem_size = 8 + elif dtype == torch.float16: + elem_size = 2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + return elem_size + + +def tensor_meta_size(tensor_meta) -> int: + numel = 1 if len(tensor_meta.shape) == 0 else functools.reduce(operator.mul, tensor_meta.shape) + + dtype = tensor_meta.dtype + if dtype == torch.float32: + elem_size = 4 + elif dtype == torch.float64 or dtype == torch.int64: + elem_size = 8 + elif dtype == torch.float16 or dtype == torch.bfloat16: + elem_size = 2 + elif dtype == torch.bool: + elem_size = 1 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + return numel * elem_size + + +class NodeValueOffloadHelper: + + def __init__(self, device): + self.device = device + self.env_values: Dict[str, Argument] = {} + self.original_device: Dict[torch.Tensor, torch.device] = {} + + def _to_cpu(self, v): + if torch.is_tensor(v): + with unset_fake_temporarily(): + device = v.device + offloaded = v.to('cpu').detach() + self.original_device[offloaded] = device + return offloaded + return v + + def _from_cpu(self, v): + if torch.is_tensor(v) and v in self.original_device: + return v.to(self.original_device[v]) + return v + + def save(self, name: str, v: Argument, offload) -> None: + self.env_values[name] = map_aggregate(v, lambda x: self._to_cpu(x) if offload else x) + + def load(self, name: str) -> Argument: + return map_aggregate(self.env_values[name], lambda x: self._from_cpu(x)) + + def get_offloaded_value(self, name: str) -> Argument: + return self.env_values[name] + + def has_value(self, name: str) -> bool: + return name in self.env_values + + def clear(self) -> None: + self.env_values.clear() + self.original_device.clear() + + +def materialize_fake(v, device=None): + from torch._subclasses.fake_tensor import is_fake + + def convert(t): + if is_fake(t): + with unset_fake_temporarily(): + if t.is_floating_point(): + return torch.randn(t.shape, + dtype=t.dtype, + device=t.device if device is None else device, + layout=t.layout, + requires_grad=t.requires_grad, + pin_memory=t.is_pinned()) + else: + return torch.zeros(t.shape, + dtype=t.dtype, + device=t.device if device is None else device, + requires_grad=t.requires_grad) + + return t + + return map_aggregate(v, lambda x: convert(x)) + + +def get_last_uses(graph: Graph): + position = {node: i for i, node in enumerate(graph.nodes)} + + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + no_copy_ops = get_no_copy_ops() + + def register_last_uses(n: Node, user: Node): + update = False + known_last_use = None + + if user.target in no_copy_ops and n in node_to_last_use: + last_user = node_to_last_use[user] + last_use_position = position[last_user] + + known_last_use = node_to_last_use[n] + known_last_use_position = position[known_last_use] + update = last_use_position > known_last_use_position + + if n not in node_to_last_use or update: + if user.target in no_copy_ops: + user = node_to_last_use[user] + + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + if known_last_use: + user_to_last_uses[known_last_use].remove(n) + + for node in reversed(graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + return node_to_last_use, user_to_last_uses + + +def get_real_uses(graph: Graph): + node_to_uses: Dict[Node, List[Node]] = defaultdict(list) + no_copy_ops = get_no_copy_ops() + + def register_last_uses(n: Node, user: Node): + if user.target == "output": + return + + if user.target in no_copy_ops: + users = node_to_uses[user] + node_to_uses[n].extend(users) + else: + node_to_uses[n].append(user) + + for node in reversed(graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + return node_to_uses + + +def count_inflight_values(graph: Graph, file_path: str): + position = {node: i for i, node in enumerate(graph.nodes)} + + node_to_last_use, user_to_last_uses = get_last_uses(graph) + + max_inflight_size = 0 + inflight_values = set() + + # Output csv. + csv_filename = file_path + csv_data = [] + header = [ + 'Node', 'tensor_size', 'inflight_size', 'inflight_size_in_output', 'args', 'users', 'node_to_last_use', + 'lifetime', 'user_to_last_uses', 'inflight_values' + ] + csv_data.append(header) + + from .fx import get_output_node + output_node = get_output_node(graph) + values_in_output = set([n for n in output_node.args[0] if isinstance(n, Node)]) + + for node in graph.nodes: + inflight_values.add(node) + if node in user_to_last_uses: + for to_delete in user_to_last_uses[node]: + inflight_values.remove(to_delete) + + assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size" + inflight_size = sum(n.meta["tensor_size"] for n in inflight_values) + inflight_size_in_output = sum(n.meta["tensor_size"] for n in inflight_values if n in values_in_output) + + lifetime = position[node_to_last_use[node]] - position[node] if node in node_to_last_use else 0 + + row = [ + node.name, node.meta["tensor_size"], inflight_size, inflight_size_in_output, + [a.name for a in node.args if isinstance(a, Node)], + list(node.users.keys()), node_to_last_use[node] if node in node_to_last_use else 'NA', lifetime, + user_to_last_uses[node] if node in user_to_last_uses else 'NA', + list(inflight_values) + ] + csv_data.append(row) + + # print( + # f"Node: {node.name} users: {list(node.users.keys())} node_to_last_use: {node_to_last_use[node] if node in node_to_last_use else 'NA'} user_to_last_uses: {user_to_last_uses[node] if node in user_to_last_uses else 'NA'} inflight_values: {inflight_values} inflight_size: {inflight_size}" + # ) + max_inflight_size = max(max_inflight_size, inflight_size) + + import csv + with open(csv_filename, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerows(csv_data) + + print(f"Max inflight size: {max_inflight_size}") + print(f"Data successfully written to {csv_filename}") + + +def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]): + + input_nodes = get_input_nodes(graph) + param_node_names = set([n.name for n in param_nodes_bw]) + + activation_node_names = [] + for in_node in input_nodes: + if in_node.name in fwd_output_names: + if in_node.name not in param_node_names: + activation_node_names.append(in_node.name) + + return activation_node_names + + +class TensorOffloadHelper(): + + def __init__(self): + self.devices = {} + self.base_tensors = {} + self.views = {} + self.arg_list = [] + self.offloaded = {} + self.non_tensor = {} + + def offload(self, argument): + + def is_base_tensor(tensor): + return torch.is_tensor(a) and not a._is_view() and not hasattr(tensor, "ds_id") + + base_tensor_ids = set() + for a in argument: + if is_base_tensor(a): + base_tensor_ids.add(id(a)) + + for a in argument: + a_id = id(a) + + if is_base_tensor(a): + # Base tensor + self.devices[a_id] = a.device + self.base_tensors[a_id] = a + # elif torch.is_tensor(a) and not hasattr(a, "ds_id") and id(a._base) in base_tensor_ids: + # # View + # self.views[a_id] = { + # "base_id": id(a._base), + # "size": a.size(), + # "stride": a.stride(), + # "offset": a.storage_offset(), + # } + else: + # other types or ds tensor + self.non_tensor[a_id] = a + + self.arg_list.append(a_id) + + for a in argument: + if is_base_tensor(a): + a.data = a.data.to("cpu") + + def reload(self, in_place): + + loaded_base_tensors = {} + for a_id in self.arg_list: + if a_id in self.base_tensors: + device = self.devices[a_id] + + if in_place: + self.base_tensors[a_id].data = self.base_tensors[a_id].to(device) + loaded_base_tensors[a_id] = self.base_tensors[a_id] + else: + loaded_base_tensors[a_id] = self.base_tensors[a_id].to(device) + + results = [] + for a_id in self.arg_list: + if a_id in self.base_tensors: + results.append(loaded_base_tensors[a_id]) + + # elif a_id in self.views: + # view_info = self.views[a_id] + # # print(f"load_args loading view {a_id} base_id={view_info['base_id']} size={view_info['size']} stride={view_info['stride']} offset={view_info['offset']}") + # base_tensor = loaded_base_tensors[view_info["base_id"]] + # view_tensor = base_tensor.as_strided( + # view_info["size"], view_info["stride"], view_info["offset"] + # ) + # results.append(view_tensor) + + elif a_id in self.non_tensor: + results.append(self.non_tensor[a_id]) + + return results + + +def add_mem_profile_nodes(graph: Graph, prefix: str): + + def show_memory(label: str): + if dist.get_rank() == 0: + print( + f"{prefix} {label} alloc_mem={get_accelerator().memory_allocated()} max_mem={get_accelerator().max_memory_allocated()}" + ) + + nodes = list(graph.nodes) + for node in nodes: + if node.op == "output": + continue + + with graph.inserting_after(node): + msg = f"Mem {node.name}" + name = f"show_memory_{node.name}" + graph.create_node('call_function', show_memory, (msg, ), {}, name=name) + + +def is_release_node(n: Node) -> bool: + return n.target == torch.ops.dc.release_param.default + + +def get_index_by_graph_id(graph_order, target_graph_id): + for index, (graph_id, _) in enumerate(graph_order): + if graph_id == target_graph_id: + return index + return -1 + + +def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor]: + """ + specs = [ + (input_ids, 1, pad_token_id), # Example: Pad the right side with + (attention_mask, 1, 0), # Example: Pad the right side with 0 + ... + ] + + - Share the "maximum length of the dim dimension" across ranks for all specs + - Pad the right side for the missing parts and return + - Communication (`all_reduce`) happens only once + """ + assert len(specs) > 0, "specs is empty" + + device = specs[0][0].device + # Vectorize local lengths + local_sizes = torch.tensor( + [tensor.size(dim) for tensor, dim, _ in specs], + dtype=torch.long, + device=device, + ) + + # Element-wise MAX across ranks + dist.all_reduce(local_sizes, op=dist.ReduceOp.MAX) + max_sizes = local_sizes.tolist() + + # Pad each tensor as needed + padded: List[torch.Tensor] = [] + + # Don't use F.pad here: + # If you don't need to pad only on a certain rank, it will lead to different strides across ranks. + # This will cause recompilation on only some ranks and get the communication collective stuck. + for (tensor, dim, pad_val), max_len in zip(specs, max_sizes): + cur_len = tensor.size(dim) + + # --- (1) Always allocate a new buffer with 'row-major, contiguous memory' ------------- + out_shape = list(tensor.shape) + out_shape[dim] = max_len + out = torch.full(out_shape, pad_val, dtype=tensor.dtype, device=tensor.device) + + # --- (2) Copy original data using slicing ------------------------------ + slc = [slice(None)] * tensor.dim() + slc[dim] = slice(0, cur_len) + out[tuple(slc)] = tensor + + # out is always row-major: for example, if shape is (..., 1, L), then + # stride = (..., L, 1) + padded.append(out) + + return padded + + +def create_shard_offsets(gm: GraphModule, s0_node: Node) -> Tuple[Node, Node]: + sp_size: int = sp_dp_registry.sp_size() + sp_rank: int = dist.get_rank() % sp_dp_registry.sp_size() + with gm.graph.inserting_after(s0_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(s0_node, sp_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(sp_rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + + +def get_sdpa_nodes(gm: GraphModule) -> List[Node]: + return list(gm.graph.find_nodes( + op="call_function", + target=F.scaled_dot_product_attention, + )) + + +def get_input_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_INPUT_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the input sequence.") + return node + + +def get_label_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_LABEL_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the label.") + return node + + +def get_position_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_POSITION_ID_KEY) + return node + + +def create_symbolic_slice_indices( + gm: GraphModule, + sym_seq_dim_node: Node, +) -> Tuple[Node, Node]: + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node) + + with gm.graph.inserting_after(end_node): + slice_all = gm.graph.call_function(slice, args=(None, None, None)) + with gm.graph.inserting_after(slice_all): + slice_range = gm.graph.call_function(slice, args=(start_node, end_node, None)) + + return slice_all, slice_range + + +def shard_tensor_node(gm: GraphModule, tensor_node: Node): + from .fx import find_node_by_name, get_node_shape_meta, replace_node_users + val = get_node_shape_meta(tensor_node) + assert val is not None, f"Node {tensor_node.name} has no shape metadata" + + seq_len = val.shape[1] + + assert isinstance( + seq_len, + torch.SymInt), (f"Expected sequence dimension to be {torch.SymInt!r} but instead found {type(seq_len)!r}") + + symb_seq_int_node = find_node_by_name(gm, str(seq_len)) + assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" + + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node) + indices = (slice_all, slice_range) + + positions = {node: i for i, node in enumerate(gm.graph.nodes)} + # Insert after the later dependency so the new getitem does not appear + # before the symbolic slice nodes in graph order. Torch 2.9 bf16 can place + # the SymInt placeholder after the tensor placeholder. + anchor_node = slice_range if positions[slice_range] > positions[tensor_node] else tensor_node + with gm.graph.inserting_after(anchor_node): + sliced_node = gm.graph.call_function( + operator.getitem, + args=(tensor_node, indices), + ) + + replace_node_users(tensor_node, sliced_node, exclude=[sliced_node]) diff --git a/deepspeed/compression/basic_layer.py b/deepspeed/compression/basic_layer.py index a5b872fa3a65..bc2b54951bbe 100644 --- a/deepspeed/compression/basic_layer.py +++ b/deepspeed/compression/basic_layer.py @@ -16,7 +16,7 @@ class QuantAct(nn.Module): """ - Class to quantize given activations. Note that when using this function, the input acttivation quantization range will be fixed for all + Class to quantize given activations. Note that when using this function, the input activation quantization range will be fixed for all tokens/images for inference. This generally will affect some accuracy but achieve better latency performance. Parameters: ---------- @@ -170,7 +170,7 @@ def enable_row_pruning(self, ratio, method): if method == 'l1': # compute the l1 norm of each column - weight_norm = torch.norm(self.weight.data, p=1, dim=1) + weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=1) mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False) mask = mask.view(-1, 1) mask = mask.to(self.weight.device) @@ -465,7 +465,7 @@ def enable_channel_pruning(self, ratio, method): if method == 'l1': # compute the l1 norm of each conv2d kernel (the last three dimension) - weight_norm = torch.norm(self.weight.data, p=1, dim=[1, 2, 3]) + weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=[1, 2, 3]) mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False) mask = mask.view(-1, 1, 1, 1) mask = mask.to(self.weight.device) @@ -618,7 +618,7 @@ def fix_channel_pruning_helper(self, mask, dim_reduction=True): def _reduce(input_): - """All-reduce the the input tensor across model parallel group.""" + """All-reduce the input tensor across model parallel group.""" group = g_mpu.get_model_parallel_group() # Bypass the function if we are using only 1 GPU. @@ -673,7 +673,7 @@ def _split(input_): def _gather(input_): - """Gather tensors and concatinate along the last dimension.""" + """Gather tensors and concatenate along the last dimension.""" group = g_mpu.get_model_parallel_group() # Bypass the function if we are using only 1 GPU. @@ -708,7 +708,7 @@ def backward(ctx, grad_output): class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-redcue the input from the model parallel region.""" + """All-reduce the input from the model parallel region.""" @staticmethod def forward(ctx, input_): @@ -732,7 +732,7 @@ def backward(ctx, grad_output): class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from model parallel region and concatinate.""" + """Gather the input from model parallel region and concatenate.""" @staticmethod def forward(ctx, input_): diff --git a/deepspeed/compression/compress.py b/deepspeed/compression/compress.py index 37d98d9496fd..2f0e88beee21 100644 --- a/deepspeed/compression/compress.py +++ b/deepspeed/compression/compress.py @@ -11,6 +11,11 @@ import os import json +try: + import neural_compressor as nc +except ImportError as e: + nc = None + def check_deepspeed_config(config): if isinstance(config, dict): @@ -117,6 +122,26 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None): layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu) compression_preparation(c_model, layer_added_compress_methods, mpu) + # For sparse pruning snip_momentum method + shared_parameters = compress_methods[SPARSE_PRUNING][SHARED_PARAMETERS] + if shared_parameters[SPARSE_PRUNING_ENABLED] and \ + shared_parameters[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM: + + assert nc is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning" + + from .helper import generate_pruners, register_on_step_begin + from nc import WeightPruningConfig + + config = WeightPruningConfig(target_sparsity=1 - shared_parameters[SPARSE_PRUNING_DENSE_RATIO], + pattern=shared_parameters[SPARSE_PRUNING_BLOCK_PATTERN], + pruning_frequency=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE], + start_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET], + end_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_END], + excluded_op_names=shared_parameters[SPARSE_PRUNING_EXCLUDED_MODULES]) + pruners = generate_pruners(config, c_model) + c_model.pruners = pruners + register_on_step_begin(c_model) + return model @@ -187,17 +212,17 @@ def student_initialization(student_model, teacher_model, deepspeed_config): The prefix name before the layer #. Example 1: bert.encoder.layer, for BERT_base model's prefix name Example 2: transformer.h, for GPT-2 hugging face prefix name - teacher_layer (`list of intergers`) - The layer of teacher will be used for student's reinitializedion + teacher_layer (`list of integers`) + The layer of teacher will be used for student's reinitialization Example 1: [1,3,5,7,9], means we want to matches the 2nd/4th/6th/8th/10th layer of teacher to the first 5 layers of student student_layer (`list` or None) - The layer of student need to be re-intiialized + The layer of student need to be re-initialized Example 1: None, means we want to reinitialize all the layers Example 1: [0,1,2,3,4], means we want to reinitialize the first 5 layers other_module_name (`list of string`) - The modules will be used for student's reinitializedion + The modules will be used for student's reinitialization Example 1: ['bert.pooler', 'bert.embeddings', 'classifier'], means we want to apply the weight in teacher's embedding/pooler/classier module to the student - Example 2: ['transformer.w', 'transformer.ln_f', 'lm_head'], means we want to apply the weight in teacher's embeddingn layers module to the student + Example 2: ['transformer.w', 'transformer.ln_f', 'lm_head'], means we want to apply the weight in teacher's embedding layers module to the student Note that teacher_layer should matches student layer ''' assert len(student_layer) == len(teacher_layer) diff --git a/deepspeed/compression/config.py b/deepspeed/compression/config.py index d6e241bd0f80..0fab1032fc87 100644 --- a/deepspeed/compression/config.py +++ b/deepspeed/compression/config.py @@ -5,7 +5,7 @@ from .constants import * import copy -from ..runtime.config_utils import get_scalar_param +from ..runtime.config_utils import get_scalar_param, get_list_param def get_compression_config(param_dict): @@ -221,15 +221,17 @@ def get_sparse_pruning(param_dict): # shared parameters output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict) # each sub-groups - if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]: + if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED] and output[SHARED_PARAMETERS][ + SPARSE_PRUNING_METHOD] != SPARSE_PRUNING_METHOD_SNIP_MOMENTUM: assert DIFFERENT_GROUPS in sub_param_dict.keys( - ), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified" + ), f"Sparse Pruning is enabled and not snip_momentum method, {DIFFERENT_GROUPS} must be specified" output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict) return output def get_sparse_pruning_shared_parameters(param_dict): output = {} + if SHARED_PARAMETERS in param_dict.keys(): sub_param_dict = param_dict[SHARED_PARAMETERS] output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED, @@ -237,10 +239,26 @@ def get_sparse_pruning_shared_parameters(param_dict): output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD, SPARSE_PRUNING_METHOD_DEFAULT) assert output[SPARSE_PRUNING_METHOD] in [ - SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK - ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]" + SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK, SPARSE_PRUNING_METHOD_SNIP_MOMENTUM + ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}, {SPARSE_PRUNING_METHOD_SNIP_MOMENTUM}]" output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET, SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT) + if output[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM: + output[SPARSE_PRUNING_BLOCK_PATTERN] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_BLOCK_PATTERN, + SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT) + output[SPARSE_PRUNING_DENSE_RATIO] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_DENSE_RATIO, + SPARSE_PRUNING_DENSE_RATIO_DEFAULT) + assert output[SPARSE_PRUNING_DENSE_RATIO] > 0 and output[ + SPARSE_PRUNING_DENSE_RATIO] < 1, "Invalid dense_ratio value. Must be less than 1" + output[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE] = get_scalar_param( + sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT) + output[SPARSE_PRUNING_EXCLUDED_MODULES] = get_list_param(sub_param_dict, SPARSE_PRUNING_EXCLUDED_MODULES, + SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT) + output[SPARSE_PRUNING_SCHEDULE_OFFSET_END] = get_scalar_param(sub_param_dict, + SPARSE_PRUNING_SCHEDULE_OFFSET_END, + output[SPARSE_PRUNING_SCHEDULE_OFFSET]) + assert output[SPARSE_PRUNING_SCHEDULE_OFFSET] <= output[ + SPARSE_PRUNING_SCHEDULE_OFFSET_END], "Invalid schedule_offset and schedule_offset_end values" else: output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT diff --git a/deepspeed/compression/constants.py b/deepspeed/compression/constants.py index 18268e3bbae4..67375d510a4b 100644 --- a/deepspeed/compression/constants.py +++ b/deepspeed/compression/constants.py @@ -12,6 +12,7 @@ DIFFERENT_GROUPS = "different_groups" TECHNIQUE_ENABLED = "enabled" TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset" +TECHNIQUE_SCHEDULE_OFFSET_END = "schedule_offset_end" DIFFERENT_GROUPS_PARAMETERS = "params" DIFFERENT_GROUPS_MODULE_SCOPE = "modules" DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*" @@ -32,7 +33,7 @@ OTHER_MODULE_NAME = "other_module_name" #### -# Weight Quantzation +# Weight Quantization #### WEIGHT_QUANTIZATION = "weight_quantization" @@ -111,11 +112,25 @@ SPARSE_PRUNING_METHOD_DEFAULT = "l1" SPARSE_PRUNING_METHOD_L1 = "l1" SPARSE_PRUNING_METHOD_TOPK = "topk" +SPARSE_PRUNING_METHOD_SNIP_MOMENTUM = "snip_momentum" + +SPARSE_PRUNING_BLOCK_PATTERN = "block_pattern" +SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT = "4x1" + +SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE = "schedule_offset_stride" +SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT = 1 SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000 +SPARSE_PRUNING_SCHEDULE_OFFSET_END = TECHNIQUE_SCHEDULE_OFFSET_END +SPARSE_PRUNING_SCHEDULE_OFFSET_END_DEFAULT = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT + SPARSE_PRUNING_DENSE_RATIO = "dense_ratio" +SPARSE_PRUNING_DENSE_RATIO_DEFAULT = 0.1 + +SPARSE_PRUNING_EXCLUDED_MODULES = "excluded_modules" +SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT = [] ### # Row Pruning ### diff --git a/deepspeed/compression/helper.py b/deepspeed/compression/helper.py index fdca916e9f15..4cefb7d3be00 100644 --- a/deepspeed/compression/helper.py +++ b/deepspeed/compression/helper.py @@ -6,6 +6,12 @@ import torch from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress from .constants import * +from deepspeed.utils import logger + +try: + from neural_compressor.compression import pruner as nc_pruner +except ImportError as e: + nc_pruner = None def recursive_getattr(model, module_name): @@ -131,7 +137,7 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None) else: new_module = None - if compression_technique is not None: + if compression_technique is not None and new_module is not None: for k, v in compression_technique.items(): if k == SPARSE_PRUNING: if v[SPARSE_PRUNING_ENABLED]: @@ -176,13 +182,13 @@ def is_module_compressible(module, mpu=None): return ret -def compression_preparation(model, compression_techinique_list, mpu): +def compression_preparation(model, compression_technique_list, mpu): """ Prepare the compression techniques of a model. Args: model (`torch.nn.Module`) The model to prepare the compression techniques of. - compression_techinique_list (`list`) + compression_technique_list (`list`) The list of compression techniques to prepare the model to. list[] """ @@ -190,7 +196,7 @@ def compression_preparation(model, compression_techinique_list, mpu): for module_name, module in model.named_modules(): if is_module_compressible(module, mpu): module_replacement(model, module_name, mpu=mpu) - for module_name_lists, _, compression_technique in compression_techinique_list: + for module_name_lists, _, compression_technique in compression_technique_list: for mnl in module_name_lists: for module_name in mnl: module_replacement(model, module_name, compression_technique) @@ -246,3 +252,71 @@ def convert_conv1d_to_linear(model, convert_type): recursive_setattr(c_model, name, new_module) return model + + +def generate_pruners(config, model): + """Generate pruners. + Args: + config (`neural_compressor.WeightPruningConfig`) + The object to the class WeightPruningConfig. + model (`torch.nn.module`) + The torch module object to be pruned. + """ + assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning" + from nc_pruner.utils import process_config, parse_to_prune + from nc_pruner.pruners import get_pruner + assert isinstance(model, torch.nn.Module) + pruners_info = process_config(config) + pruners = [] + for info in pruners_info: + modules = parse_to_prune(info, model) + if modules == {}: + logger.warning("one pruner hooks no layers, please have a check") + + pruners.append(get_pruner(info, modules)) + info['modules'] = [key for key in modules.keys()] + info['len_of_modules'] = len(info['modules']) + logger.info(info) + return pruners + + +def register_on_step_begin(model): + """Mount on_step_begin to the model. + Args: + model (`torch.nn.module`) + The torch module object to be pruned. + """ + + def hook(module, input): + for pruner in module.pruners: + pruner.on_step_begin(0) + + hook_handle = model.register_forward_pre_hook(hook) + return hook_handle + + +def rewrite_optimizer_step(opt: torch.optim.Optimizer): + """Mount on_before/after_optimizer_step to the optimizer. + Args: + model (`torch.opt.Optimizer`) + The torch optimizer object to be hooked. + """ + + def new_step(self, closure=None): + if hasattr(self, "pruners"): + for pruner in self.pruners: + pruner.on_before_optimizer_step() + + if closure is not None: + res = self.orig_step(closure) + else: + res = self.orig_step() + if hasattr(self, "pruners"): + for pruner in self.pruners: + pruner.on_after_optimizer_step() + return res + + opt.orig_step = opt.step + import types + opt.step = types.MethodType(new_step, opt) + return opt diff --git a/deepspeed/compression/scheduler.py b/deepspeed/compression/scheduler.py index 582ecd8f6f5e..85fdb67f642f 100644 --- a/deepspeed/compression/scheduler.py +++ b/deepspeed/compression/scheduler.py @@ -100,7 +100,8 @@ def check_sparse_pruning(self): return else: shared_parameters = sp[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: + if shared_parameters[TECHNIQUE_SCHEDULE_OFFSET] <= self.training_steps <= shared_parameters[ + TECHNIQUE_SCHEDULE_OFFSET_END]: for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]: for module_name in module_name_list: module = recursive_getattr(self.model, module_name) diff --git a/deepspeed/compression/utils.py b/deepspeed/compression/utils.py index 3534f994cd78..481e833bdf8c 100644 --- a/deepspeed/compression/utils.py +++ b/deepspeed/compression/utils.py @@ -72,7 +72,7 @@ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): The input which needs to be quantized num_bits (int, >=4) Number of bits to use for quantization - min_value/max_vlue (torch.FloatTensor) + min_value/max_value (torch.FloatTensor) Used for static activation quantization num_groups (int) How many groups to partition the quantization into @@ -114,7 +114,7 @@ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): The input which needs to be quantized num_bits (int, >=4) Number of bits to use for quantization - min_value/max_vlue (torch.FloatTensor) + min_value/max_value (torch.FloatTensor) Used for static activation quantization num_groups (int) How many groups to partition the quantization into @@ -158,7 +158,7 @@ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): The input which needs to be quantized num_bits (int) Dummy variable - min_value/max_vlue (torch.FloatTensor) + min_value/max_value (torch.FloatTensor) Used for static activation quantization; for now they are dummy variable num_groups (int) How many groups to partition the quantization into @@ -199,7 +199,7 @@ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): The input which needs to be quantized num_bits (int) Dummy variable - min_value/max_vlue (torch.FloatTensor) + min_value/max_value (torch.FloatTensor) Used for static activation quantization; for now they are dummy variable num_groups (int) How many groups to partition the quantization into diff --git a/deepspeed/constants.py b/deepspeed/constants.py index 7ebc8f9983a5..8378e8f2264f 100644 --- a/deepspeed/constants.py +++ b/deepspeed/constants.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import os from datetime import timedelta ############################################# @@ -15,6 +16,10 @@ # (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). # To make an attempt at backwards compatibility with THD, we use an # extraordinarily high default timeout, given that THD did not have timeouts. -default_pg_timeout = timedelta(minutes=30) +default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30))) INFERENCE_GENERIC_MODE = 'generic' INFERENCE_SPECIALIZED_MODE = 'specialized' + +CROSS_RANK = "CROSS_RANK" +CROSS_SIZE = "CROSS_SIZE" +LOCAL_RANK = 'LOCAL_RANK' diff --git a/deepspeed/datastates/README.md b/deepspeed/datastates/README.md new file mode 100644 index 000000000000..d946bd6cf875 --- /dev/null +++ b/deepspeed/datastates/README.md @@ -0,0 +1,12 @@ +# DataStates-LLM checkpointing engine. + +This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md). + +``` +{ + ... other deepspeed config options, + "datastates_ckpt": { + "host_cache_size": 16 + } +} +``` diff --git a/deepspeed/datastates/__init__.py b/deepspeed/datastates/__init__.py new file mode 100644 index 000000000000..0e3e65c9cce4 --- /dev/null +++ b/deepspeed/datastates/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory. + +# DeepSpeed Team diff --git a/deepspeed/datastates/config.py b/deepspeed/datastates/config.py new file mode 100644 index 000000000000..7ec0d3c957c2 --- /dev/null +++ b/deepspeed/datastates/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory. + +# DeepSpeed Team + +from deepspeed.runtime.config_utils import DeepSpeedConfigObject +import copy + +DATASTATES_CHECKPOINTING = "datastates_ckpt" +DATASTATES_CHECKPOINTING_ENABLED = False + + +class DeepSpeedDataStatesConfig(DeepSpeedConfigObject): + + def __init__(self, param_dict): + super(DeepSpeedDataStatesConfig, self).__init__() + + self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False + self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None)) diff --git a/deepspeed/elasticity/config.py b/deepspeed/elasticity/config.py index 9c574d3537c8..7c6bd42cdfd9 100644 --- a/deepspeed/elasticity/config.py +++ b/deepspeed/elasticity/config.py @@ -84,7 +84,7 @@ def __init__(self, param_dict): raise ElasticityConfigError("Elasticity min_gpus cannot be greater than max_gpus, " f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}") - self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT) + self.model_parallel_size = param_dict.get(MODEL_PARALLEL_SIZE, MODEL_PARALLEL_SIZE_DEFAULT) if self.model_parallel_size < 1: raise ElasticityConfigError("Model-Parallel size cannot be less than 1, " f"given model-parallel size: {self.model_parallel_size}") diff --git a/deepspeed/elasticity/constants.py b/deepspeed/elasticity/constants.py index 85a38c749b2a..b3134c54b4d6 100644 --- a/deepspeed/elasticity/constants.py +++ b/deepspeed/elasticity/constants.py @@ -50,8 +50,8 @@ NUM_GPUS_PER_NODE = 'num_gpus_per_node' NUM_GPUS_PER_NODE_DEFAULT = 1 -MODEL_PARLLEL_SIZE = "model_parallel_size" -MODEL_PARLLEL_SIZE_DEFAULT = 1 +MODEL_PARALLEL_SIZE = "model_parallel_size" +MODEL_PARALLEL_SIZE_DEFAULT = 1 # Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown MIN_TIME = "min_time" diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index 91a217cb9971..8fd4293d312c 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -6,7 +6,7 @@ from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent from typing import Any, Dict, Optional, Tuple from datetime import datetime -from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port +from torch.distributed.elastic.utils.distributed import get_free_port from torch.distributed.elastic.metrics import put_metric from torch.distributed.elastic.agent.server.api import ( RunResult, @@ -24,6 +24,10 @@ from contextlib import closing import subprocess +from torch.distributed.elastic.utils.logging import get_logger + +log = get_logger(__name__) + class DSElasticAgent(LocalElasticAgent): @@ -39,15 +43,20 @@ def __init__( self.ds_env = env @staticmethod - def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]): + def _set_master_addr_port(store: Store, + master_addr: Optional[str], + master_port: Optional[int], + local_addr: Optional[str] = None): if master_port is None: - sock = _get_socket_with_port() + sock = get_free_port() with closing(sock): master_port = sock.getsockname()[1] if master_addr is None: # master_addr = _get_fq_hostname() - result = subprocess.check_output("hostname -I", shell=True) + import shlex + safe_cmd = shlex.split("hostname -I") + result = subprocess.check_output(safe_cmd) master_addr = result.decode('utf-8').split()[0] store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8")) @@ -151,8 +160,8 @@ def _invoke_run(self, role: str = "default") -> RunResult: f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") self._exit_barrier() return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED - } or len(participants) > len(rdzv_handler._state_holder.state.participants): + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( + rdzv_handler._state_holder.state.participants): if self._remaining_restarts > 0: log.info(f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" diff --git a/deepspeed/elasticity/elasticity.py b/deepspeed/elasticity/elasticity.py index 730f3acdc6a5..f26f13c98dfa 100644 --- a/deepspeed/elasticity/elasticity.py +++ b/deepspeed/elasticity/elasticity.py @@ -44,7 +44,7 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): if batch_size % micro_batch == 0: max_gpus = batch_size // micro_batch - if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus: + if min_valid_gpus <= max_gpus <= max_valid_gpus: valid_gpus.append(max_gpus) # find all factors less than max_gpus / 2 @@ -148,7 +148,7 @@ def get_microbatch(final_batch_size): for micro_batch in micro_batches: if final_batch_size // current_num_gpus % micro_batch == 0: - if candidate_microbatch == None: + if candidate_microbatch is None: candidate_microbatch = micro_batch if prefer_larger and candidate_microbatch < micro_batch: candidate_microbatch = micro_batch diff --git a/deepspeed/elasticity/utils.py b/deepspeed/elasticity/utils.py index 43f2cdb2918c..78ae0352cf6e 100644 --- a/deepspeed/elasticity/utils.py +++ b/deepspeed/elasticity/utils.py @@ -3,7 +3,7 @@ # DeepSpeed Team -import torch +from deepspeed.utils.torch import required_torch_version def is_torch_elastic_compatible(): @@ -11,9 +11,4 @@ def is_torch_elastic_compatible(): Helper to lookup torch version. Elastic training is introduced in 1.11.x ''' - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - if TORCH_MAJOR == 1 and TORCH_MINOR >= 11: - return True - else: - return False + return required_torch_version(min_version=1.11) diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py index a42abe0ac874..cd6dd1cd898a 100644 --- a/deepspeed/env_report.py +++ b/deepspeed/env_report.py @@ -3,12 +3,13 @@ # DeepSpeed Team +import os import torch import deepspeed import subprocess import argparse from .ops.op_builder.all_ops import ALL_OPS -from .git_version_info import installed_ops, torch_info +from .git_version_info import installed_ops, torch_info, accelerator_name from deepspeed.accelerator import get_accelerator GREEN = '\033[92m' @@ -27,6 +28,7 @@ def op_report(verbose=True): + from torch.utils.cpp_extension import is_ninja_available max_dots = 23 max_dots2 = 11 h = ["op name", "installed", "compatible"] @@ -40,7 +42,7 @@ def op_report(verbose=True): print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) print("JIT compiled ops requires ninja") - ninja_status = OKAY if ninja_installed() else FAIL + ninja_status = OKAY if is_ninja_available() else FAIL print('ninja', "." * (max_dots - 5), ninja_status) print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2]) @@ -50,20 +52,13 @@ def op_report(verbose=True): for op_name, builder in ALL_OPS.items(): dots = "." * (max_dots - len(op_name)) is_compatible = OKAY if builder.is_compatible(verbose) else no - is_installed = installed if installed_ops[op_name] else no + is_installed = installed if installed_ops.get(op_name, + False) and accelerator_name == get_accelerator()._name else no dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len)) print(op_name, dots, is_installed, dots2, is_compatible) print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) -def ninja_installed(): - try: - import ninja # noqa: F401 - except ImportError: - return False - return True - - def nvcc_version(): import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME @@ -79,6 +74,61 @@ def nvcc_version(): return ".".join(release) +def installed_cann_path(): + if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]): + return os.environ["ASCEND_HOME_PATH"] + return None + + +def installed_cann_version(): + import re + ascend_path = installed_cann_path() + if ascend_path is None: + return "CANN_HOME does not exist, unable to compile NPU op(s)" + cann_version = "" + for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)): + if cann_version: + break + install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)] + if install_files: + filepath = os.path.join(dirpath, install_files[0]) + with open(filepath, "r") as f: + for line in f: + if line.find("version") != -1: + cann_version = line.strip().split("=")[-1] + break + return cann_version + + +def get_shm_size(): + try: + shm_stats = os.statvfs('/dev/shm') + except (OSError, FileNotFoundError, ValueError, AttributeError): + return "UNKNOWN", None + + shm_size = shm_stats.f_frsize * shm_stats.f_blocks + shm_hbytes = human_readable_size(shm_size) + warn = [] + if shm_size < 512 * 1024**2: + warn.append( + f" {YELLOW} [WARNING] /dev/shm size might be too small, if running in docker increase to at least --shm-size='1gb' {END}" + ) + if get_accelerator().communication_backend_name() == "nccl": + warn.append( + f" {YELLOW} [WARNING] see more details about NCCL requirements: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#sharing-data {END}" + ) + return shm_hbytes, warn + + +def human_readable_size(size): + units = ['B', 'KB', 'MB', 'GB', 'TB'] + i = 0 + while size >= 1024 and i < len(units) - 1: + size /= 1024 + i += 1 + return f'{size:.2f} {units[i]}' + + def debug_report(): max_dots = 33 @@ -92,12 +142,25 @@ def debug_report(): ("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " + (f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}")) ]) + elif get_accelerator().device_name() == 'npu': + import torch_npu + report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']}"), + ("torch_npu install path", torch_npu.__path__), ("torch_npu version", torch_npu.__version__), + ("ascend_cann version", installed_cann_version())]) else: report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")]) + report.append(("shared memory (/dev/shm) size", get_shm_size())) + print("DeepSpeed general environment info:") for name, value in report: + warns = [] + if isinstance(value, tuple): + value, warns = value print(name, "." * (max_dots - len(name)), value) + if warns: + for warn in warns: + print(warn) def parse_arguments(): diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index 6ca0996bb592..70c536d2f78e 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -5,7 +5,7 @@ try: # This is populated by setup.py - from .git_version_info_installed import * # noqa: F401 + from .git_version_info_installed import * # noqa: F401 # type: ignore except ModuleNotFoundError: import os if os.path.isfile('version.txt'): @@ -18,5 +18,14 @@ from .ops.op_builder.all_ops import ALL_OPS installed_ops = dict.fromkeys(ALL_OPS.keys(), False) - compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) + accelerator_name = "" torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"} + +# compatible_ops list is recreated for each launch +from .ops.op_builder.all_ops import ALL_OPS + +compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + compatible_ops[op_name] = op_compatible + compatible_ops["deepspeed_not_implemented"] = False diff --git a/deepspeed/inference/__init__.py b/deepspeed/inference/__init__.py index 0fc748f4e167..cdd00fec935b 100644 --- a/deepspeed/inference/__init__.py +++ b/deepspeed/inference/__init__.py @@ -2,5 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - -from .engine import InferenceEngine +from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig +from .v2.engine_v2 import InferenceEngineV2 +from .v2 import build_hf_engine, build_engine_from_ds_checkpoint diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 70a67c062ad2..6df61f7c8841 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -4,41 +4,26 @@ # DeepSpeed Team import torch +import deepspeed +from pydantic import Field, field_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from pydantic import Field -from pydantic import validator -from typing import Dict, Union +from typing import Dict, Union, Optional from enum import Enum class DtypeEnum(Enum): - # The torch dtype must always be the first value (so we return torch.dtype) - fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" - fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" - int8 = torch.int8, "torch.int8", "int8" - - # bf16 not supported - # bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16" - - # Copied from https://stackoverflow.com/a/43210118 - # Allows us to use multiple values for each Enum index and returns first - # listed value when Enum is called - def __new__(cls, *values): - obj = object.__new__(cls) - # first value is canonical value - obj._value_ = values[0] - for other_value in values[1:]: - cls._value2member_map_[other_value] = obj - obj._all_values = values - return obj - - def __repr__(self): - return "<%s.%s: %s>" % ( - self.__class__.__name__, - self._name_, - ", ".join([repr(v) for v in self._all_values]), - ) + fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half") + fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float") + bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat") + int8 = (torch.int8, "torch.int8", "int8") + + @classmethod + def from_str(cls, value: str): + for dtype in cls: + if value in dtype.value: + return dtype + raise ValueError(f"'{value}' is not a valid DtypeEnum") class MoETypeEnum(str, Enum): @@ -55,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): tp_size: int = 1 """ Number of devices to split the model across using tensor parallelism. """ + tp_grain_size: int = 64 + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + mpu: object = None """ A model parallelism unit object that implements @@ -93,22 +81,24 @@ class QuantTypeEnum(str, Enum): class BaseQuantConfig(DeepSpeedConfigModel): - enabled = True - num_bits = 8 + enabled: bool = True + num_bits: int = 8 q_type: QuantTypeEnum = QuantTypeEnum.sym q_groups: int = 1 class WeightQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True + quantized_initialization: Dict = {} + post_init_quant: Dict = {} class ActivationQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True class QKVQuantConfig(DeepSpeedConfigModel): - enabled = True + enabled: bool = True class QuantizationConfig(DeepSpeedConfigModel): @@ -120,9 +110,9 @@ class QuantizationConfig(DeepSpeedConfigModel): # todo: brainstorm on how to do ckpt loading for DS inference class InferenceCheckpointConfig(DeepSpeedConfigModel): - checkpoint_dir: str = None - save_mp_checkpoint_path: str = None - base_dir: str = None + checkpoint_dir: Optional[str] = None + save_mp_checkpoint_path: Optional[str] = None + base_dir: Optional[str] = None class DeepSpeedInferenceConfig(DeepSpeedConfigModel): @@ -136,7 +126,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): `(attention_output projection, transformer output projection)` """ - dtype: DtypeEnum = torch.float16 + dtype: torch.dtype = torch.float16 """ Desired model data type, will convert model to this type. Supported target types: `torch.half`, `torch.int8`, `torch.float` @@ -154,6 +144,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): can run faster using the graph replay method. """ + use_triton: bool = False + """ + Use this flag to use triton kernels for inference ops. + """ + + triton_autotune: bool = False + """ + Use this flag to enable triton autotuning. + Turning it on is better for performance but increase the 1st runtime for + autotuning. + """ + zero: DeepSpeedZeroConfig = {} """ ZeRO configuration to use with the Inference Engine. Expects a dictionary @@ -172,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): values for :any:`DeepSpeedMoEConfig`. """ + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In very large models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + Set only for models with injection policies and auto TP. + """ + quant: QuantizationConfig = {} """ NOTE: only works for int8 dtype. @@ -186,12 +197,12 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): """ #todo: refactor the following 3 into the new checkpoint_config - checkpoint: str = None + checkpoint: Optional[Union[str, Dict]] = None """ Path to deepspeed compatible checkpoint or path to JSON with load policy. """ - base_dir: str = None + base_dir: str = "" """ This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too. @@ -202,7 +213,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): specifying whether the inference-module is created with empty or real Tensor """ - save_mp_checkpoint_path: str = None + save_mp_checkpoint_path: Optional[str] = None """ The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the parallelism degree to help alleviate the @@ -231,19 +242,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): replace_method: str = Field( "auto", - deprecated=True, - deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference") + json_schema_extra={ + "deprecated": True, + "deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference" + }) - injection_policy: Dict = Field(None, alias="injection_dict") + injection_policy: Optional[Dict] = Field(None, alias="injection_dict") """ Dictionary mapping a client nn.Module to its corresponding injection policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}` """ - injection_policy_tuple: tuple = None + injection_policy_tuple: Optional[tuple] = None """ TODO: Add docs """ - config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor + config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor max_out_tokens: int = Field(1024, alias="max_tokens") """ @@ -262,25 +275,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): transposed_mode: bool = Field(False, alias="transposed_mode") - mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") + mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"}) """ Desired model parallel size, default is 1 meaning no model parallelism. Deprecated, please use the ``tensor_parallel` config to control model parallelism. """ - mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") - ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") - ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group") - ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group") - moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts") - moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type") - - @validator("moe") + mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"}) + ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"}) + ep_group: object = Field(None, + alias="expert_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_group" + }) + ep_mp_group: object = Field(None, + alias="expert_mp_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_mp_group" + }) + moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"}) + moe_type: MoETypeEnum = Field(MoETypeEnum.standard, + json_schema_extra={ + "deprecated": True, + "new_param": "moe.type" + }) + + @field_validator("dtype", mode="before") + def validate_dtype(cls, field_value, values): + if isinstance(field_value, str): + return DtypeEnum.from_str(field_value).value[0] + if isinstance(field_value, torch.dtype): + return field_value + raise TypeError(f"Invalid type for dtype: {type(field_value)}") + + @field_validator("moe") def moe_backward_compat(cls, field_value, values): if isinstance(field_value, bool): return DeepSpeedMoEConfig(moe=field_value) return field_value - class Config: - # Get the str representation of the datatype for serialization - json_encoders = {torch.dtype: lambda x: str(x)} + @field_validator("use_triton") + def has_triton(cls, field_value, values): + if field_value and not deepspeed.HAS_TRITON: + raise ValueError('Triton needs to be installed to use deepspeed with triton kernels') + return field_value diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index de7ca5a71197..7e78a6b060fb 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -6,7 +6,7 @@ import torch import time import os - +import deepspeed from deepspeed import comm as dist from deepspeed.utils.logging import log_dist @@ -14,7 +14,7 @@ from packaging import version as pkg_version from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.timer import SynchronizedWallClockTimer - +from deepspeed.runtime.compiler import is_compile_supported from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject import replace_transformer_layer, generic_injection @@ -27,6 +27,9 @@ from ..module_inject.auto_tp import AutoTP from ..module_inject.replace_policy import generic_policies +from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask +from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention +from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference DS_INFERENCE_ENABLED = False from torch import nn @@ -34,58 +37,6 @@ INFERENCE_MODEL_TIMER = "model-forward-inference" -def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: - """ - Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it - relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value - `softmax(l+a) = softmax(l)`. Based on - https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. - - Args: - Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) - attention_mask (`torch.Tensor`): - Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - num_heads (`int`, *required*): - number of heads - dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): - dtype of the output tensor - """ - import math - batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2**math.floor(math.log2(num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != num_heads: - extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 - arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] - alibi = slopes[..., None] * arange_tensor - if dist.is_initialized(): - num_heads_per_rank = int(num_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank - alibi = alibi.view(batch_size, num_heads, 1, seq_length) - alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] - return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) - else: - return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) - - class InferenceEngine(Module): inference_mp_group = None inference_ep_group = None @@ -101,6 +52,8 @@ def __init__(self, model, config): DS_INFERENCE_ENABLED = True super().__init__() + if DeepSpeedTransformerInference.workspace is not None: + self.destroy() self.module = model self._config = config @@ -114,6 +67,10 @@ def __init__(self, model, config): if hasattr(self.module, "config"): TransformerPolicy.hf_model_config = self.module.config + if config.dtype not in get_accelerator().supported_dtypes(): + raise ValueError( + f"Data type {config.dtype} is not supported by {get_accelerator().device_name()} accelerator") + # todo: keep this self.injection_dict because we don't use to change config.injection_policy API # todo: this will get changed when Molly's PR on auto injection dict is merged self.injection_dict = config.injection_policy @@ -122,7 +79,6 @@ def __init__(self, model, config): self.mp_group = config.tensor_parallel.tp_group self.mpu = config.tensor_parallel.mpu - #self._validate_args(self.mpu, config.replace_with_kernel_inject) self.quantize_merge_count = 1 self.quantization_scales = None @@ -146,14 +102,12 @@ def __init__(self, model, config): # This is a hack to redefine the alibi func due to TP if config.tensor_parallel.tp_size > 1: self.build_alibi_tensor() + self.build_attn_bias() if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" - if config.checkpoint and not config.replace_with_kernel_inject: - self._load_checkpoint(config.checkpoint) - # convert model to intended dtype if config.dtype: self._convert_to_dtype(config) @@ -171,28 +125,34 @@ def __init__(self, model, config): moe = False if moe and dist.get_world_size() > 1: - self._create_ep_parallel_group(config.moe.moe_experts) - - # retain this from the old conditional argument being passed to apply_injection_policy() - if not config.replace_with_kernel_inject: - config.checkpoint = None + self._create_ep_parallel_group(config.moe.ep_size) - # We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism. + # We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism if tp_size > 1. if self.injection_dict: # 1. User specified Tensor Parallelism assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection" for client_module, injection_policy in self.injection_dict.items(): + + assert issubclass(client_module, + torch.nn.Module), f"{client_module} is not a subclass of torch.nn.Module" + # construct the tuple and pass that instead of a string or dict. if isinstance(injection_policy, str): config.injection_policy_tuple = (injection_policy, ) else: config.injection_policy_tuple = injection_policy + + layer_names = [name for name, _ in self.module.named_modules()] + for policy in config.injection_policy_tuple: + if not any(name.endswith(policy) for name in layer_names): + raise ValueError(f"Injection policy layer'{policy}' not valid.") + self._apply_injection_policy(config, client_module) else: if config.replace_with_kernel_inject: # 2. DeepSpeed Kernel Injection self._apply_injection_policy(config) - else: + elif config.tensor_parallel.tp_size > 1: # 3. Automatic Tensor Parallelism parser_dict = AutoTP.tp_parser(model) print("AutoTP: ", parser_dict) @@ -204,7 +164,12 @@ def __init__(self, model, config): self._apply_injection_policy(config, client_module) device = get_accelerator().current_device_name() - self.module.to(device) + # NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type + is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' + if is_meta_device: + self.module.to_empty(device=device) + elif not config.keep_module_on_host: + self.module.to(device) if config.tensor_parallel.tp_size > 1: _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) @@ -216,6 +181,14 @@ def __init__(self, model, config): # Check if local CUDA graphs can be created in replacement modules self.local_cuda_graph = self._local_cuda_graph_used(self.module) + self._is_compiled = False + + def destroy(self): + DeepSpeedTransformerInference.layer_id = 0 + DeepSpeedSelfAttention.num_layers = 0 + if DeepSpeedTransformerInference.workspace.is_allocated(): + DeepSpeedTransformerInference.workspace.release_workspace() + DeepSpeedTransformerInference.workspace = None def profile_model_time(self, use_cuda_events=True): if not self.model_profile_enabled and not self._config.enable_cuda_graph: @@ -240,6 +213,19 @@ def build_alibi_tensor(self): if hasattr(self.module, 'transformer'): if hasattr(self.module.transformer, 'build_alibi_tensor'): self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor + if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'): + self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor + self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor + if hasattr(self.module, 'model'): + if hasattr(self.module.model, 'get_alibi_mask'): + self.module.model.get_alibi_mask_orig = self.module.model.get_alibi_mask + self.module.model.__class__.get_alibi_mask = get_alibi_mask + + def build_attn_bias(self): + if hasattr(self.module, 'transformer'): + if hasattr(self.module.transformer, '_attn_bias'): + self.module.transformer._attn_bias_orig = self.module.transformer._attn_bias + self.module.transformer.__class__._attn_bias = build_mpt_atten_bias_tensor def _pre_forward_hook(self, module, *inputs, **kwargs): if self.use_cuda_events: @@ -255,7 +241,7 @@ def _post_forward_hook(self, module, input, output): else: get_accelerator().synchronize() self._end = time.time() - elapsed_time = self._end - self._start + elapsed_time = (self._end - self._start) * 1e3 # convert seconds to ms self._model_times.append(elapsed_time) def _create_model_parallel_group(self, config): @@ -281,6 +267,8 @@ def _create_ep_parallel_group(self, moe_experts): self.expert_mp_group.update({e: None}) for moe_ep_size in self.ep_group.keys(): num_ep_groups = dist.get_world_size() // moe_ep_size + if num_ep_groups == 0: + raise ValueError(f"Invalid ep_size={moe_ep_size} for world_size={dist.get_world_size()}") for i in range(num_ep_groups): ep_cnt = i * moe_ep_size size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size @@ -312,29 +300,6 @@ def _init_quantization_setting(self, quantization_setting): f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"quantize_groups = {self.quantize_groups}", [0]) - # TODO: remove this function and add this functionality to pydantic config checking - def _validate_args(self, mpu, replace_with_kernel_inject): - # TODO: to support SD pipeline we need to avoid this check for now - if replace_with_kernel_inject and not isinstance(self.module, Module): - raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") - if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1: - raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}") - - if mpu: - methods = ["get_model_parallel_group", "get_data_parallel_group"] - for method in methods: - if not hasattr(mpu, method): - raise ValueError(f"mpu is missing {method}") - if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)): - raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}") - - supported_dtypes = [None, torch.half, torch.int8, torch.float] - if self._config.dtype not in supported_dtypes: - raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}") - - if self.injection_dict is not None and not isinstance(self.injection_dict, dict): - raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}") - def load_model_with_checkpoint(self, r_module): self.mp_replace = ReplaceWithTensorSlicing( mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) @@ -343,16 +308,38 @@ def load_model_with_checkpoint(self, r_module): def load(module, state_dict, prefix): args = (state_dict, prefix, {}, True, [], [], error_msgs) if hasattr(module, 'weight'): + if module.weight.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, + device="cpu"), + requires_grad=module.weight.data.requires_grad) if 'query_key_value' in prefix: - module.weight = self.mp_replace.qkv_copy(module.weight.data, state_dict[prefix + 'weight']) + module.weight = self.mp_replace.strided_copy(module.weight.data, + state_dict[prefix + 'weight'], + num_splits=3) else: module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight']) else: + if module.norm.weight.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.norm.weight = torch.nn.parameter.Parameter( + data=torch.empty_like(module.norm.weight.data, device="cpu"), + requires_grad=module.norm.weight.data.requires_grad) module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight']) if prefix + 'bias' in self.key_list: if hasattr(module, 'norm'): + if module.norm.bias.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.norm.bias = torch.nn.parameter.Parameter( + data=torch.empty_like(module.norm.bias.data, device="cpu"), + requires_grad=module.norm.bias.data.requires_grad) module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias']) else: + if module.bias.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, + device="cpu"), + requires_grad=module.bias.data.requires_grad) data = state_dict[prefix + 'bias'] data = data.to(get_accelerator().current_device_name()) module.bias = self.mp_replace.copy(module.bias, data) @@ -381,15 +368,22 @@ def load_module_recursive(module, prefix='', level=0): load_module_recursive(r_module) + embedding_weight = None + + for n, p in r_module.named_parameters(): + if "word_embeddings." in n or "embed_tokens." in n or "wte." in n: + embedding_weight = p + if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr( + r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta: + r_module.lm_head.weight = embedding_weight + def _apply_injection_policy(self, config, client_module=None): # client_module is only passed when using the injection_dict method. checkpoint_dir = config.checkpoint checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - generic_injection(self.module, - fp16=(config.dtype == torch.half) or (config.dtype == torch.int8), - enable_cuda_graph=config.enable_cuda_graph) + generic_injection(self.module, dtype=config.dtype, enable_cuda_graph=config.enable_cuda_graph) if isinstance(self.module, torch.nn.Module): # config is our DeepSpeedInferenceConfig and self.config is the HF model config @@ -432,16 +426,18 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): else: sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine) - if type(sd_loader) is list: - self.sd = torch.load(sd_loader[0], map_location='cpu') + checkpoint = sd_loader['checkpoints'] + + if type(checkpoint) is list: + self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) - for i in range(1, len(sd_loader)): + for i in range(1, len(checkpoint)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") - self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name()) + self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: @@ -507,11 +503,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True @@ -523,7 +519,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def model_times(self): @@ -575,12 +571,13 @@ def forward(self, *inputs, **kwargs): else: self._create_cuda_graph(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs) + else: outputs = self.module(*inputs, **kwargs) if self.model_profile_enabled and self._config.enable_cuda_graph: get_accelerator().synchronize() - duration = time.time() - start + duration = (time.time() - start) * 1e3 # convert seconds to ms self._model_times.append(duration) return outputs @@ -598,6 +595,33 @@ def _generate(self, *inputs, **kwargs): if num_beams > 1: raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please " - "add your request to: https://github.com/microsoft/DeepSpeed/issues/2506") + "add your request to: https://github.com/deepspeedai/DeepSpeed/issues/2506") + + if ("input_ids" in kwargs) and (kwargs["input_ids"].dim() == 2): + for input_tensor in kwargs["input_ids"]: + tensor_length = input_tensor.shape[-1] + if tensor_length > self._config.max_out_tokens: + raise RuntimeError( + f"Input with size {tensor_length} exceeds maximum length of {self._config.max_out_tokens}. Please increase max_tokens in the DeepSpeed Inference Config." + ) return self.module.generate(*inputs, **kwargs) + + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None: + """ + Compile the module using the specified backend and kwargs. + """ + if not is_compile_supported(): + raise RuntimeError("compile is not supported in your version of PyTorch.") + + if self._is_compiled: + return + + # Avoid graph breaks + deepspeed.utils.nvtx.enable_nvtx = False + self.module.compile(backend=backend, **compile_kwargs) + self._is_compiled = True + + @property + def is_compiled(self) -> bool: + return self._is_compiled diff --git a/deepspeed/inference/quantization/__init__.py b/deepspeed/inference/quantization/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/inference/quantization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/inference/quantization/layers.py b/deepspeed/inference/quantization/layers.py new file mode 100644 index 000000000000..e9a7e5629f1b --- /dev/null +++ b/deepspeed/inference/quantization/layers.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from torch import nn +from torch import Tensor +from torch.nn import functional as F +from .utils import Quantizer, DeQuantizer, concat_to_compat_param +from typing import Tuple, Callable, Dict +from deepspeed.runtime.zero import register_external_parameter + +quantized_weight_registry = {} +is_zero3_enabled = False + + +# deal with weight sharing +def get_quantized_weight_wrapper(model, pre_quant_weight: nn.Parameter, quantize_weight_fn: Callable) -> nn.Parameter: + if id(pre_quant_weight) in quantized_weight_registry: + compat_tensor = quantized_weight_registry[id(pre_quant_weight)] + if is_zero3_enabled: + register_external_parameter(model, compat_tensor) + + return quantized_weight_registry[id(pre_quant_weight)] + else: + quantized_weights, quant_scale, quant_min = quantize_weight_fn() + quantized_weight_registry[id(pre_quant_weight)] = concat_to_compat_param(quantized_weights, quant_scale, + quant_min) + return quantized_weight_registry[id(pre_quant_weight)] + + +def get_quantize_weight_fn(quantizer: Quantizer, pre_quant_weight: nn.Parameter) -> Callable: + + def func() -> Tuple[nn.Parameter, Tensor, Tensor]: + quantized_weights, quant_scale, quant_min = quantizer.quantize(pre_quant_weight.data) + # A temporary hack as zero Zero3 assume all model weights has the same type. in all_gather_coalesced.get_only_unique_item + quantized_weights = quantized_weights.view(pre_quant_weight.dtype) + quant_scale = quant_scale.type(pre_quant_weight.dtype) + quant_min = quant_min.type(pre_quant_weight.dtype) + return quantized_weights, quant_scale, quant_min + + return func + + +class QuantizedLinear(nn.Linear): + + def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None: + super(QuantizedLinear, self).__init__(in_features=pre_quant_layer.in_features, + out_features=pre_quant_layer.out_features, + bias=pre_quant_layer.bias is not None, + device=pre_quant_layer.weight.device, + dtype=pre_quant_layer.weight.dtype) + self.config = config + + self.quantizer = Quantizer(config=config) + self.bias = pre_quant_layer.bias + self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight, + get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight)) + + self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype) + + def forward(self, input: Tensor) -> Tensor: + quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight) + temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale, + quant_min) + + # !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is + # replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary. + # If weight is temp buffer there will be memory leak. + return torch._C._nn.linear(input, temp_dequantized_weight, self.bias) + + +class QuantizedEmbedding(nn.Embedding): + + def __init__(self, config: Dict, pre_quant_layer: nn.Embedding) -> None: + super(QuantizedEmbedding, self).__init__(num_embeddings=pre_quant_layer.num_embeddings, + embedding_dim=pre_quant_layer.embedding_dim, + padding_idx=pre_quant_layer.padding_idx, + max_norm=pre_quant_layer.max_norm, + norm_type=pre_quant_layer.norm_type, + scale_grad_by_freq=pre_quant_layer.scale_grad_by_freq, + sparse=pre_quant_layer.sparse, + _weight=pre_quant_layer.weight, + device=pre_quant_layer.weight.device, + dtype=pre_quant_layer.weight.dtype) + + assert pre_quant_layer.max_norm is None, 'Not supported' + assert pre_quant_layer.norm_type == 2, 'Not supported' + assert pre_quant_layer.scale_grad_by_freq == False, 'Not supported' + assert pre_quant_layer.sparse == False, 'Not supported' + + self.config = config + quantizer = Quantizer(config=config) + + self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight, + get_quantize_weight_fn(quantizer, pre_quant_layer.weight)) + + self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype) + + def forward(self, input: Tensor) -> Tensor: + quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight) + temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale, + quant_min) + + return F.embedding(input, temp_dequantized_weight, self.padding_idx, self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.sparse) + + +QUANTIZATION_LAYER_MAPPINGS = { + nn.Linear: QuantizedLinear, + nn.Embedding: QuantizedEmbedding, +} diff --git a/deepspeed/inference/quantization/quantization.py b/deepspeed/inference/quantization/quantization.py new file mode 100644 index 000000000000..9ae39e8d5688 --- /dev/null +++ b/deepspeed/inference/quantization/quantization.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch import nn +from typing import Dict +import gc +from deepspeed.inference.quantization import layers +from .layers import QUANTIZATION_LAYER_MAPPINGS +from .utils import get_AsyncPartitionedParameterSwapper, recursive_setattr +from deepspeed.utils.logging import logger +from collections import deque +from transformers.utils.generic import ContextManagers +from .quantization_context import QuantizationContext +import contextlib + + +def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> nn.Module: + """[Experimental] Apply group-wise weight quantization to model. Replace layers module according to config_list + + Args: + model (nn.Module): A nn.Module + ds_config (Dict, optional): The ds_config dictionary. use None for non-deepspeed managed model. + + Returns: + nn.Module: Quantized nn.Module + """ + + # global quantized_weight_registry + + matched_module_list_by_key = {} + matched_module_count = 0 + + assert 'weight_quantization' in ds_config, 'Please provide quantization config in ds_config' + quantization_config = ds_config['weight_quantization']['post_init_quant'] + + # Return nvme swapper if exists, else return None. + # For nvme offloading we must use the same swapper here as model initialized. + nvme_swapper = get_AsyncPartitionedParameterSwapper(model) + is_zero3_enabled = 'zero_optimization' in ds_config and \ + 'stage' in ds_config['zero_optimization'] and \ + ds_config['zero_optimization']['stage'] == 3 + is_offloading_enabled = 'zero_optimization' in ds_config and \ + 'offload_param' in ds_config['zero_optimization'] + + layers.is_zero3_enabled = is_zero3_enabled + + context_mgr = ContextManagers([QuantizationContext(config_dict_or_path=ds_config, param_swapper=nvme_swapper)]) \ + if is_zero3_enabled else contextlib.suppress() + with context_mgr: + module_list = list( + filter(lambda named_module: type(named_module[1]) in QUANTIZATION_LAYER_MAPPINGS, model.named_modules())) + + # Quantize small weight first then large. + if not is_offloading_enabled: + module_list.sort(key=lambda named_module: named_module[1].weight.ds_tensor.numel() + if is_zero3_enabled else named_module[1].weight.numel()) + module_list = deque(module_list) + + while len(module_list) > 0: + # Use popleft to timely release module's memory of replaced module after each loop iteration + module_name, module = module_list.popleft() + + matched_key = None + matched_quantization_config = None + + for key, config in quantization_config.items(): + if key in module_name: + assert matched_key is None, f'{module_name} matched multiple quantization key word {matched_key} and {key}' + matched_key = key + matched_quantization_config = config + + if matched_key is None: + continue + + if is_zero3_enabled: + module.weight.all_gather() + + assert module.weight.dtype == torch.float16, 'Model weight is expected in half.' + + new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module) + + if is_zero3_enabled: + module.weight.partition() + + recursive_setattr(model, module_name, new_module) + + if matched_key not in matched_module_list_by_key: + matched_module_list_by_key[matched_key] = [] + matched_module_list_by_key[matched_key].append(module_name) + matched_module_count += 1 + + # Timely recycle memory to prevent OOM on large models + gc.collect() + + # Clear registry after model construction. + layers.quantized_weight_registry.clear() + + logger.info( + f'Group-wise weight quantization summary: convert {matched_module_count} node(s) to quantized implementation') + summary_str = '\n' + + for key, module_list in matched_module_list_by_key.items(): + summary_str += f'Key: {key}, matched modules:\n' + for module_name in module_list: + summary_str += f'\t{module_name}\n' + logger.info(summary_str) + + return model diff --git a/deepspeed/inference/quantization/quantization_context.py b/deepspeed/inference/quantization/quantization_context.py new file mode 100644 index 000000000000..d3333da05058 --- /dev/null +++ b/deepspeed/inference/quantization/quantization_context.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.zero import partition_parameters +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper + + +class QuantizationContext(partition_parameters.Init): + + def __init__(self, config_dict_or_path, param_swapper: AsyncPartitionedParameterSwapper = None) -> None: + super().__init__(config_dict_or_path=config_dict_or_path, param_swapper=param_swapper) diff --git a/deepspeed/inference/quantization/utils.py b/deepspeed/inference/quantization/utils.py new file mode 100644 index 000000000000..a5e8f28bdec9 --- /dev/null +++ b/deepspeed/inference/quantization/utils.py @@ -0,0 +1,288 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +from torch import Tensor +from typing import Tuple +import torch.nn as nn +from typing import Dict, Callable, Union +from deepspeed.accelerator import get_accelerator +import functools + +device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu' + +quantizer_module = None + + +def get_quantizer_module(): + global quantizer_module + if quantizer_module is None: + quantizer_module = deepspeed.ops.op_builder.QuantizerBuilder().load() + return quantizer_module + + +def tensor_clamp(tensor: Tensor, min, max) -> Tensor: + if tensor.device.type == 'cpu' and tensor.dtype == torch.float16: + # CPU does not support FP16 clamp + return tensor.to(dtype=torch.float32).clamp_(min, max).to(dtype=torch.float16) + else: + return tensor.clamp_(min, max) + + +def tensor_round(tensor: Tensor) -> Tensor: + if tensor.device.type == 'cpu' and tensor.dtype == torch.float16: + # CPU does not support FP16 round + return tensor.to(dtype=torch.float32).round_().to(dtype=torch.float16) + else: + return tensor.round_() + + +class Quantizer: + + def __init__(self, config: Dict) -> None: + self.config = config + assert self.config['num_bits'] == 4 or self.config[ + 'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.' + assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.' + + def quantize(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + assert tensor.shape[self.config['group_dim']] % self.config['group_size'] == 0 \ + , f'Tensor shape: {tensor.shape} quantization config {self.config}' + + tensor = torch.clone(tensor) + + shape = tensor.shape + num_groups = shape[self.config['group_dim']] // self.config['group_size'] + new_shape = (shape[:self.config['group_dim']] + (num_groups, self.config['group_size']) + + shape[self.config['group_dim'] + 1:]) + tensor = tensor.view(new_shape) + + quantized_tensor, scale, min_value = self._quantize_int8(tensor) + quantized_tensor = quantized_tensor.view(shape) + + if self.config['num_bits'] == 4: + return self._compress_uint8_to_uint4(quantized_tensor), scale, min_value + if self.config['num_bits'] == 8: + return quantized_tensor, scale, min_value + + assert False, 'Unsupported quantization bits {}'.format(self.config['num_bits']) + + def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + q_range = 2**self.config['num_bits'] - 1 + min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True) + max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True) + + scale = q_range / (max_value - min_value) + + tensor = tensor.sub_(min_value).mul_(scale) + tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8) + return tensor, scale, min_value + + def _compress_uint8_to_uint4(self, tensor: Tensor) -> Tensor: + assert tensor.shape[-1] % 2 == 0 + + new_data_shape = list(tensor.shape) + new_data_shape[-1] = new_data_shape[-1] // 2 + + data = torch.empty(new_data_shape, dtype=torch.uint8, device=tensor.device) + data = torch.bitwise_or(tensor[..., 0::2].bitwise_left_shift(4), tensor[..., 1::2]) + + return data + + +class DeQuantizer: + + def __init__(self, config: Dict, dtype: torch.dtype) -> None: + self.config = config + self.dtype = dtype + assert self.config['num_bits'] == 4 or self.config[ + 'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.' + assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.' + + def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor: + # Use customized CUDA quantization kernel if possible. + if self.config['group_size'] % 8 == 0 and \ + (self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \ + self.config['group_dim'] == len(tensor.shape) - 1 and \ + self.dtype == torch.float16 and device == get_accelerator().device_name(): + + last_dimension_size = self.config['group_size'] + if self.config['num_bits'] == 4: + last_dimension_size = last_dimension_size // 2 + quantized_tensor = get_quantizer_module().dequantize_int4_to_half_experimental( + tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, + tensor.numel() // last_dimension_size, self.config['group_size']) + shape = list(tensor.shape) + shape[-1] = shape[-1] * 2 + elif self.config['num_bits'] == 8: + # last_dimension_size = last_dimension_size // 2 + quantized_tensor = get_quantizer_module().dequantize_int8_to_half_experimental( + tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, + tensor.numel() // last_dimension_size, self.config['group_size']) + shape = list(tensor.shape) + + return quantized_tensor.reshape(shape) + + if self.config['num_bits'] == 4: + tensor = self._decompress_uint4_to_uint8(tensor) + elif self.config['num_bits'] != 8: + assert False, 'Unsupported quantization bits {}'.format(self.config['num_bits']) + + shape = tensor.shape + num_groups = shape[self.config['group_dim']] // self.config['group_size'] + new_shape = (shape[:self.config['group_dim']] + (num_groups, self.config['group_size']) + + shape[self.config['group_dim'] + 1:]) + tensor = tensor.view(new_shape) + + dequantized_tensor = self._dequantize_int8(tensor, quant_scale, quant_min).view(shape) + return dequantized_tensor + + def _dequantize_int8(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor: + assert tensor.dtype == torch.uint8 + data = torch.zeros_like(tensor, dtype=self.dtype, device=tensor.device) + data = data.copy_(tensor) + data = data.div_(quant_scale).add_(quant_min) + + return data + + def _decompress_uint4_to_uint8(self, tensor: Tensor) -> Tensor: + new_data_shape = list(tensor.shape) + new_data_shape[-1] = new_data_shape[-1] * 2 + data = torch.empty(new_data_shape, dtype=torch.uint8, device=tensor.device) + data[..., 0::2] = tensor.bitwise_right_shift(4) + data[..., 1::2] = tensor.bitwise_and(0xF) + + return data + + +def get_AsyncPartitionedParameterSwapper(model: nn.Module): + for param_name, param in model.named_parameters(): + if hasattr(param, 'nvme_swapper') and param.nvme_swapper is not None: + return param.nvme_swapper + return None + + +def recursive_setattr(model, module_name, module): + """ + Recursively set the attribute of a module. + Args: + model (`torch.nn.Module`) + The model to set the attribute in. + module_name (`str`) + The name of the module to set the attribute in. + module (`torch.nn.Module`) + The module to set the attribute to. + """ + split_list = module_name.split('.') + output = model + for name in split_list[:-1]: + output = getattr(output, name) + output.__setattr__(split_list[-1], module) + + +def concat_to_compat_param(quantized_weight: Tensor, + quant_scale: Tensor, + quant_min: Tensor, + return_param: bool = True) -> Union[nn.Parameter, Tensor]: + shape_wieght = quantized_weight.shape + shape_scale = quant_scale.shape + shape_min = quant_min.shape + + quantized_weight = torch.flatten(quantized_weight) + quant_scale = torch.flatten(quant_scale) + quant_min = torch.flatten(quant_min) + + def deconcat_individual_tensors(shape_wieght: torch.Size, shape_scale: torch.Size, + shape_min: torch.Size) -> Callable: + + def fn(compat_tensor: nn.Parameter) -> Tuple[Tensor, Tensor, Tensor]: + weight = torch.narrow(compat_tensor, 0, 0, shape_wieght.numel()).view(shape_wieght) + scale = torch.narrow(compat_tensor, 0, shape_wieght.numel(), shape_scale.numel()).view(shape_scale) + min_val = torch.narrow(compat_tensor, 0, + shape_wieght.numel() + shape_scale.numel(), shape_min.numel()).view(shape_min) + + return weight, scale, min_val + + return fn + + compat_tensor = torch.concat([quantized_weight, quant_scale, quant_min]) + if return_param: + compat_tensor = nn.Parameter(compat_tensor, requires_grad=False) + compat_tensor.deconcat = deconcat_individual_tensors(shape_wieght, shape_scale, shape_min) + + return compat_tensor + + +def _quantize_param(param: nn.Parameter, quant_config: Dict): + assert not hasattr(param, 'weight_quantized'), 'Parameter has already been quantized.' + quantizer = Quantizer(quant_config) + dequantizer = DeQuantizer(quant_config, param.dtype) + + quantized_weight, quant_scale, quant_min = quantizer.quantize(param.data) + + quantized_weight = quantized_weight.view(param.dtype) + quant_scale = quant_scale.view(param.dtype) + quant_min = quant_min.view(param.dtype) + + quantized_compat_tensor = concat_to_compat_param(quantized_weight, quant_scale, quant_min) + param.data = quantized_compat_tensor + param.deconcat = quantized_compat_tensor.deconcat + + param.quantizer = quantizer + param.dequantizer = dequantizer + setattr(param, 'weight_quantized', True) + + +def wrap_quantized_functional(f): + + @functools.wraps(f) + def wrapper(input: Tensor, weight: nn.Parameter, *args, **kwargs) -> Tensor: + if hasattr(weight, 'weight_quantized') and getattr(weight, 'weight_quantized'): + quantized_weight, quant_scale, quant_min = weight.deconcat(weight) + temp_dequantized_weight = weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale, + quant_min) + return f(input, temp_dequantized_weight, *args, **kwargs) + else: + return f(input, weight, *args, **kwargs) + + return wrapper + + +def wrap_load_from_state_dict(f): + + @functools.wraps(f) + def wrapper(model, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + replaced_old_value = None + key = None + # We may have nested wrappers if we launch multiple initialization context. + # Use state_dict_quantized flag to quantize state_dict only once + if hasattr(model.weight, 'weight_quantized') and getattr( + model.weight, 'weight_quantized') and not hasattr(model.weight, 'state_dict_quantized'): + setattr(model.weight, 'state_dict_quantized', True) + key = prefix + 'weight' + if key in state_dict: + quantized_weight, quant_scale, quant_min = model.weight.quantizer.quantize(state_dict[key]) + quantized_weight = quantized_weight.view(model.weight.dtype) + quant_scale = quant_scale.view(model.weight.dtype) + quant_min = quant_min.view(model.weight.dtype) + + replaced_old_value = state_dict[key] + + state_dict[key] = concat_to_compat_param(quantized_weight, quant_scale, quant_min) + + f(model, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + if replaced_old_value is not None: + state_dict[key] = replaced_old_value + delattr(model.weight, 'state_dict_quantized') + + return wrapper + + +WEIGHT_QUANTIZATION_LAYERS = ( + nn.Linear, + nn.Embedding, +) diff --git a/deepspeed/inference/v2/__init__.py b/deepspeed/inference/v2/__init__.py new file mode 100644 index 000000000000..ac8a42da8ab3 --- /dev/null +++ b/deepspeed/inference/v2/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig +from .engine_v2 import InferenceEngineV2 +from .engine_factory import build_hf_engine, build_engine_from_ds_checkpoint diff --git a/deepspeed/inference/v2/allocator.py b/deepspeed/inference/v2/allocator.py new file mode 100644 index 000000000000..fcc0d94c0f82 --- /dev/null +++ b/deepspeed/inference/v2/allocator.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from functools import reduce +from typing import Iterable +from collections import defaultdict +import torch + +from deepspeed.accelerator import get_accelerator + + +class Allocator: + cache = defaultdict(dict) + + def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: + try: + return Allocator.cache[tensor][shape] + except KeyError: + shape_size = reduce(lambda x, y: x * y, shape) + if shape_size == 0: + raise ValueError("Cannot create empty tensor with size 0") + Allocator.cache[tensor][shape] = tensor.flatten()[:shape_size].view(shape) + return Allocator.cache[tensor][shape] + + +empty_from = Allocator.empty_from + + +def on_device(method) -> torch.Tensor: + """ + Wraps a method to ensure the returned tensor is on the current device. + """ + + def wrapped(self, *args, **kwargs): + tensor = method(self, *args, **kwargs) + if isinstance(tensor, torch.Tensor): + return tensor.to(get_accelerator().current_device()) + return tensor + + return wrapped diff --git a/deepspeed/inference/v2/checkpoint/__init__.py b/deepspeed/inference/v2/checkpoint/__init__.py new file mode 100644 index 000000000000..45e523ab62b9 --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base_engine import CheckpointEngineBase +from .in_memory_engine import InMemoryModelEngine +from .huggingface_engine import HuggingFaceCheckpointEngine diff --git a/deepspeed/inference/v2/checkpoint/base_engine.py b/deepspeed/inference/v2/checkpoint/base_engine.py new file mode 100644 index 000000000000..26fc467d4d86 --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/base_engine.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import Iterable, Tuple + +import torch + +#from .huggingface_engine import HuggingFaceCheckpointEngine + +MEGATRON = 'megatron' +HUGGINGFACE = 'huggingface' + + +class CheckpointEngineBase(ABC): + """ + Abstract interface for checkpoint engines to implement. + + There is no ``__init__`` method here by design, since the creation of the checkpoint + engine will happen outside the policy/engine code. The tradeoff being made here is + that we will write different frontends for different checkpoint engines, but these + frontends can be tailored to the specific checkpoint engine/model source needs. + """ + + @abstractmethod + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + """ + This method should create a generator of tuples of the form (name, parameter) for + all parameters in the model. The name should be the fully qualified name of the + parameter, and the parameter should be a torch.Tensor. + + The expected use of a checkpoint engine is the following: + ```python + for name, parameter in checkpoint_engine.parameters(): + container_map.map_param(name, parameter) + ``` + For a concrete use example, see ``InferenceV2Policy``. + """ + ... diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py new file mode 100644 index 000000000000..b17bb886838f --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import torch +from .base_engine import CheckpointEngineBase +from typing import Iterable, Tuple +from functools import partial + +from ..logging import inference_logger + + +class HuggingFaceCheckpointEngine(CheckpointEngineBase): + + def __init__(self, model_name_or_path: str, auth_token: str = None, **hf_kwargs) -> None: + super().__init__() + from transformers import AutoConfig, GenerationConfig + + self.model_name_or_path = model_name_or_path + self.auth_token = auth_token + self.model_config = AutoConfig.from_pretrained(self.model_name_or_path, **hf_kwargs) + # Define this property here so we can use it in the model implementation + if not hasattr(self.model_config, "max_seq_length"): + if hasattr(self.model_config, "max_position_embeddings"): + self.model_config.max_seq_length = self.model_config.max_position_embeddings + else: + generation_config = GenerationConfig.from_pretrained(self.model_name_or_path) + self.model_config.max_seq_length = generation_config.max_length + self._local_checkpoint_dir = None + self._all_ckpt_paths = self._fetch_checkpoint_files() + + def _fetch_checkpoint_files(self): + """ + Fetch the checkpoint files from the HuggingFace Hub. + """ + # TODO(jeff): for models like llama-2 the user will have to provide an auth `token`, + # currently coming from the ckpt engine init but maybe a catch all kwargs for other + # snapshot download parameters would be more flexible. + + from huggingface_hub import snapshot_download, list_repo_tree + + def model_has_safetensors(model_name_or_path: str) -> bool: + if os.path.isdir(model_name_or_path): + file_list = os.listdir(model_name_or_path) + else: + file_list = [rf.path for rf in list_repo_tree(model_name_or_path)] + for f in file_list: + if f.endswith(".safetensors"): + return True + return False + + if os.path.isdir(self.model_name_or_path): + self._local_checkpoint_dir = self.model_name_or_path + else: + # We need to download the checkpoint files from HF + if model_has_safetensors(self.model_name_or_path): + # Prioritize downloading safetensors if they are available + allow_patterns = ["*.safetensors", "*.json"] + else: + # Fallback to bin files when safetensors are not present + allow_patterns = ["*.bin", "*.json", "*.pt"] + self._local_checkpoint_dir = snapshot_download(self.model_name_or_path, + allow_patterns=allow_patterns, + revision=None, + token=self.auth_token) + + assert os.path.isdir( + self._local_checkpoint_dir + ), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint." + + # Set the appropriate file names based on whether we have safetensors or not + if model_has_safetensors(self._local_checkpoint_dir): + from safetensors.torch import load_file + model_param_json_fname = "model.safetensors.index.json" + model_file_fname = "model.safetensors" + self._checkpoint_load_fn = load_file + else: + model_param_json_fname = "pytorch_model.bin.index.json" + model_file_fname = "pytorch_model.bin" + self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False) + + model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname) + + if not os.path.isfile(model_param_json): + # We don't need any json as all such HF models will have pytorch_model.bin + all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, model_file_fname)] + else: + param_map = json.load(open(model_param_json, "r")) + + # weight_map -> { "lm_head.weight": "pytorch_model-00002-of-00002.bin", ... } + weight_map = param_map["weight_map"] + + # unique set of all checkpoint files + all_checkpoint_files = set(weight_map.values()) + + # get absolute path of all unique checkpoint files + all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, f) for f in all_checkpoint_files] + + return all_checkpoint_files + + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + """ + Generator of model parameters (satisfies the CheckpointEngineBase interface). + """ + for checkpoint in self._all_ckpt_paths: + inference_logger().info(f"Loading checkpoint: {checkpoint}") + checkpoint_sd = self._checkpoint_load_fn(checkpoint) + + # If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights + if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings: + if self.model_config.model_type == "qwen2": + checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"] + + param_keys = list(checkpoint_sd.keys()) + for param_name in param_keys: + param = checkpoint_sd[param_name] + yield param_name, param + + del checkpoint_sd + + +if __name__ == "__main__": + # To test, add your auth_token here and run `python huggingface_engine.py` + engine = HuggingFaceCheckpointEngine(model_name_or_path="meta-llama/Llama-2-7b-hf", + auth_token="hf_xxxxxxxxxxxxxxxxx") + for name, param in engine.parameters(): + print(name, param.shape) diff --git a/deepspeed/inference/v2/checkpoint/in_memory_engine.py b/deepspeed/inference/v2/checkpoint/in_memory_engine.py new file mode 100644 index 000000000000..13ec7b288f5f --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/in_memory_engine.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Tuple +import torch + +from .base_engine import CheckpointEngineBase + + +class InMemoryModelEngine(CheckpointEngineBase): + """ + This "checkpoint" engine uses the existing interface to enable loading parameters into an + inference model from a model already instantiated in memory. In general, this is not the + recommended way to use the inference engine, and should only be used when absolutely necessary. + + The primary limitation of this approach is that the model must be fully instantiated in memory. + In a tensor parallel scenario, this means that the model is either replicated many times in host + memory. Currently, it is also recommended to only use this approach for models held in host memory. + + In order to free the memory held by this copy of the model, we delete the model in the first call + to `parameters`, so it is not safe to make this call twice. + """ + + def __init__(self, model: torch.nn.Module) -> None: + """ + Create virtual checkpoint engine for the provided module. + + Args: + model (torch.nn.Module): Model to load parameters from. + """ + super().__init__() + self.model = model + + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + for name, parameter in self.model.named_parameters(): + yield name, parameter + + del self.model diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py new file mode 100644 index 000000000000..325b57d8f56a --- /dev/null +++ b/deepspeed/inference/v2/config_v2.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from pydantic import Field +from typing import Optional + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .ragged import DSStateManagerConfig + + +class DeepSpeedTPConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + tp_size: int = 1 + """ Number of devices to split the model across using tensor parallelism. """ + + +class QuantizationConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + quantization_mode: Optional[str] = None + """ The quantization mode in string format. The supported modes are as follows: + - 'wf6af16', weight-only quantization with FP6 weight and FP16 activation. + """ + # TODO: may reuse the constants in deepspeed/compression/constants.py + + +class RaggedInferenceEngineConfig(DeepSpeedConfigModel): + """ Sets parameters for DeepSpeed Inference Engine. """ + + tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp") + """ + Configuration for tensor parallelism used to split the model across several + GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`. + """ + + state_manager: DSStateManagerConfig = Field({}, alias="manager") + """ + Configuration for managing persistent state + """ + + quantization: QuantizationConfig = {} diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py new file mode 100644 index 000000000000..ebbb4a9767d7 --- /dev/null +++ b/deepspeed/inference/v2/engine_factory.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import json +import logging +import os +import pickle +from packaging import version + +from .engine_v2 import InferenceEngineV2 +from .config_v2 import RaggedInferenceEngineConfig +from .checkpoint import HuggingFaceCheckpointEngine +from .logging import inference_logger +from .model_implementations import ( + Exaone4Policy, + OPTPolicy, + Llama2Policy, + MistralPolicy, + MixtralPolicy, + FalconPolicy, + PhiPolicy, + Phi3Policy, + QwenPolicy, + Qwen2Policy, + Qwen2MoePolicy, +) +from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy +from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata + + +def build_engine_from_ds_checkpoint(path: str, + engine_config: RaggedInferenceEngineConfig, + debug_level: int = logging.INFO) -> InferenceEngineV2: + """ + Creates an engine from a checkpoint saved by ``InferenceEngineV2``. + + Arguments: + path: Path to the checkpoint. This does not need to point to any files in particular, + just the directory containing the checkpoint. + engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details. + debug_level: Logging level to use. Unless you are actively seeing issues, the recommended + value is ``logging.INFO``. + + Returns: + Fully initialized inference engine ready to serve queries. + """ + + inference_logger(level=debug_level) + # Load metadata, for grabbing the policy name we'll have all ranks just check for + # rank 0. + metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size) + metadata = json.load(open(metadata_filename, "r")) + metadata = ModelMetadata.parse_raw(metadata) + + # Get the policy + try: + policy_cls: InferenceV2Policy = POLICIES[metadata.policy] + except KeyError: + raise ValueError(f"Unknown policy {metadata.policy} for model {path}") + + # Load the model config + model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb")) + policy = policy_cls(model_config, inf_checkpoint_path=path) + + return InferenceEngineV2(policy, engine_config) + + +def build_hf_engine(path: str, + engine_config: RaggedInferenceEngineConfig, + debug_level: int = logging.INFO) -> InferenceEngineV2: + """ + Build an InferenceV2 engine for HuggingFace models. This can accept both a HuggingFace + model name or a path to an Inference-V2 checkpoint. + + Arguments: + path: Path to the checkpoint. This does not need to point to any files in particular, + just the directory containing the checkpoint. + engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details. + debug_level: Logging level to use. Unless you are actively seeing issues, the recommended + value is ``logging.INFO``. + + Returns: + Fully initialized inference engine ready to serve queries. + """ + + if os.path.exists(os.path.join(path, "ds_model_config.pkl")): + return build_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level) + else: + # Set up logging + inference_logger(level=debug_level) + # get HF checkpoint engine + checkpoint_engine = HuggingFaceCheckpointEngine(path) + + # get model config from HF AutoConfig + model_config = checkpoint_engine.model_config + + # get the policy + # TODO: generalize this to other models + if model_config.model_type == "opt": + if not model_config.do_layer_norm_before: + raise ValueError( + "Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/deepspeedai/DeepSpeed-MII/issues" + ) + policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "llama": + policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "mistral": + # Ensure we're using the correct version of transformers for mistral + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \ + f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}" + policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "mixtral": + # Ensure we're using the correct version of transformers for mistral + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \ + f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}" + policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "falcon": + policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi": + policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi3": + policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "qwen": + policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "qwen2": + policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "qwen2_moe": + policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "exaone4": + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.54.0"), \ + f"EXAONE 4.0 requires transformers >= 4.54.0, you have version {transformers.__version__}" + policy = Exaone4Policy(model_config, checkpoint_engine=checkpoint_engine) + else: + raise ValueError(f"Unsupported model type {model_config.model_type}") + + return InferenceEngineV2(policy, engine_config) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py new file mode 100644 index 000000000000..4a358310377f --- /dev/null +++ b/deepspeed/inference/v2/engine_v2.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import pickle +from typing import Iterable, Tuple + +import torch + +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.comm.comm import init_distributed + +from .model_implementations import InferenceV2Policy +from .logging import inference_logger +from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor +from .scheduling_utils import SchedulingError, SchedulingResult +from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename +from .model_implementations.inference_model_base import DSInferenceModelBase + +from .config_v2 import RaggedInferenceEngineConfig + +INFERENCE_MODEL_TIMER = "model-forward-inference" + + +class InferenceEngineV2: + + _config: RaggedInferenceEngineConfig + """ + Configuration of the inference engine. + """ + + _model: DSInferenceModelBase + """ + Inference model supporting ragged inference. + """ + + _state_manager: DSStateManager + """ + Persistent state manager for sequences and KV-cache. + """ + + @property + def free_blocks(self) -> torch.Tensor: + """ + Number of free KV blocks. This is a tensor of shape [n_kv_cache_groups] where each + element is the number of free blocks in the corresponding KV cache group. + """ + return self._state_manager.free_blocks + + @property + def n_kv_cache_groups(self) -> int: + """ + Number of KV cache groups. + """ + return self._state_manager.n_kv_cache_groups + + def model(self) -> DSInferenceModelBase: + """ + The model implementation. + """ + return self._model + + def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None: + """ + Create the Inference V2 engine. + + Arguments: + policy (InferenceV2Policy): Policy for the model implementation. This policy object + will be used to build the model and load the checkpoint associated with it. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + """ + self._config = engine_config + self._policy = policy + self._base_mp_group = self._initialize_tp_group() + + # Build model from policy + inference_logger().info("Building model...") + self._model = self._policy.build_model(self._config, self._base_mp_group) + inference_logger().info("Model built.") + + # Create state manager + self._batch = RaggedBatchWrapper(self._config.state_manager) + self._state_manager = DSStateManager(self._config.state_manager, + self._model.kv_cache_config(), + base_mp_group=self._base_mp_group) + self._model.set_state_manager(self._state_manager) + + def _initialize_tp_group(self): + """ + Implementation of our TP group initialization. + """ + init_distributed() + local_rank = int(os.getenv("LOCAL_RANK", 0)) + get_accelerator().set_device(local_rank) + + if local_rank >= self._config.tensor_parallel.tp_size: + raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.") + + ranks = list(range(self._config.tensor_parallel.tp_size)) + return dist.new_group(ranks=ranks) + + def put(self, + batch_uids: Iterable[int], + batch_tokens: Iterable[torch.Tensor], + do_checks: bool = True) -> torch.Tensor: + """ + Put a ragged batch onto the inference engine. This will perform one forward and return + a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens + are not calculated. + + Arguments: + batch_uids: Iterable of uids for the batch on the host + batch_tokens: Iterable of token tensors for the batch on the host + do_checks: Check schedulability when it is set to True. You can skip this check for better performance when it has already been completed. + """ + + if do_checks: + token_lens = [len(tokens) for tokens in batch_tokens] + schedule_check = self.can_schedule(batch_uids, token_lens) + if schedule_check != SchedulingResult.Success: + raise SchedulingError(schedule_check) + + self._batch.clear() + for uid, tokens in zip(batch_uids, batch_tokens): + + host_seq_desc = self._state_manager.get_or_create_sequence(uid) + self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + host_seq_desc.pre_forward(tokens.numel()) + + # We can disable checks since we already validated schedulability. + self._batch.insert_sequence(host_seq_desc, tokens, do_checks=do_checks) + + # Send all metadata to the device + self._batch.finalize() + + # Prep all data structures for the actual forward (in anticipation of CG in the future) + # and also to amortize some of the costs in a more straightforward way. + self._model.prepare_batch(self._batch) + + # Model implementation will pick up in the forward. + logits = self._model.forward(self._batch) + + # We return one set of logits per sequence in the batch (saves cost on unembedding) + assert logits.shape[0] == self._batch.current_sequences + + for uid in batch_uids: + host_seq_desc = self._state_manager.get_sequence(uid) + host_seq_desc.post_forward() # Updates sequence metadata. + self._model.maybe_free_kv(host_seq_desc) + + return logits + + def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, torch.Tensor]: + """ + Determine the number of tokens and KV blocks to reserve for a given request. Given a UID + (this UID may not be recognized by the model yet), this will return the number of tokens + and blocks to reserve for the request. + + Arguments: + uid (int): The UID of the sequence (as tracked by the scheduling entity). If + this is a new sequence (with a UID unknown to the inference engine), then + an empty placeholder is created to pass to the occupancy logic. + n_tokens (int): The number of tokens to hypothetically send. + + Returns: + Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks + required to schedule the sequence. + """ + seq_desc = self._state_manager.get_sequence(uid) + if seq_desc is None: + if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences): + return (0, 0) + seq_desc = PlaceholderSequenceDescriptor() + + req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks) + + return (req_tokens, req_blocks) + + def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult: + """ + Dry run a batch to determine if it can be scheduled. Placeholder sequences will be + created for any UIDs that are unknown to the inference engine. + + Arguments: + uids (Iterable[int]): Iterable of UIDs for the batch + lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths + corresponds to the number of tokens to send in the hypothetical forward; history + tokens will be determined via UID lookup and future tokens are disregarded. + + Returns: + bool: True if the batch can be scheduled, False otherwise. + """ + + cur_seqs = self._state_manager.n_tracked_sequences + free_blocks = self._state_manager.free_blocks + req_blocks = 0 + batch_len = 0 + + if len(uids) > self._config.state_manager.max_ragged_sequence_count: + # Can only compose a batch from a limited number of sequences + return SchedulingResult.BatchSequenceLimitExceeded + + for uid, length in zip(uids, lengths): + seq_desc = self._state_manager.get_sequence(uid) + if seq_desc is None: + cur_seqs += 1 + seq_desc = PlaceholderSequenceDescriptor() + + sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks) + + if sched_len != length: + # We ran out of KV cache + return SchedulingResult.KVCacheLimitExceeded + + batch_len += length + free_blocks -= sched_blocks + + if cur_seqs > self._config.state_manager.max_tracked_sequences: + # Would run out of tracking metadata + return SchedulingResult.EngineSequenceLimitExceeded + + if batch_len > self._config.state_manager.max_ragged_batch_size: + # Would exceed the maximum batch size + return SchedulingResult.BatchTokenLimitExceeded + + return SchedulingResult.Success + + def get_remaining_block_capacity(self, uid: int) -> int: + """ + Get the remaining capacity of the last block already allocated. + """ + seq_desc = self._state_manager.get_sequence(uid) + if seq_desc is None: + return 0 + return self._model.get_remaining_block_capacity(seq_desc) + + def flush(self, uid: int) -> None: + """ + Remove all state associated with a sequence from the inference engine. + + Arguments: + uid (int): The UID of the sequence to flush. + """ + self._state_manager.flush_sequence(uid) + + def serialize(self, save_path: str) -> None: + """ + Serialize the model to a file. + + Arguments: + path (str): Path to the file to serialize to. + """ + param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size) + metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size) + + # Save the flattened parameters + + torch.save(self._model.flattened_params, param_file_name) + + json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w")) + + if self._model.tp_rank == 0: + pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb")) diff --git a/deepspeed/inference/v2/inference_parameter.py b/deepspeed/inference/v2/inference_parameter.py new file mode 100644 index 000000000000..8ec5b458a9b5 --- /dev/null +++ b/deepspeed/inference/v2/inference_parameter.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict + +import torch + +CORE_PARAM = "_ds_core_param_key" + +STR_TO_DTYPE = { + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.int64": torch.int64, + "torch.int32": torch.int32, + "torch.int16": torch.int16, + "torch.int8": torch.int8, + "torch.uint8": torch.uint8, + "torch.bool": torch.bool, +} + + +class InferenceParameter(torch.Tensor): + """ + An extension of the torch.Tensor class to support our inference focused features. One important + thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of + torch.Tensor operations will not be InferenceParams. + """ + + @staticmethod + def __new__(cls, tensor, *args, **kwargs): + new_tensor = super().__new__(cls, tensor, *args, **kwargs) + if hasattr(tensor, "_aux_attrs"): + setattr(new_tensor, "_aux_attrs", tensor.aux_attrs) + return new_tensor + + def to(self, *args, **kwargs): + new_tensor = super().to(*args, **kwargs) + if hasattr(self, "_aux_attrs"): + setattr(new_tensor, "_aux_attrs", self.aux_attrs) + try: + _ = torch.device(args[0]) + for name, attr in new_tensor.aux_attrs.items(): + new_attr = attr.to(*args, **kwargs) + setattr(new_tensor, name, new_attr) + new_tensor.aux_attrs[name] = new_attr + except Exception: + pass + + return new_tensor + + @classmethod + def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter': + """ + Create the inference parameter. + """ + param = InferenceParameter(core_param) + setattr(param, "_aux_attrs", kwargs) + + for attr_name, attr in kwargs.items(): + if hasattr(param, attr_name): + raise ValueError(f"Attribute {attr_name} already exists on param.") + + if not isinstance(attr, torch.Tensor): + raise ValueError(f"Attribute {attr_name} must be a tensor.") + + setattr(param, attr_name, attr) + + return param + + @classmethod + def initialize_raw(self, **kwargs) -> 'InferenceParameter': + """ + All kwargs must be torch.Tensors and must include the core parameter. + """ + if CORE_PARAM not in kwargs: + raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.") + + return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs) + + @property + def aux_attrs(self) -> Dict[str, torch.Tensor]: + """ + Dictionary of auxiliary attributes. + """ + return self._aux_attrs diff --git a/deepspeed/inference/v2/inference_utils.py b/deepspeed/inference/v2/inference_utils.py new file mode 100644 index 000000000000..7b2dd4237353 --- /dev/null +++ b/deepspeed/inference/v2/inference_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict + +import torch + +from enum import Enum, IntEnum + + +class NormTypeEnum(Enum): + LayerNorm: str = "layer_norm" + RMSNorm: str = "rms_norm" + + +class DtypeEnum(Enum): + # The torch dtype must always be the first value (so we return torch.dtype) + fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" + fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" + bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat" + int8 = torch.int8, "torch.int8", "int8" + + # Copied from https://stackoverflow.com/a/43210118 + # Allows us to use multiple values for each Enum index and returns first + # listed value when Enum is called + def __new__(cls, *values): + obj = object.__new__(cls) + # first value is canonical value + obj._value_ = values[0] + for other_value in values[1:]: + cls._value2member_map_[other_value] = obj + obj._all_values = values + return obj + + def __repr__(self): + return "<%s.%s: %s>" % ( + self.__class__.__name__, + self._name_, + ", ".join([repr(v) for v in self._all_values]), + ) + + +ELEM_SIZES: Dict[torch.dtype, int] = { + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int8: 1, + torch.uint8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.bool: 1, +} + + +class ActivationType(IntEnum): + """ + Types of activations supported by DS-Inference + """ + + GELU = 0 + + RELU = 1 + + SILU = 2 + + GEGLU = 3 + + ReGLU = 4 + + SiGLU = 5 + + IDENTITY = 6 + + InvalidType = -1 + + +def is_gated(act_fn: ActivationType) -> bool: + """ + Return True if the given activation function is gated. + """ + if not isinstance(act_fn, ActivationType): + act_fn = ActivationType(act_fn) + + return act_fn in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU] + + +def elem_size(dtype: torch.dtype) -> int: + """ + Return size in bytes of the given dtype. + """ + try: + return ELEM_SIZES[dtype] + except KeyError: + raise ValueError("Unknown dtype size for {}".format(dtype)) + + +def ceil_div(a: int, b: int) -> int: + """ + Return ceil(a / b). + """ + return -(-a // b) diff --git a/deepspeed/inference/v2/kernels/__init__.py b/deepspeed/inference/v2/kernels/__init__.py new file mode 100644 index 000000000000..01b7b0580073 --- /dev/null +++ b/deepspeed/inference/v2/kernels/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ds_kernel import DSKernelBase diff --git a/deepspeed/inference/v2/kernels/core_ops/__init__.py b/deepspeed/inference/v2/kernels/core_ops/__init__.py new file mode 100644 index 000000000000..1d16b484a560 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .bias_activations import * +from .blas_kernels import * +from .cuda_layer_norm import * +from .cuda_rms_norm import * +from .gated_activations import * +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py new file mode 100644 index 000000000000..ea7f8a7d1996 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .bias_activation import * diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp new file mode 100644 index 000000000000..4f0cc9cbd77c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "bias_activation.h" +#include +#include "ds_kernel_utils.h" + +#ifdef BF16_AVAILABLE +#define DTYPE_SWITCH(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DTYPE_SWITCH(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +/* +In-place bias and activation fusion kernel. +*/ +void bias_activation(torch::Tensor& activation, + c10::optional& bias, + const int32_t act_type) +{ + const ActivationType atype = static_cast(act_type); + const int32_t rows = activation.size(0); + const int32_t cols = activation.size(1); + + TORCH_CHECK(atype == ActivationType::GELU || atype == ActivationType::RELU || + atype == ActivationType::SILU || atype == ActivationType::IDENTITY, + "Unsupported activation type for BiasActivation"); + TORCH_CHECK(activation.dim() == 2, "BiasActivation only supports 2D activation tensors"); + + DTYPE_SWITCH(activation.scalar_type(), [&] { + scalar_t* activation_ptr = reinterpret_cast(activation.data_ptr()); + + const scalar_t* bias_ptr; + if (bias.has_value()) { + TORCH_CHECK(activation.scalar_type() == bias.value().scalar_type(), + "BiasActivation activation and bias must have same dtype"); + bias_ptr = reinterpret_cast(bias.value().data_ptr()); + } else { + bias_ptr = nullptr; + } + + if (atype == ActivationType::IDENTITY && bias_ptr == nullptr) { return; } + + launch_bias_activation( + activation_ptr, bias_ptr, rows, cols, atype, c10::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h new file mode 100644 index 000000000000..db6174633a09 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "activation_type.h" + +template +void launch_bias_activation(T* activation, + const T* bias, + const int32_t n_rows, + const int32_t n_cols, + const ActivationType activation_type, + cudaStream_t stream); + +void bias_activation(torch::Tensor& activation, + c10::optional& bias, + const int32_t activation_type); diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py new file mode 100644 index 000000000000..436d7f8805d5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class CUDABiasActivation(DSKernelBase): + """ + CUDA implementation of bias activation kernel. This kernel should be deprecated once + we are fusing the bias activation into the linear kernel in all scenarios. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.IDENTITY, ActivationType.GELU, ActivationType.RELU, ActivationType.SILU] + + def __init__(self, channels: int, dtype: DtypeEnum, act_fn: ActivationType) -> None: + """ + Compile and validate for the fused bias-activation kernel. + + Parameters: + channels (int): Number of channels to expect in the activation. + dtype (torch.dtype): Data type for the input/output. Supported values + are DtypeEnum.fp16 and DtypeEnum.bf16. + act_fn (ActivationType): Activation function to use. Only IDENTITY, GELU, RELU, and SILU are supported. + """ + + if channels % 8 != 0: + raise ValueError("channels must be divisible by 8") + + if DtypeEnum(dtype) not in CUDABiasActivation.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, CUDABiasActivation.supported_dtypes)) + + act_fn = ActivationType(act_fn) + if act_fn not in CUDABiasActivation.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, CUDABiasActivation.supported_act_fns)) + + inf_module = InferenceCoreBuilder().load() + self.kernel = inf_module.bias_activation + self.act_fn = act_fn + + def __call__(self, activation: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Add an optional bias and perform the non-linear activation function. + + Parameters: + activation (torch.Tensor): Input tensor of shape [tokens, channels] + bias (torch.Tensor): Optional bias tensor of shape [channels] + + Returns: + activation that has been updated in-place + """ + self.kernel(activation, bias, self.act_fn.value) diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation_cuda.cu new file mode 100644 index 000000000000..66bca0c175c3 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation_cuda.cu @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "activation_type.h" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +// Default activation function will error out +template +DS_D_INLINE float act_fn(float val); + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val; +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val > 0.0f ? val : 0.0f; +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715f; + return val * 0.5f * (1.0f + tanhf(sqrt_param * (val + mul_param * val * val * val))); +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val / (1.0f + expf(-val)); +} + +namespace bias_act { + +constexpr int access_size = 16; +constexpr int threads = 512; +constexpr int unroll = 4; + +} // namespace bias_act + +template +__global__ void bias_activation_kernel(T* activation, + const T* bias, + const int32_t rows, + const int32_t cols) +{ + constexpr int vector_T = bias_act::access_size / sizeof(T); + + const int32_t thread_offset = threadIdx.x * vector_T; + const int32_t block_offset = blockIdx.x * vector_T * bias_act::unroll * bias_act::threads; + const int32_t base_offset = block_offset + thread_offset; + + const int32_t thread_stride = bias_act::threads * vector_T; + +#pragma unroll + for (int i = 0; i < bias_act::unroll; i++) { + const int32_t iter_offset = base_offset + i * thread_stride; + + const int32_t row = iter_offset / cols; + + T buffer[vector_T]; + T bias_buffer[vector_T]; + + if (row < rows) { + const int32_t col = iter_offset % cols; + + mem_access::load_global(buffer, activation + iter_offset); + mem_access::load_global( + bias_buffer, bias + col, bias != nullptr); + +#pragma unroll + for (int j = 0; j < vector_T; j++) { + float val = + conversion::to(buffer[j]) + conversion::to(bias_buffer[j]); + buffer[j] = conversion::to(act_fn(val)); + } + + mem_access::store_global(activation + iter_offset, buffer); + } + } +} + +#define ACT_TYPE_SWITCH(ACT_TYPE, ...) \ + if (ACT_TYPE == ActivationType::IDENTITY) { \ + constexpr ActivationType act_fn_t = ActivationType::IDENTITY; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::RELU) { \ + constexpr ActivationType act_fn_t = ActivationType::RELU; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::GELU) { \ + constexpr ActivationType act_fn_t = ActivationType::GELU; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::SILU) { \ + constexpr ActivationType act_fn_t = ActivationType::SILU; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } + +template +void launch_bias_activation(T* activation, + const T* bias, + const int32_t n_rows, + const int32_t n_cols, + const ActivationType activation_type, + cudaStream_t stream) +{ + constexpr int32_t elems_per_block = + bias_act::threads * bias_act::unroll * bias_act::access_size / sizeof(T); + const int32_t total_elems = n_rows * n_cols; + + const int32_t blocks = (total_elems + elems_per_block - 1) / elems_per_block; + + const dim3 grid(blocks); + const dim3 block(bias_act::threads); + + ACT_TYPE_SWITCH(activation_type, [&] { + bias_activation_kernel + <<>>(activation, bias, n_rows, n_cols); + }); +} + +#define INSTANTIATE_FOR_T(T) \ + template void launch_bias_activation( \ + T*, const T*, const int32_t, const int32_t, const ActivationType, cudaStream_t); + +INSTANTIATE_FOR_T(__half); + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_T(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py new file mode 100644 index 000000000000..4af5a579ca1b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blas_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h new file mode 100644 index 000000000000..1854e40a227d --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include "blas_utils.h" + +#define DISPATCH_BLAS_MATMUL(T_TYPE, C_TYPE) \ + if (output.options().dtype() == torch::T_TYPE) { \ + blas_gemm_ex(output.data_ptr(), \ + (const void*)weights.data_ptr(), \ + (const void*)hidden_states.data_ptr(), \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc, \ + trans_a, \ + trans_b, \ + &alpha, \ + &beta, \ + C_TYPE); \ + } + +void blas_linear(at::Tensor& output, at::Tensor& hidden_states, at::Tensor& weights) +{ + /* + Expected shape: output([total_tokens_across_dims], out_neurons) + hidden_states([total_tokens_across_dims], in_neurons) + weights(out_neurons, in_neurons) + + We are going to assume contiguous for the above shapes. + + The shapes are going to get messed with a little internally to handle column-major + GEMMs. + */ + + // Number of tokens is N (since the GEMM output is column-major but our Tensor + // is row-major, we need to transpose the shapes) + const int n = output.numel() / output.size(-1); + const int k = weights.size(1); + const int m = weights.size(0); + + // A strides + const bool trans_a = weights.stride(1) == 1; + const int lda = (trans_a) ? weights.stride(0) : weights.stride(1); + + // B strides + const bool trans_b = hidden_states.stride(-1) != 1; + const int ldb = (trans_b) ? hidden_states.stride(-1) : hidden_states.stride(-2); + + // C strides + const int ldc = output.stride(-2); + + const float alpha = 1.0f; + const float beta = 0.0f; + + TORCH_CHECK(output.scalar_type() == hidden_states.scalar_type(), + "Output and hidden states must have the same scalar type"); + TORCH_CHECK(output.scalar_type() == weights.scalar_type(), + "Output and weights must have the same scalar type"); + + // Dispatch the datatypes + DISPATCH_BLAS_MATMUL(kFloat, BlasType::FP32); + DISPATCH_BLAS_MATMUL(kHalf, BlasType::FP16); +#ifdef BF16_AVAILABLE + DISPATCH_BLAS_MATMUL(kBFloat16, BlasType::BF16); +#endif +} + +#define DISPATCH_4D_BLAS(T_TYPE, C_TYPE) \ + if (C.options().dtype() == torch::T_TYPE) { \ + blas_strided_batched_gemm(C.data_ptr(), \ + (const void*)A.data_ptr(), \ + (const void*)B.data_ptr(), \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc, \ + trans_a, \ + trans_b, \ + &alpha, \ + &beta, \ + stride_a, \ + stride_b, \ + stride_c, \ + batch, \ + C_TYPE); \ + } + +void blas_4d_matmul(at::Tensor& C, at::Tensor& B, at::Tensor& A) +{ + /* + C shape: (batch_size, N, M) + A shape: (batch_size, N, K) + B shape: (batch_size, K, M) + */ + + const int n = C.size(-2); + const int k = C.size(-1); + const int m = B.size(-1); + + // A strides + const bool trans_a = A.stride(-1) == 1; + const int lda = (trans_a) ? A.stride(-2) : A.stride(-1); + const int stride_a = A.stride(-3); + + // B strides + const bool trans_b = B.stride(-1) != 1; + const int ldb = (trans_b) ? B.stride(-1) : B.stride(-2); + const int stride_b = B.stride(-3); + + // C strides + const int ldc = C.stride(-2); + const int stride_c = C.stride(-3); + + const float alpha = 1.0f; + const float beta = 0.0f; + + const int batch = C.numel() / (n * m); + + // Dispatch the datatypes + DISPATCH_4D_BLAS(kFloat, BlasType::FP32); + DISPATCH_4D_BLAS(kHalf, BlasType::FP16); +#ifdef BF16_AVAILABLE + DISPATCH_4D_BLAS(kBFloat16, BlasType::BF16); +#endif +} + +void create_handle() { BlasContext::getInstance().get_handle(); } diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py new file mode 100644 index 000000000000..9a151ce36dc4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class BlasLibLinear(DSKernelBase): + """ + Wrapper around the BLAS matmul kernel for FP16/BF16/FP32 for CUDA/RoCM. + + Performs z = x @ y + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32] + + def __init__(self, fp_dtype: DtypeEnum): + """ + Parameters: + fp_dtype (torch.dtype): Data type for the input/output. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + fp_dtype = DtypeEnum(fp_dtype) + if fp_dtype not in BlasLibLinear.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, BlasLibLinear.supported_dtypes)) + + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.kernel = self.inf_module.blas_linear + + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + Matmul kernel as implemented by platform BLAS library. The input must be 2D or larger. If + n-dimensional, the leading dimensions are folded into each other: + 2D: m = x.size(0) + 3D: m = x.size(0) * x.size(1) + 4D: m = x.size(0) * x.size(1) * x.size(2) (etc...) + All inputs should be contiguous. + + Parameters: + output (torch.Tensor): Output tensor. Shape is of [*, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features] + weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features] + + Returns: + z (torch.Tensor): Output tensor. Shape is of [m, n] + """ + self.kernel(output, hidden_states, weights) + return output diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h new file mode 100644 index 000000000000..294db7528699 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#ifdef BF16_AVAILABLE +#include +#endif +#include +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif +#include +#include +#include + +class BlasContext { + /* + Slim wrapper for managing the lifetime of the platform's BLAS handle. This should + be hipified for ROCm. + */ +public: + BlasContext() + { + if (cublasCreate(&_handle) != CUBLAS_STATUS_SUCCESS) { + auto message = std::string("Fail to create cublas handle."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } +#ifndef __HIP_PLATFORM_AMD__ + cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH); +#endif + } + + virtual ~BlasContext() { cublasDestroy(_handle); } + + static BlasContext& getInstance() + { + // Should always access the singleton through this function. + static BlasContext _instance; + return _instance; + } + + cublasHandle_t get_handle() const { return _handle; } + +private: + cublasHandle_t _handle; +}; + +enum class BlasType { FP32, FP16, BF16 }; + +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) +rocblas_operation get_trans_op(bool do_trans) +{ + return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none; +} + +rocblas_datatype get_datatype(BlasType type) +{ + switch (type) { + case BlasType::FP32: return rocblas_datatype_f32_r; + case BlasType::FP16: return rocblas_datatype_f16_r; + case BlasType::BF16: return rocblas_datatype_bf16_r; + default: throw std::runtime_error("Unsupported BlasType"); + } +} +#else +cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T : CUBLAS_OP_N; } + +cublasDataType_t get_datatype(BlasType type) +{ + switch (type) { +#ifdef __HIP_PLATFORM_AMD__ + case BlasType::FP32: return HIPBLAS_R_32F; + case BlasType::FP16: return HIPBLAS_R_16F; + case BlasType::BF16: return HIPBLAS_R_16B; +#else + case BlasType::FP32: return CUDA_R_32F; + case BlasType::FP16: return CUDA_R_16F; + case BlasType::BF16: return CUDA_R_16BF; +#endif + default: throw std::runtime_error("Unsupported BlasType"); + } +} +#endif + +int blas_gemm_ex(void* C, + const void* A, + const void* B, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + bool transa, + bool transb, + const float* alpha, + const float* beta, + BlasType type) +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_operation_t transa_op = get_trans_op(transa); + rocblas_operation_t transb_op = get_trans_op(transb); + + rocblas_datatype_t abc_type = get_datatype(type); + + rocblas_status status = rocblas_gemm_ex(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + B, + abc_type, + ldb, + (const void*)beta, + C, + abc_type, + ldc, + C, + abc_type, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); +#else + cublasOperation_t transa_op = get_trans_op(transa); + cublasOperation_t transb_op = get_trans_op(transb); + + cublasDataType_t abc_type = get_datatype(type); + cublasStatus_t status = cublasGemmEx(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + B, + abc_type, + ldb, + (const void*)beta, + C, + abc_type, + ldc, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int blas_strided_batched_gemm(void* C, + const void* A, + const void* B, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + bool transa, + bool transb, + const float* alpha, + const float* beta, + int stride_A, + int stride_B, + int stride_C, + int batch, + BlasType type) +{ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + rocblas_operation_t transa_op = get_trans_op(transa); + rocblas_operation_t transb_op = get_trans_op(transb); + + rocblas_datatype_t abc_type = get_datatype(type); + + rocblas_status status = + rocblas_gemm_strided_batched_ex(BlasContext::getInstance()::get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + stride_A, + B, + abc_type, + ldb, + stride_B, + (const void*)beta, + C, + abc_type, + ldc, + stride_C, + C, + abc_type, + ldc, + stride_C, + batch, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); +#else + cublasOperation_t transa_op = get_trans_op(transa); + cublasOperation_t transb_op = get_trans_op(transb); + + cublasDataType_t abc_type = get_datatype(type); + + cublasStatus_t status = cublasGemmStridedBatchedEx(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + stride_A, + B, + abc_type, + ldb, + stride_B, + (const void*)beta, + C, + abc_type, + ldc, + stride_C, + batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else + CUDA_R_32F, +#endif + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp new file mode 100644 index 000000000000..3f36a6bf01cb --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include + +#include "bias_activation.h" +#include "blas.h" +#include "gated_activation_kernels.h" +#include "layer_norm.h" +#include "linear_kernels.h" +#include "rms_norm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // bias_activation.h + m.def("bias_activation", &bias_activation, "DeepSpeed bias activation in CUDA"); + + // layer_norm.h + m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm in CUDA"); + m.def("pre_layer_norm", &ds_pre_layer_norm, "DeepSpeed pre layer norm in CUDA"); + m.def("post_layer_norm", &ds_post_layer_norm, "DeepSpeed pre layer norm in CUDA"); + + // blas.h + m.def("blas_linear", &blas_linear, "Linear implemented by vendor BLAS"); + m.def("blas_4d_matmul", &blas_4d_matmul, "4D matmul implemented by vendor BLAS"); + m.def("create_handle", &create_handle, "Create a handle for vendor BLAS"); + + // gated_activation_kernels.h + m.def("gated_activation", &ds_gated_activation, "DeepSpeed gated activation in CUDA"); + + // rms_norm.h + m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA"); + m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA"); + + // linear_kernels.h + m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA"); + m.def( + "preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit"); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py new file mode 100644 index 000000000000..bed7688b15d2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_ln import * +from .cuda_post_ln import * +from .cuda_pre_ln import * diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py new file mode 100644 index 000000000000..3c2aa5cb5eb4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDAFPLNBase(DSKernelBase): + """ + Base class for CUDA LN kernels. They all same the same validation logic, + so we can share it here. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5): + """ + Parameters: + channels (int): Number of channels in the input tensor. Must be divisible to align + to 16 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in CUDAFPLNBase.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDAFPLNBase.supported_dtypes)) + + if elem_size(fp_dtype) * channels % 16 != 0: + raise ValueError("channels must be divisible by 16 bytes") + + self.inf_module = InferenceCoreBuilder().load() + self.epsilon = epsilon diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py new file mode 100644 index 000000000000..583736fb8bbc --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPLN(CUDAFPLNBase): + """ + Floating point layer norm kernel for CUDA/RoCM. + + Performs: z = ln(x) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> torch.Tensor: + """ + output_z may alias input_x directly. All Tensors should have the same shape. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + """ + self.inf_module.layer_norm(output_z, input_x, gamma, beta, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py new file mode 100644 index 000000000000..0ced1ecf207e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPPostLN(CUDAFPLNBase): + """ + Floating point post-LayerNorm kernel for CUDA/RoCM. + + Performs: z = ln(x + y) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, input_y: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> torch.Tensor: + """ + Either input_x or input_y can alias output_z. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + input_y (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.post_layer_norm(output_z, input_x, input_y, gamma, beta, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py new file mode 100644 index 000000000000..74b2d9cf5880 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPPreLN(CUDAFPLNBase): + """ + Floating point pre-LayerNorm kernel for CUDA/RoCM. + + Performs: z_res = x_res + y_hid + z_hid = ln(z_hid) + """ + + def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor, + gamma: torch.Tensor, beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + z_res can alias x_res. All non-parameter input/output tensors + must have the same shape. z_hid can alias y_hid. + + Parameters: + z_res (torch.Tensor): Output residual. + z_hid (torch.Tensor): Output hidden states. + x_res (torch.Tensor): Input residual. + y_hid (torch.Tensor): Input hidden states. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.pre_layer_norm(z_res, z_hid, x_res, y_hid, gamma, beta, self.epsilon) + return z_res, z_hid diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp new file mode 100644 index 000000000000..b2c95d410a1f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "layer_norm.h" + +#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_LAYER_NORM(kFloat, float); + DISPATCH_LAYER_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_post_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_post_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_pre_ln((C_TYPE*)norm_output.data_ptr(), \ + (C_TYPE*)res_output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_pre_layer_norm(at::Tensor& res_output, + at::Tensor& norm_output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h new file mode 100644 index 000000000000..9ea3a8c42524 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +/* +Kernel launch methods for layer norm variants. +*/ + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_fused_post_ln(T* output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); +template +void launch_fused_pre_ln(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +void ds_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); + +void ds_post_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); + +void ds_pre_layer_norm(at::Tensor& res_output, + at::Tensor& norm_output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu new file mode 100644 index 000000000000..fb6dd0578f1d --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu @@ -0,0 +1,489 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace ln { +constexpr int granularity = 16; +} // namespace ln + +/* +Regular layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +Args: + output: buffer for output data + vals: buffer for input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +*/ +template +__global__ void fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[unRoll * T_per_load]; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + sum = reduce::element(sum, vals_up_cast); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_ln \ + <<>>(output, vals, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_FUSED_LN(T) \ + template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_FUSED_LN(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_FUSED_LN(__nv_bfloat16); +#endif +INSTANTIATE_FUSED_LN(float); + +/* +Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual +need to be fused into compute-bound producer operations. + +Args: + output: buffer for output data + res_output: output of residual addition + vals: buffer for input data + residual: residual data + bias: bias of of input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +Template arg: + StoreResidual: controls whether the residual calculation is stored + or not. When set to false, the input `res_output` is unused. +*/ +template +__global__ void fused_residual_ln(T* output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = tb.size() * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + + T local_buffer[unRoll * T_per_load]; + + // Unlike a vanilla layernorm, since we're fusing the two adds as well + // an inner unRoll seems to be less valuable. If anything, a double unRoll + // makes the most sense if we find we are having performance issues. +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + T residual_buffer[T_per_load]; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + mem_access::load_global(residual_buffer, + residual_base + i * stride, + thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + float res_up_cast = conversion::to(residual_buffer[j]); + vals_up_cast += res_up_cast; + sum = reduce::element(sum, vals_up_cast); + iteration_buffer[j] = conversion::to(vals_up_cast); + } + + if (preLnResidual && (thread_offset + i * stride < elems_per_row)) { + mem_access::store_global(res_output + base_offset + i * stride, + iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified. +#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + output, nullptr, vals, residual, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_post_ln(T* output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + norm_output, res_output, vals, residual, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_pre_ln(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_RES_LN(T) \ + template void launch_fused_post_ln( \ + T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +#define INSTANTIATE_PRE_LN_RES(T) \ + template void launch_fused_pre_ln( \ + T*, T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_RES_LN(__half); +INSTANTIATE_RES_LN(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_RES_LN(__nv_bfloat16); +#endif + +INSTANTIATE_PRE_LN_RES(__half); +INSTANTIATE_PRE_LN_RES(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_PRE_LN_RES(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py new file mode 100644 index 000000000000..cd08409c0a7a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py new file mode 100644 index 000000000000..69aa9e8920e2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from ....logging import inference_logger +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class CUDAWf6Af16Linear(DSKernelBase): + """ + Wrapper around the CUDA kernel of Wf6Af16 quantized linear. + + Performs z = x @ y + """ + supported_dtypes = [DtypeEnum.fp16] + + def __init__(self): + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.kernel = self.inf_module.cuda_wf6af16_linear + # The split_k_map is profiled on A100-80G GPU for some common shapes. + # It is an array of dictionaries, where the array index is the tokens chunk id. + # The dictionary is the mapping from the output channel to the split-K size. + self.split_k_map = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } + ] + + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, + weights_4bit: torch.Tensor, scale: torch.Tensor, out_channels, tokens, in_channels) -> torch.Tensor: + """ + Matmul kernel of FP6 weight-only quantized linear. All inputs should be contiguous. + It does not support batched-matmul. + + Parameters: + output (torch.Tensor): Output tensor. Shape is of [token_number, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [token_number, in_features] + weights_2bit (torch.Tensor): Input tensor of the 2-bit slice. Shape is of [out_features*2/8, in_features] + weights_4bit (torch.Tensor): Input tensor of the 4-bit slice. Shape is of [out_features*4/8, in_features] + scale (torch.Tensor): Input tensor. Shape is of [out_features], since the scale is per output channel + out_channels (int): The number of output channels + tokens (int): The number of tokens + in_channels (int): The number of input channels + """ + + if out_channels % 256 != 0 or in_channels % 64 != 0: + raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.") + + # TODO: add a more general heuristic to determine the split-K. + split_k = -1 # not initialized + if tokens <= 768: + # Try to find the split-K from the pre-profiled map. + tokens_chunk_id = (tokens - 1) // 64 + split_k = self.split_k_map[tokens_chunk_id].get(out_channels, -1) + if split_k == -1: + split_k = 1 + inference_logger().warning( + f"The split-K setting may be suboptimal for shape {tokens}x{in_channels}x{out_channels}...") + + workspace = self.get_workspace(out_channels, tokens, in_channels, split_k, torch.float, hidden_states.device) + self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, out_channels, tokens, + in_channels, split_k) + + def get_workspace(self, out_channels: int, tokens: int, in_channels: int, split_k: int, dtype, + device) -> torch.Tensor: + """ + Allocate workspace for the kernel. The workspace is used to store the intermediate results of the matmul before + split-K. The split-K size is determined by the size of the matmul. + """ + workspace = torch.empty((split_k, out_channels, tokens), dtype=dtype, device=device) + + return workspace diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h new file mode 100644 index 000000000000..76e8eda2d35e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef CONFIGS_H +#define CONFIGS_H + +// #define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 \ + 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 \ + 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * + PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = + TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +/************************ General Config for Quant-LLM **********************/ +#define WEIGHT_FRAG1_BIT_WIDTH 2 +#define WEIGHT_FRAG2_BIT_WIDTH 4 +#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH + WEIGHT_FRAG2_BIT_WIDTH) // 6 +// #define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // +// QuantGroupSize: 4*64 = 256 +/*************************** 64*64 Weghts of A WARP *************************/ +#define WEIGHT_PER_UNIT (WARP_M * WARP_K) // 64*64 +#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / \ + 8) // 1024 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / \ + 8) // 2048 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_A1_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A1 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double + // buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_A2_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A2 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double + // buffer for 2-level pipeline A= 16 KB. +/******************** Global Memory Layout For QUANTIZED DATA ******************/ +#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / 128) // 64 +#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / 128) // 128 +/******************** Register Allocation For QUANTIZED DATA ******************/ +#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT / WARP_SIZE) // 128 +#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 2) // 8 +#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 4) // 16 +/******************** Register Allocation For QUANT Scales ******************/ +#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers +#define WARP_REG_QUANT_SCALE_DISTRIBUTED \ + 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for + // each thread + +#endif // CONFIGS_H diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh new file mode 100644 index 000000000000..860f70b226cb --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_KERNEL_MATMUL_CUH +#define DEEPSPEED_CUDA_LINEAR_KERNEL_MATMUL_CUH + +#include "configs.h" +#include "utils_core.cuh" +#include "utils_gmem.cuh" + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 900 + +#ifdef DEBUG_MODE + assert(K_Global % TilingConfig::TILE_K == 0); + assert(M_Global % TilingConfig::TILE_M == 0); + assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M)); +#endif + extern __shared__ __align__(128) + half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + reinterpret_cast( + smem + + (SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE) / 2); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64 * TilingConfig::BLOCK_WARPS]; // static shared memory for + // quantization scales, 64 row per + // warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = + blockIdx.y % + (M_Global / TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global - Tile_Start_N) < TilingConfig::TILE_N + ? (N_Global - Tile_Start_N) + : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global / TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K / Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if (BatchID < ExtraNumBlock_K) NumIter++; + size_t StartBlockID_K = AverageNumBlock_K * BatchID; + if (BatchID < ExtraNumBlock_K) + StartBlockID_K += BatchID; + else + StartBlockID_K += ExtraNumBlock_K; + // Warp ID. + const int warpId = threadIdx.x / WARP_SIZE; + int WARP_i = + warpId / TilingConfig::BLOCK_COL_WARPS; // WARP_i: row number; WARP_j: column number + // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + // Global Memory Address for Matrix A (Weight) + // ///////////////////////////////////////////////////////////////////////// StartPTR for each + // ThreadBlock(TB) + const uint4* TB_StartGPTR_A1 = + Weight1 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* TB_StartGPTR_A2 = + Weight2 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP. + const uint4* WARP_StartGPTR_A1 = + TB_StartGPTR_A1 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* WARP_StartGPTR_A2 = + TB_StartGPTR_A2 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP, considering SplitK + const size_t WARP_Start_UnitID_K = StartBlockID_K; + WARP_StartGPTR_A1 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + WARP_StartGPTR_A2 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // Copying A tile from Global to Shared, using double-buffer + // ////////////////////////////////////////////////////////// StartSPTR for each ThreadBlock + uint32_t* AFrag_2BIT_SPTR = reinterpret_cast(smem); + uint32_t* AFrag_4BIT_SPTR = + AFrag_2BIT_SPTR + + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * TilingConfig::BLOCK_WARPS * + PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4; + // Pre-fetch of A tile + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared_A( + AFrag_2BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A( + AFrag_4BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; + } + // Global Memory Address for Matrix A (QuantScale) + // ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64, WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK + // ///////////////////////////////////////////////////////////// + const half* BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared( + smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros + // ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef PIPELINE_LEVEL_SMEM + uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM] + [4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM] + [4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) +#endif + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++) + for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f; + // + cp_async_wait_all(); + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64); +#ifdef PIPELINE_LEVEL_SMEM + // Initializing the Software Pipeline: writing registers. + // //////////////////////////////////////////////////////////////////////////////////////////////// + initialize_mma_slice( + a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); +#endif +// The outer loop. +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +#ifdef PIPELINE_LEVEL_SMEM + uint32_t* __restrict__ read2_SPTR_Frag1 = + AFrag_2BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4; + uint32_t* __restrict__ read2_SPTR_Frag2 = + AFrag_4BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4; +#endif + uint32_t* __restrict__ write_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#ifdef PIPELINE_LEVEL_SMEM + half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A( + write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A( + write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared( + write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); +#ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, + a, + b, + read_SPTR_Frag1, + read_SPTR_Frag2, + read_SPTR, + Scales_RPTR, + 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each + // WARP; read_SPTR is shared among WARPs + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice( + c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; +#else + PipelinedCoreLoop( + c, + read_SPTR, + read_SPTR_Frag1, + read_SPTR_Frag2, + Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; + // read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); +#endif + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast(smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = + C + BatchID * (M_Global * N_Global) + Tile_Start_M + Tile_Start_N * M_Global; + for (size_t i = warpId; i < NumColumnToCopy; i += TilingConfig::BLOCK_WARPS) // i-th column +#pragma unroll + for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M; + j += WARP_SIZE) // j-th row + { + if constexpr (std::is_same::value) + BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]); + else + BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j]; + } + +#else + assert(("The FP6 functions are only available on Ampere GPUs.", false)); +#endif +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh new file mode 100644 index 000000000000..c417e6a46a7c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_KERNEL_REDUCTION_CUH +#define DEEPSPEED_CUDA_LINEAR_KERNEL_REDUCTION_CUH + +#include +#include +#include + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, + float* Reduction_Workspace, + size_t M_Global, + size_t N_Global, + int Split_K) +{ + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { +#pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } +// Writing to global memory +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh new file mode 100644 index 000000000000..982d5a80010c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_PTX_CP_ASYNC_CUH +#define DEEPSPEED_CUDA_LINEAR_PTX_CP_ASYNC_CUH + +#include +#include +#include + +template +__device__ __forceinline__ void cp_async(half* smem_ptr, + const half* global_ptr, + bool pred_guard = true) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile( + "{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); +#else + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); +#endif +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +__device__ __forceinline__ void cp_async_group_commit() +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#else + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); +#endif +} + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +__device__ __forceinline__ void cp_async_wait_group() +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); +#endif +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#else + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); +#endif +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh new file mode 100644 index 000000000000..56f86a46f6b5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_PTX_MMA_CUH +#define DEEPSPEED_CUDA_LINEAR_PTX_MMA_CUH + +#include +#include +#include + +#include +#include "configs.h" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t (*__restrict__ Reg)[4], + half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int slice_id) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif + + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][slice_id * MMA_16 + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +#else + assert( + ("The matrix load functions are only supported on Ampere and newer architectures", false)); +#endif +} +#else +// Debug: Whether ldmatrix.trans is required??? +// B is in column-major +template +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t (*__restrict__ Reg)[4], + half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int k_offset) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif + + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][k_offset + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +#else + assert( + ("The matrix load functions are only supported on Ampere and newer architectures", false)); +#endif +} +#endif + +__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c, + uint32_t* __restrict__ a, + uint32_t* __restrict__ b) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), + "r"(a[1]), + "r"(a[2]), + "r"(a[3]), + "r"(b[0]), + "r"(b[1]), + "r"(c[0]), + "r"(c[1]), + "r"(c[2]), + "r"(c[3])); +#else + assert(("The mma functions are only implemented for Ampere and newer architectures", false)); +#endif +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh new file mode 100644 index 000000000000..bd8a009a02c6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_CORE_CUH +#define DEEPSPEED_CUDA_LINEAR_UTILS_CORE_CUH + +#include + +#include "configs.h" +#include "ptx_mma.cuh" +#include "utils_paralleldequant.cuh" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], + uint32_t* SPTR, + int slice_id) +{ + SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } +} + +template +__device__ __forceinline__ void initialize_mma_slice( + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) +{ + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at + // register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers +} + +template +__device__ __forceinline__ void core_mma_slice( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +{ +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast( + c); // Registers for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t(*a_read)[4] = a; + uint32_t(*a_write)[4] = a; + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + if (slice_id % 2 == 1) { + b_write += NumRegSets_b; + a_write += NumRegSets_a; + } else { + b_read += NumRegSets_b; + a_read += NumRegSets_a; + } + +// Reading registers and issuing core tensor core computations (a slice of A and B tile in shared +// memory) +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]); + } else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a_read[i], + b_read[j] + 2); // c+4; b+2 + } + } + } + + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way( + a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register + // level, dequantizing a slice each time + B_FromSharedToReg( + b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +#else +// Old version with naive pipeline design +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) +{ + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } +} +template +__device__ __forceinline__ void PipelinedCoreLoop( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) +{ +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + + // Registers to store FP32 results + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2 * 2]; // double buffer is used + uint32_t a_2[4 * 2]; // double buffer is used + // Registers to store decompressed FP6 + uint32_t a[NumRegSets_a * 1][4]; // No double buffer + // Register to store FP16 B matrix (a slice) + uint32_t b[NumRegSets_b * 2][4]; // double buffer is used + + // Overlapped Smem and TC pipeline: pre-loading from shared to registers + CopyFromSharedToRegister_AFrag<2>(a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2, read_SPTR_Frag2); + B_FromSharedToReg(b, read_SPTR, 0); + +#pragma unroll + for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + uint32_t* a_1_read = a_1; + uint32_t* a_1_write = a_1; + uint32_t* a_2_read = a_2; + uint32_t* a_2_write = a_2; + if (k % 2 == 0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; + } + // data loading + if (k + 1 < WARP_K_MMA_TENSORS) { + // updating SPTR for fragment1 and fragment2 + read_SPTR_Frag1 += 2 * WARP_SIZE; + read_SPTR_Frag2 += 4 * WARP_SIZE; + CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); + B_FromSharedToReg(b_write, read_SPTR, (k + 1) * MMA_16); + } + // SIMT Dequant + Tensor Core computations + Dequant_32FP6_4Way( + a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, + // dequantizing a slice each time +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) + MMA_FP16_M16N8K16(c_uint_ptr[i], a[i], b_read[0]); + else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a[i], + b_read[j] + 2); // c+4; b+2 + } + } + } + } +} +#endif // #ifdef PIPELINE_LEVEL_SMEM + +template +__device__ __forceinline__ void StoreToSharedMemoryFromRegister( + float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) +{ + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; + j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS; + int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; +#pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r % 2 == 1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = + c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh new file mode 100644 index 000000000000..3dd7e9e0104e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_GMEM_CUH +#define DEEPSPEED_CUDA_LINEAR_UTILS_GMEM_CUH + +#include +#include "configs.h" +#include "ptx_cp.async.cuh" + +/* + * Copying A1/A2 from global memory to shared memory. + * Usually 1024 or 2048 Bytes + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) +{ +#ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0); +#endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id * 8; + GPTR_HALF += lane_id * 8; +#pragma unroll + for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) { + cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } +} + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id * 2; + int Offset_Global = lane_id / 4 + (lane_id % 4) * 16; + for (int i = 0; i < 2; i++) + SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8]; +} + +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared( + half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) +{ + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x % 8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; +#pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh new file mode 100644 index 000000000000..11603fcc576c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_PARALLELDEQUANT_CUH +#define DEEPSPEED_CUDA_LINEAR_UTILS_PARALLELDEQUANT_CUH + +#include +#include +#include + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t* R1, uint32_t* R2) +{ + *R2 = *R1 & 0x80808080; + *R1 = *R1 >> 2; + *R1 = *R1 & 0x1f1f1f1f; + *R2 = *R2 | *R1; + *R1 = *R2 & 0x9f009f00; + *R2 = *R2 & 0x009f009f; + *R2 = *R2 << 8; +} + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is NOT applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t* R1, uint32_t* R2) +{ + //*R2 = *R1 & 0x80808080; + *R2 = *R1 & 0xc0c0c0c0; + *R1 = *R1 >> 2; + //*R1 = *R1 & 0x1f1f1f1f; + *R1 = *R1 & 0x0f0f0f0f; + *R2 = *R2 | *R1; + // + //*R1 = *R2 & 0x9f009f00; + //*R2 = *R2 & 0x009f009f; + *R1 = *R2 & 0xcf00cf00; + if (!(*R1 & 0x40000000) && (*R1 & 0x0c000000)) *R1 = *R1 | 0x30000000; + if (!(*R1 & 0x00004000) && (*R1 & 0x00000c00)) *R1 = *R1 | 0x00003000; + *R2 = *R2 & 0x00cf00cf; + if (!(*R2 & 0x00400000) && (*R2 & 0x000c0000)) *R2 = *R2 | 0x00300000; + if (!(*R2 & 0x00000040) && (*R2 & 0x0000000c)) *R2 = *R2 | 0x00000030; + // + *R2 = *R2 << 8; + //*R1 = 0x3c003c00; + //*R2 = 0x3c003c00; +} + +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) +{ + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul(__hmul(*FP16_1, __float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul(__hmul(*FP16_2, __float2half(4096.0f)), Scale); + return output; +} + +__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4], + uint32_t* __restrict__ read_RPTR_Frag1, + uint32_t* __restrict__ read_RPTR_Frag2, + uint32_t* Scales) +{ + uint32_t* OutputRegs = reinterpret_cast(Reg); + uint32_t* Frag1_PTR = read_RPTR_Frag1; + uint32_t* Frag2_PTR = read_RPTR_Frag2; + half* Scale_RPTR = reinterpret_cast(Scales); + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; +// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + // Frag1 + Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; + if (i % 4 == 3) + Frag1_PTR++; + else + (*Frag1_PTR) = (*Frag1_PTR) << 2; + // Frag2 + tmp = (*Frag2_PTR) & 0xf0f0f0f0; + tmp = tmp >> 2; + if (i % 2 == 1) + Frag2_PTR++; + else + (*Frag2_PTR) = (*Frag2_PTR) << 4; + // Packed_FP6 + Packed_FP6 = Packed_FP6 | tmp; + // + FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + // + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0]); // Muliply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if (i % 2 == 1) Scale_RPTR += 2; + } +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, + half* WARP_SPTR_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; +#pragma unroll + for (int i = 0; i < 4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h new file mode 100644 index 000000000000..384e2f0b26a0 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_WEIGHT_PREPACKING_H +#define DEEPSPEED_CUDA_LINEAR_WEIGHT_PREPACKING_H + +#include +#include +#include + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], + unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0] << 6) | ((FP6_Array[1] >> 2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1] << 4) | ((FP6_Array[2] >> 4) & 0xfc); + Padded_FP6[3] = FP6_Array[2] << 2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3] << 6) | ((FP6_Array[4] >> 2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4] << 4) | ((FP6_Array[5] >> 4) & 0xfc); + Padded_FP6[7] = FP6_Array[5] << 2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, + unsigned char B2, + unsigned char B3, + unsigned char B4) +{ + unsigned char out; + out = (B1 & 0xc0) | ((B2 & 0xc0) >> 2) | ((B3 & 0xc0) >> 4) | ((B4 & 0xc0) >> 6); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6( + unsigned char B1, + unsigned char + B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ((B1 << 2) & 0xf0) | ((B2 >> 2) & 0x0f); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(std::vector Seg_2bit[], + std::vector Seg_4bit[], + unsigned char* PTR_1, + unsigned char* PTR_2, + unsigned char* PTR_3, + unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for (int t = 0; t < 4; t++) { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0 + t * 2], + Padded_8_FP8[0][1 + t * 2], + Padded_8_FP8[1][0 + t * 2], + Padded_8_FP8[1][1 + t * 2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0 + t * 2], + Padded_8_FP8[2][1 + t * 2], + Padded_8_FP8[3][0 + t * 2], + Padded_8_FP8[3][1 + t * 2]); + Seg2_Byte1_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0 + t * 2], Padded_8_FP8[0][1 + t * 2]); + Seg2_Byte2_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0 + t * 2], Padded_8_FP8[1][1 + t * 2]); + Seg2_Byte3_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0 + t * 2], Padded_8_FP8[2][1 + t * 2]); + Seg2_Byte4_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0 + t * 2], Padded_8_FP8[3][1 + t * 2]); + } + // + for (int t = 0; t < 4; t++) { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for + // bit-interleaving in QuantLLM + int order_2bit[16] = { + 2, 6, 10, 14, 4, 8, 12, 16, 1, 5, 9, 13, 3, 7, 11, 15}; // pre-defined order for + // bit-interleaving in QuantLLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for (int i = 0; i < 16; i++) Frags_2bit[i] = (input << 2 * (order_2bit[i] - 1)) & 0xc0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 16; i++) output |= (Frags_2bit[i] >> (i * 2)); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in QuantLLM + int order_4bit[8] = { + 2, 6, 4, 8, 1, 5, 3, 7}; // pre-defined order for bit-interleaving in QuantLLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for (int i = 0; i < 8; i++) Frags_4bit[i] = (input << 4 * (order_4bit[i] - 1)) & 0xf0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 8; i++) output |= (Frags_4bit[i] >> (i * 4)); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = Weight_6bit; + unsigned char* Weight_4bit = Weight_6bit + M * K * 2 / 8; + // + std::vector A_Segment_2bit[32]; + std::vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K * 6 / 8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) { + for (size_t k = 0; k < 64 / 16; k++) { + size_t row = i * 64 + k * 16; + size_t col = j * 16; + unsigned char* StartPTR_1 = Weight_6bit + row * BytesPerRow + col * 6 / 8; + unsigned char* StartPTR_2 = StartPTR_1 + 8 * BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8 * 6 / 8; + unsigned char* StartPTR_4 = StartPTR_2 + 8 * 6 / 8; + // Dealing with each 16*16 blocks then... + for (int l = 0; l < 8; l++) + Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l * 4], + &A_Segment_4bit[l * 4], + StartPTR_1 + l * BytesPerRow, + StartPTR_2 + l * BytesPerRow, + StartPTR_3 + l * BytesPerRow, + StartPTR_4 + l * BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M * K * 2 / 8 / 32; + size_t BytesPerThread_4bit = M * K * 4 / 8 / 32; + for (int i = 0; i < 32; i++) { + assert(A_Segment_2bit[i].size() == BytesPerThread_2bit); + assert(A_Segment_4bit[i].size() == BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for (size_t i = 0; i < BytesPerThread_2bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_2bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_2bit[t] + [i * 4 + b]; // why (3-b): special byte order within a register + for (size_t i = 0; i < BytesPerThread_4bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_4bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_4bit[t][i * 4 + b]; // why (3-b):special byte order within a register + // Pass-3: Bit-level interleaving + for (size_t i = 0; i < BytesPerThread_2bit * 32 / 4; i++) + BitInterleaving_2bit(Weight_2bit + 4 * i); + for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) + BitInterleaving_4bit(Weight_4bit + 4 * i); +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp new file mode 100644 index 000000000000..3b4966eb822b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "linear_kernels.h" + +namespace { + +// For bit-level debugging. +template +void print_bits(T num) +{ + char bits[sizeof(T) * 8 + 1] = {'\0'}; + for (int bit = 0; bit < (sizeof(T) * 8); bit++) { + bits[sizeof(T) * 8 - 1 - bit] = '0' + (num & 0x01); + num = num >> 1; + } + printf("%s\n", bits); +} + +void print_bits(half num) +{ + char bits[sizeof(half) * 8 + 1] = {'\0'}; + auto int_num = *reinterpret_cast(&num); + for (int bit = 0; bit < (sizeof(half) * 8); bit++) { + bits[sizeof(half) * 8 - 1 - bit] = '0' + (int_num & 0x01); + int_num = int_num >> 1; + } + printf("%s\n", bits); +} + +/* + * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. + */ +void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + // Constants for FP6 + constexpr int exponent_nbits_fp6 = 3; + constexpr int mantissa_nbits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + // Constants for FP16 + constexpr int exponent_nbits_fp16 = 5; + constexpr int mantissa_nbits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; + + int fp6_temp[4]; + + float absmin_nonzero_fp6 = 0.0625; + // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is + // the same with that in qtorch. + float absmax_fp6 = 28; + + for (int i = 0; i < 4; ++i) { + uint16_t source = FP16x4[i]; + float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); + if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || + fp6_value_abs > absmax_fp6) { + // TODO(zhen): a better way may be rounding it to the nearest FP6 value. + throw std::invalid_argument("Input value out of range for FP6."); + } + + // It is not safe to do shift operation on uint16_t. So we promote it to int. + int source_promote = int(source); + + int sign_bit = (source_promote >> 15); + // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' + int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + // Extracting mantissa represented in FP16 + int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp_bit; + int new_mant_bit; + + if (exp_bit == 0) { + // Subnormal FP16 number. Too small for FP6. + new_exp_bit = 0; + new_mant_bit = 0; + } else { + new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; + + // Deal with subnormal FP6 values. + int target_exp_val = exp_bit - exp_bias_fp16; + int min_fp6_exp_val = -exp_bias_fp6 + 1; + bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; + if (subnormal_fp6) { + // TODO(zhen): add the rounding logic. + new_exp_bit = 0; + // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we + // need to add it + new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> + (min_fp6_exp_val - target_exp_val); + } + } + + fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Function to prepack FP16 weights into continuous FP6 values. + * + * Parameters: + * weight_16bit: input weight in FP16, size M*K + * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 + * M, K: the shape of the weight + */ +void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) +{ + // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). + if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } + size_t K_fp6_packed = K * 6 / 8; + // #pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; + uint16_t* ptr_16bit = weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + cast_fp16_fp6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} + +} // namespace + +/* + * Function to execute the FP6 linear kernel. + * + * Parameters: + * output: output tensor, size M*N + * hidden_states: input activation tensor, size N*K + * weights_2bit: packed 2bit weights, size M*K*2/8 + * weights_4bit: packed 4bit weights, size M*K*4/8 + * scales: scale tensor, size M + * workspace: workspace tensor, size M*N*split_k + * M: the output channel number of the weight + * N: the token number of the activation + * K: the input channel number of the weight + * split_k: the split size of the GEMM calculation + */ +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_2bit, + torch::Tensor& weights_4bit, + torch::Tensor& scales, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k) +{ + TORCH_CHECK(weights_2bit.device().type() == torch::kCUDA, "weight_2bit must be on CUDA"); + TORCH_CHECK(weights_4bit.device().type() == torch::kCUDA, "weight_4bit must be on CUDA"); + TORCH_CHECK(hidden_states.device().type() == torch::kCUDA, "X must be on CUDA"); + TORCH_CHECK(scales.device().type() == torch::kCUDA, "scales must be on CUDA"); + + auto status = fp6_linear_kernel(at::cuda::getCurrentCUDAStream(), + (uint4*)(weights_2bit.data_ptr()), + (uint4*)(weights_4bit.data_ptr()), + (half*)(scales.data_ptr()), + (half*)(hidden_states.data_ptr()), + (half*)(output.data_ptr()), + M, + N, + K, + workspace.data_ptr(), + split_k); + if (status != cudaSuccess) { + AT_ERROR("fp6_linear_kernel failed with error: ", cudaGetErrorString(status)); + } +} + +/* + * Function to prepack the fake 6-bit-quantized FP16 weights into 2bit and 4bit. + * + * Parameters: + * weight: input weight in FP16 (containing the quantized FP6-ranged value), size M*K + * Returns: + * weight_2bit: output weight in 2bit, size M*K*2/8 + * weight_4bit: output weight in 4bit, size M*K*4/8 + */ +std::vector preprocess_weight(torch::Tensor& weight) +{ + TORCH_CHECK(weight.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(weight.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = weight.size(0); + auto K = weight.size(1); + TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); + + // Pack weight from FP16 to FP6. + uint16_t* weight_16bit_ptr = reinterpret_cast(weight.data_ptr()); + std::vector weight_6bit_packed(M * K * 6 / 8); + uint8_t* weight_6bit_ptr = weight_6bit_packed.data(); + weight_prepacking_fp16_to_fp6(weight_16bit_ptr, weight_6bit_ptr, M, K); + + // Split weight into 2bit and 4bit. + weight_matrix_prepacking(reinterpret_cast(weight_6bit_ptr), M, K); + uint8_t* weight_2bit_ptr = weight_6bit_ptr; + + // Make sure that the new split tensor does not share the underlying memory with the original + // one. Otherwise it will incur some problems when the original tensor is deleted. It also + // makes the memory flattern risky. + auto weight_2bit = + torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8).clone().detach(); + uint8_t* weight_4bit_ptr = weight_2bit_ptr + M * K * 2 / 8; + auto weight_4bit = + torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8).clone().detach(); + + return {weight_2bit, weight_4bit}; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.h new file mode 100644 index 000000000000..01a6b7c18af8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef DEEPSPEED_CUDA_LINEAR_KERNELS_H +#define DEEPSPEED_CUDA_LINEAR_KERNELS_H + +#include +#include +#include "ds_kernel_utils.h" + +#include "linear_kernels_cuda.h" + +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_2bit, + torch::Tensor& weights_4bit, + torch::Tensor& scale, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k); + +std::vector preprocess_weight(torch::Tensor& Weight); + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu new file mode 100644 index 000000000000..ea0203c42f84 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu @@ -0,0 +1,318 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +// clang-format off +// Put the torch headers at the front to avoid conflict with other headers on +// `at::nullopt` and `at::optional`. +#include +#include +// clang-format on + +#include "include/kernel_matmul.cuh" +#include "include/kernel_reduction.cuh" +#include "include/weight_prepacking.h" + +#include +#include + +#include "linear_kernels_cuda.h" + +template +static void Kernel_Ex(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", + TilingConfig::TILE_M, + TilingConfig::TILE_K, + TilingConfig::TILE_N); +#endif + static size_t SHMEM_SZ = + max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, + TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); + +#ifdef DEBUG_MODE + printf( + "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: " + "%d SHMEM_SZ: %d\n", + GridDim.x, + GridDim.y, + GridDim.z, + BlockDim.x, + BlockDim.y, + BlockDim.z, + SHMEM_SZ); + printf("\n"); +#endif + + QUANT_GEMM_Kernel<<>>( + Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); +} + +/* + * + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K) +{ + assert(M_Global % 256 == 0); + assert(K_Global % 64 == 0); + assert(N_Global > 0); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; + if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; + if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; + if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; + if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; + if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 16: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 32: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 64: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 128: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %lu!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + } + } else { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 16: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 32: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 64: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 128: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %lu!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>( + C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} + +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathematical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in +row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not +perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when +calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: splitting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + assert(num_in_channels % 64 == 0); + assert((num_in_channels / 16 * 3) == + _weights.size(1)); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight1 = reinterpret_cast( + _weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto weight2 = weight1 + num_in_channels * num_out_channels * 2 / 128; + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + float* Reduction_Workspace = nullptr; + if (splitK != 1) { + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast( + _out_feats.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * + // N_Global * sizeof(fp32) + } + + fp6_linear_kernel(0, // Using default stream here. + weight1, + weight2, + scales, + in_feats, + out_feats, + M, + N, + K, + Reduction_Workspace, + splitK); + + return _out_feats; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t OC, size_t IC) +{ + assert((OC % 256 == 0) && (IC % 64 == 0)); + assert((fp6_tensor.size(0) == OC) && (fp6_tensor.size(1) == IC / 16 * 3)); + // auto packed_tensor = torch::empty_like(fp6_tensor); + // auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(fp6_tensor_ptr, OC, IC); + return fp6_tensor; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.h new file mode 100644 index 000000000000..6a83290f0cb5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef DEEPSPEED_CUDA_LINEAR_FP6_LINEAR_CUH +#define DEEPSPEED_CUDA_LINEAR_FP6_LINEAR_CUH + +#include +#include +#include + +#include + +/* + * Computes FP6-FP16 GEMM (C++ interface). + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K); + +/* + * Computes FP6-FP16 GEMM (PyTorch interface). + */ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK = 1); + +/* + * In-place weight prepacking (C++ interface). + */ +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K); + +/* + * Weight prepacking (Pytorch interface). + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K); + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py new file mode 100644 index 000000000000..640a72307650 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .rms_norm import CUDARMSNorm +from .rms_pre_norm import CUDARMSPreNorm diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp new file mode 100644 index 000000000000..c67712df438a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "rms_norm.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +void rms_norm(torch::Tensor& norm_output, + torch::Tensor& norm_input, + torch::Tensor& gamma, + float epsilon) +{ + TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(), + "norm_output and norm_input should have the same data type"); + TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(), + "norm_output and gamma should have the same data type"); + + const int32_t rows = norm_input.size(0); + const int32_t cols = norm_input.size(1); + + TORCH_CHECK(norm_output.size(0) == rows, + "norm_output and norm_input should have the same first dimension"); + TORCH_CHECK(norm_output.size(1) == cols, + "norm_output and norm_input should have the same second dimension"); + + DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] { + scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr()); + scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr()); + scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr()); + scalar_t* null_t = nullptr; + + launch_rms_norm(norm_output_ptr, + null_t, + norm_input_ptr, + null_t, + gamma_ptr, + epsilon, + rows, + cols, + at::cuda::getCurrentCUDAStream()); + }); +} + +void rms_pre_norm(torch::Tensor& norm_output, + torch::Tensor& residual_output, + torch::Tensor& norm_input, + torch::Tensor& residual_input, + torch::Tensor& gamma, + float epsilon) +{ + TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(), + "norm_output and norm_input should have the same data type"); + TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(), + "norm_output and gamma should have the same data type"); + + const int32_t rows = norm_input.size(0); + const int32_t cols = norm_input.size(1); + + TORCH_CHECK(norm_output.size(0) == rows, + "norm_output and norm_input should have the same first dimension"); + TORCH_CHECK(norm_output.size(1) == cols, + "norm_output and norm_input should have the same second dimension"); + + TORCH_CHECK(residual_output.size(0) == rows, + "residual_output and norm_input should have the same first dimension"); + TORCH_CHECK(residual_output.size(1) == cols, + "residual_output and norm_input should have the same second dimension"); + + TORCH_CHECK(residual_input.size(0) == rows, + "residual_input and norm_input should have the same first dimension"); + TORCH_CHECK(residual_input.size(1) == cols, + "residual_input and norm_input should have the same second dimension"); + + DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] { + scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr()); + scalar_t* residual_output_ptr = reinterpret_cast(residual_output.data_ptr()); + const scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr()); + const scalar_t* residual_input_ptr = + reinterpret_cast(residual_input.data_ptr()); + const scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr()); + + launch_rms_norm(norm_output_ptr, + residual_output_ptr, + norm_input_ptr, + residual_input_ptr, + gamma_ptr, + epsilon, + rows, + cols, + at::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h new file mode 100644 index 000000000000..7867fb65964f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +void rms_norm(torch::Tensor& norm_output, + torch::Tensor& norm_input, + torch::Tensor& gamma, + float epsilon); + +void rms_pre_norm(torch::Tensor& norm_output, + torch::Tensor& residual_output, + torch::Tensor& norm_input, + torch::Tensor& residual_input, + torch::Tensor& gamma, + float epsilon); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py new file mode 100644 index 000000000000..deb5d33111a9 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .rms_norm_base import CUDARMSNormBase + + +class CUDARMSNorm(CUDARMSNormBase): + """ + Floating point layer norm kernel for CUDA/RoCM. + + Performs: z = ln(x) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor: + """ + output_z may alias input_x directly. All Tensors should have the same shape. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + """ + self.inf_module.rms_norm(output_z, input_x, gamma, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py new file mode 100644 index 000000000000..62bc9d056ade --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDARMSNormBase(DSKernelBase): + """ + Base class for CUDA LN kernels. They all same the same validation logic, + so we can share it here. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5): + """ + Parameters: + channels (int): Number of channels in the input tensor. Must be divisible to align + to 16 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in CUDARMSNormBase.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDARMSNormBase.supported_dtypes)) + + if elem_size(fp_dtype) * channels % 16 != 0: + raise ValueError("channels must be divisible by 16 bytes") + + self.inf_module = InferenceCoreBuilder().load() + self.epsilon = epsilon diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu new file mode 100644 index 000000000000..e69d3c36cc00 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu @@ -0,0 +1,262 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace rms { +constexpr int granularity = 16; +} // namespace rms + +template +__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + + mem_access::load_global(iteration_buffer, + input_base + (i * stride), + thread_offset + (i * stride) < elems_per_row); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + float up_cast = conversion::to(iteration_buffer[j]); + float sq_val = up_cast * up_cast; + var_sum = reduce::element(var_sum, sq_val); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +template +__global__ void pre_rms_norm(T* output, + T* res_out, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + T* res_output = res_out + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + T residual_buffer[T_per_load]; + + const int iter_offset = i * stride + thread_offset; + const bool do_loads = (iter_offset < elems_per_row); + + mem_access::load_global( + iteration_buffer, input_base + (i * stride), do_loads); + mem_access::load_global( + residual_buffer, residual_base + (i * stride), do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] += residual_buffer[j]; + float vals_up_cast = conversion::to(iteration_buffer[j]); + + var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast); + } + + if (do_loads) { + mem_access::store_global(res_output + i * stride, iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + rms_norm \ + <<>>(norm_output, vals, gamma, epsilon, elems_per_row); + +#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + pre_rms_norm<<>>( \ + norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row); + +#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + if (pre_norm) { \ + LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } else { \ + LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = rms::granularity / sizeof(T); + constexpr int maxThreads = 256; + constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threads_per_group, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threads_per_group * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + bool pre_norm = (residual == nullptr) ? false : true; + + if (is_subblock_schedule) { + // <=128 + if (threads_per_group == 1) { + LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); + } else if (threads_per_group == 2) { + LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); + } else if (threads_per_group == 4) { + LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); + } else if (threads_per_group == 8) { + LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); + } else if (threads_per_group == 16) { + LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ + template void launch_rms_norm(T * norm_output, \ + T * res_output, \ + const T* vals, \ + const T* residual, \ + const T* gamma, \ + float epsilon, \ + int rows, \ + int elems_per_row, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_RMS_NORM(float) +INSTANTIATE_LAUNCH_RMS_NORM(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py new file mode 100644 index 000000000000..3b040d88b50f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from .rms_norm_base import CUDARMSNormBase + + +class CUDARMSPreNorm(CUDARMSNormBase): + """ + Floating point pre-LayerNorm kernel for CUDA/RoCM. + + Performs: z_res = x_res + y_hid + z_hid = ln(z_hid) + """ + + def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor, + gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + z_res can alias x_res. All non-parameter input/output tensors + must have the same shape. z_hid can alias y_hid. + + Parameters: + z_res (torch.Tensor): Output residual. + z_hid (torch.Tensor): Output hidden states. + x_res (torch.Tensor): Input residual. + y_hid (torch.Tensor): Input hidden states. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.rms_pre_norm(z_hid, z_res, y_hid, x_res, gamma, self.epsilon) + return z_res, z_hid diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py new file mode 100644 index 000000000000..05479d86c906 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .gated_activation import * diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py new file mode 100644 index 000000000000..ca1b62ba5c36 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDAGatedActivation(DSKernelBase): + """ + CUDA implementation of gated activation kernel. This kernel assumes that the input + tensor has gate and activation values in adjacent channels. The output tensor should + have half the dimensionality of the input tensor. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + supported_act_fns = [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU] + + def __init__(self, channels: int, fp_dtype: torch.dtype, act_fn: ActivationType) -> None: + """ + Compile and validate for the gated activation function. + + Args: + channels (int): Number of columns in the output tensor. Must be divisible to align + to 8 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + act_fn (ActivationType): Activation function to use. Only GEGLU is supported. + """ + if fp_dtype not in CUDAGatedActivation.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDAGatedActivation.supported_dtypes)) + + act_fn = ActivationType(act_fn) + if act_fn not in CUDAGatedActivation.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, CUDAGatedActivation.supported_act_fns)) + + if elem_size(fp_dtype) * channels % 8 != 0: + raise ValueError("Channels must be divisible by 16 bytes") + + if elem_size(fp_dtype) * channels > 98304: + raise ValueError( + "Kernel only compiled to support 98304 bytes per row, please file an issue if your model requires more." + ) + + self.inf_module = InferenceCoreBuilder().load() + self.act_fn = act_fn + self.kernel = self.inf_module.gated_activation + + def __call__(self, output: torch.Tensor, input: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None: + """ + Performs gated activation on the input tensor, writing the result to the output tensor. + + Args: + output (torch.Tensor): Output tensor. Can be of [T, C // 2] or [B, S, C // 2] + input (torch.Tensor): Input tensor. Can be of [T, C] or [B, S, C] + """ + self.kernel(output, input, bias, self.act_fn.value) diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp new file mode 100644 index 000000000000..05463c75138c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "gated_activation_kernels.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +void ds_gated_activation(at::Tensor& output, + at::Tensor& input, + c10::optional& bias, + int activation_type_raw) +{ + bool ragged_input = input.dim() == 2; + + const ActivationType activation_type = static_cast(activation_type_raw); + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int cols = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_FOR_FLOAT(input.scalar_type(), [&] { + scalar_t* bias_ptr = nullptr; + if (bias.has_value()) { + TORCH_CHECK(bias.value().scalar_type() == input.scalar_type(), + "Bias type must match input type"); + TORCH_CHECK(bias.value().numel() == cols, + "Bias must have the same number of elements as the input channels"); + bias_ptr = reinterpret_cast(bias.value().data_ptr()); + } + + scalar_t* output_ptr = reinterpret_cast(output.data_ptr()); + const scalar_t* input_ptr = reinterpret_cast(input.data_ptr()); + + launch_gated_activation(output_ptr, + input_ptr, + bias_ptr, + rows, + cols, + activation_type, + c10::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h new file mode 100644 index 000000000000..6ae01e99679a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "activation_type.h" +#include "ds_kernel_utils.h" + +template +void launch_gated_activation(T* output, + const T* vals, + const T* bias, + int rows, + int cols, + ActivationType activation_type, + cudaStream_t stream); + +void ds_gated_activation(at::Tensor& output, + at::Tensor& input, + c10::optional& bias, + int activation_type_raw); diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu new file mode 100644 index 000000000000..fc14b1831361 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "activation_type.h" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace gated_act { + +constexpr int access_size = 16; +constexpr int threads = 1024; + +template +DS_D_INLINE float gated_act_fn(float x, float y); + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; + return y * x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + return y * (x > 0.0f ? x : 0.0f); +} + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + return y * (x / (1.0f + expf(-x))); +} + +} // namespace gated_act + +template +__global__ void gated_activation_kernel(T* output, + const T* input, + const T* bias, + int rows, + int cols) +{ + constexpr int read_vector = gated_act::access_size / sizeof(T); + constexpr int write_vector = read_vector / 2; + + const int row = blockIdx.x; + const int col = threadIdx.x * read_vector; + + const T* input_row = input + row * cols; + T* output_row = output + row * cols / 2; + +#pragma unroll + for (int i = 0; i < loopUnroll; i++) { + T read[read_vector]; + T bias_read[read_vector]; + T store[write_vector]; + + const int read_offset = col + gated_act::threads * read_vector * i; + const int write_offset = col / 2 + gated_act::threads * write_vector * i; + + if (i != loopUnroll - 1 || read_offset < cols) { + mem_access::load_global(read, input_row + read_offset); + mem_access::load_global( + bias_read, bias + read_offset, bias != nullptr); + + for (int j = 0; j < write_vector; j++) { + float g_val = + conversion::to(read[j * 2]) + conversion::to(bias_read[j * 2]); + float a_val = conversion::to(read[j * 2 + 1]) + + conversion::to(bias_read[j * 2 + 1]); + + float act_val = gated_act::gated_act_fn(g_val, a_val); + store[j] = conversion::to(act_val); + } + + mem_access::store_global(output_row + write_offset, store); + } + } +} + +#define DISPATCH_UNROLL(unroll_val) \ + gated_activation_kernel \ + <<>>(output, input, bias, rows, cols); + +template +void launch_gated_activation_impl(T* output, + const T* input, + const T* bias, + int rows, + int cols, + cudaStream_t stream) +{ + constexpr int read_vector = gated_act::access_size / sizeof(T); + constexpr int cols_per_unroll = gated_act::threads * read_vector; + const int req_threads = (cols + read_vector - 1) / read_vector; + const int threads = std::min(req_threads, gated_act::threads); + + const dim3 grid(rows); + const dim3 block(threads); + const int unroll = (cols + cols_per_unroll - 1) / cols_per_unroll; + + if (unroll == 1) { + DISPATCH_UNROLL(1); + } else if (unroll == 2) { + DISPATCH_UNROLL(2); + } else if (unroll == 3) { + DISPATCH_UNROLL(3); + } else if (unroll == 4) { + DISPATCH_UNROLL(4); + } else if (unroll == 5) { + DISPATCH_UNROLL(5); + } else if (unroll == 6) { + DISPATCH_UNROLL(6); + } else if (unroll == 7) { + DISPATCH_UNROLL(7); + } else { + // TODO: provide a kernel with an outer loop to handle larger columns. + throw std::runtime_error( + "Called with more columns than supported, please report this bug and this limit will " + "be increased."); + } +} + +template +void launch_gated_activation(T* output, + const T* input, + const T* bias, + int rows, + int cols, + ActivationType act_type, + cudaStream_t stream) +{ + switch (act_type) { + case ActivationType::GEGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + case ActivationType::ReGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + case ActivationType::SiGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + default: throw std::runtime_error("Unsupported activation type"); + } +} + +#define INSTANTIATE_FOR_TYPE(T) \ + template void launch_gated_activation(T * output, \ + const T* input, \ + const T* bias, \ + int rows, \ + int cols, \ + ActivationType act_type, \ + cudaStream_t stream); + +INSTANTIATE_FOR_TYPE(float) +INSTANTIATE_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE b/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py new file mode 100644 index 000000000000..44b9adbae794 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_gemm import * +from .moe_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp b/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp new file mode 100644 index 000000000000..18e834f3e60a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "mixed_gemm.h" +#include "moe_gemm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // mixed_gemm.h + m.def("mixed_gemm", &mixed_gemm, "Mixed-precision GEMM"); + + // moe_gemm.h + m.def("moe_gemm", &moe_gemm, "MultiGEMM for MoE (16-bit weights)"); + m.def("mixed_moe_gemm", &mixed_moe_gemm, "MultiGEMM for MoE (4-bit/8-bit weights)"); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py new file mode 100644 index 000000000000..14ccf2ce5354 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu new file mode 100644 index 000000000000..7c522203bb48 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "mixed_gemm.h" +#include "mixed_gemm_api.h" +#include "weight_variant.h" + +// Switch helpers inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#define ACT_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + }() + +#define WEIGHT_VARIANT_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + constexpr WeightVariant WVariant = WeightVariant::kFP8; \ + return __VA_ARGS__(); \ + } else { \ + constexpr WeightVariant WVariant = WeightVariant::kFP4; \ + return __VA_ARGS__(); \ + } \ + }() + +void mixed_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + int num_bits, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + TORCH_CHECK(num_bits == 4 || num_bits == 8, "Data width must be 4 or 8"); + TORCH_CHECK(output.size(0) == hidden_states.size(0), "Token dimension mismatch"); + + int32_t m = output.size(0); + int32_t k = hidden_states.size(1); + int32_t n = weight.size(1); + + TORCH_CHECK(weight.size(0) == k, "Weight dimension mismatch"); + + ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] { + fastertransformer::CutlassFpAIntBGemmRunner runner = + *MixedGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.gemm((ActivationDtype*)hidden_states.data_ptr(), + (const char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + (ActivationDtype*)output.data_ptr(), + m, + n, + k, + nullptr, + 0, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { bias_ptr = (ActivationDtype*)bias.value().data_ptr(); } + runner.gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + bias_ptr, + (ActivationDtype*)output.data_ptr(), + m, + n, + k, + activation_type, + nullptr, + 0, + at::cuda::getCurrentCUDAStream()); + return; + } + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h new file mode 100644 index 000000000000..1fc3831e9084 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void mixed_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + int num_bits, + int activation_raw); diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py new file mode 100644 index 000000000000..dddb555e267a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MixedGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MixedGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MixedGEMM.supported_dtypes)) + + if act_fn not in MixedGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MixedGEMM.supported_act_fns)) + + if num_bits != 4 and num_bits != 8: + raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits)) + + inf_module = InferenceCutlassBuilder().load() + self.num_bits = num_bits + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + output: torch.Tensor, + hidden_states: torch.Tensor, + weights: torch.Tensor, + scales: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + hidden_states (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [in_neurons, out_neurons]. These weights must be contiguous. + scales (torch.Tensor): The scales of shape [out_neurons]. These scales must be contiguous. + biases (torch.Tensor): The biases of shape [out_neurons]. These biases must be contiguous. + + Returns: + output + """ + self.kernel(output, hidden_states, weights, biases, self.num_bits, self.act_fn) + return output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h new file mode 100644 index 000000000000..74fc07ffc4a2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "activation_type.h" +#include "weight_variant.h" + +namespace fastertransformer { + +template +class CutlassFpAIntBGemmRunner { +public: + void gemm(const T* A, + const char* B, + const T* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act(const T* A, + const char* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + ActivationType activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); +}; + +} // namespace fastertransformer + +template +class MixedGemmContext { +public: + MixedGemmContext() { _runner = new fastertransformer::CutlassFpAIntBGemmRunner(); } + + virtual ~MixedGemmContext() { delete _runner; } + + static MixedGemmContext& Instance() + { + static MixedGemmContext _ctx; + return _ctx; + } + + fastertransformer::CutlassFpAIntBGemmRunner* GeMM_Runner() const { return _runner; } + + fastertransformer::CutlassFpAIntBGemmRunner* _runner; +}; diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py new file mode 100644 index 000000000000..aff4e77bba98 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_moe_gemm import * +from .moe_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py new file mode 100644 index 000000000000..9c55ce341532 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MixedMoEGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MixedMoEGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MixedMoEGEMM.supported_dtypes)) + + if act_fn not in MixedMoEGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MixedMoEGEMM.supported_act_fns)) + + if num_bits != 4 and num_bits != 8: + raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits)) + + inf_module = InferenceCutlassBuilder().load() + self.num_bits = num_bits + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + ordered_output: torch.Tensor, + ordered_input: torch.Tensor, + weights: torch.Tensor, + scales: torch.Tensor, + total_rows_before_expert: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous. + scales (torch.Tensor): The scales of shape [n_experts, out_neurons]. These scales must be contiguous. + total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts]. + biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous. + + Returns: + ordered_output + """ + self.kernel(ordered_output, ordered_input, weights, scales, biases, total_rows_before_expert, self.num_bits, + self.act_fn) + return ordered_output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu new file mode 100644 index 000000000000..d1cafc9fff4c --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "moe_gemm.h" +#include "moe_gemm_api.h" +#include "weight_variant.h" + +// Switch helpers inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#define HIDDEN_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + constexpr WeightVariant WVariant = WeightVariant::kFP16; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + constexpr WeightVariant WVariant = WeightVariant::kBF16; \ + return __VA_ARGS__(); \ + } \ + }() + +void moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + TORCH_CHECK(output.dtype() == weight.dtype(), "Output and weight must have the same dtype"); + + int64_t total_rows = hidden_states.size(0); + int64_t gemm_k = hidden_states.size(1); + int64_t gemm_n = weight.size(2); + int num_experts = weight.size(0); + + TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch"); + TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch"); + TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch"); + TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch"); + + HIDDEN_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + fastertransformer::MoeGemmRunner runner = + *MoeGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + nullptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { + bias_ptr = (ActivationDtype*)bias.value().data_ptr(); + TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch"); + TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch"); + } + runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + nullptr, + bias_ptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + activation_type, + at::cuda::getCurrentCUDAStream()); + return; + } + }); +} + +#define ACT_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + }() + +#define WEIGHT_VARIANT_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + constexpr WeightVariant WVariant = WeightVariant::kFP8; \ + return __VA_ARGS__(); \ + } else { \ + constexpr WeightVariant WVariant = WeightVariant::kFP4; \ + return __VA_ARGS__(); \ + } \ + }() + +void mixed_moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int num_bits, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + + int64_t total_rows = hidden_states.size(0); + int64_t gemm_k = hidden_states.size(1); + int64_t gemm_n = weight.size(2); + int num_experts = weight.size(0); + + TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch"); + TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch"); + TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch"); + TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch"); + + ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] { + fastertransformer::MoeGemmRunner runner = + *MoeGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { + bias_ptr = (ActivationDtype*)bias.value().data_ptr(); + TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch"); + TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch"); + } + runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + bias_ptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + activation_type, + at::cuda::getCurrentCUDAStream()); + return; + } + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h new file mode 100644 index 000000000000..dfd3d4561567 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int activation_raw); + +void mixed_moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int num_bits, + int activation_raw); diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py new file mode 100644 index 000000000000..0cc233e8d87a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MoEGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MoEGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MoEGEMM.supported_dtypes)) + + if act_fn not in MoEGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MoEGEMM.supported_act_fns)) + + inf_module = InferenceCutlassBuilder().load() + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + ordered_output: torch.Tensor, + ordered_input: torch.Tensor, + weights: torch.Tensor, + total_rows_before_expert: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous. + total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts]. + biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous. + + Returns: + ordered_output + """ + self.kernel(ordered_output, ordered_input, weights, biases, total_rows_before_expert, self.act_fn) + return ordered_output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h new file mode 100644 index 000000000000..7ad92070b35f --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "activation_type.h" +#include "weight_variant.h" + +namespace fastertransformer { + +template +class MoeGemmRunner { +public: + MoeGemmRunner(); + + void moe_gemm_bias_act(const T* A, + const char* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + ActivationType activation_type, + cudaStream_t stream); + + void moe_gemm(const T* A, + const char* B, + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream); + +private: + int sm_; + int multi_processor_count_; +}; + +} // namespace fastertransformer + +template +class MoeGemmContext { +public: + MoeGemmContext() { _runner = new fastertransformer::MoeGemmRunner(); } + + virtual ~MoeGemmContext() { delete _runner; } + + static MoeGemmContext& Instance() + { + static MoeGemmContext _ctx; + return _ctx; + } + + fastertransformer::MoeGemmRunner* GeMM_Runner() const { return _runner; } + + fastertransformer::MoeGemmRunner* _runner; +}; diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h b/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h new file mode 100644 index 000000000000..4d17c799f726 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// Data structure that allows us to abstract internal CUTLASS datatypes/mappings +// to the DeepSpeed-Kernels repo. + +#pragma once + +enum WeightVariant { kFP16, kBF16, kFP8, kFP4 }; diff --git a/deepspeed/inference/v2/kernels/ds_kernel.py b/deepspeed/inference/v2/kernels/ds_kernel.py new file mode 100644 index 000000000000..8dbfa1de86a6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ds_kernel.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod + + +class DSKernelBase(ABC): + + @abstractmethod + def __init__(self, *args, **kwargs): + """ + If necessary trigger compilation and warmup + Autotuning of the kernel would happen at this stage to + eliminate any potential hangs that might occur mid-deployment + Validate that the desired run configuration is compatible. + + It is not necessary to call super on this method. + """ + raise NotImplementedError() + + @abstractmethod + def __call__(self, *args, **kwargs): + """ + However the kernel needs to be called, it can be called here. Auto-tuning + should never be performed here. + + All inputs/outputs should be passed as arguments to this function. No allocations + should be performed here. + """ + raise NotImplementedError() diff --git a/deepspeed/inference/v2/kernels/includes/activation_type.h b/deepspeed/inference/v2/kernels/includes/activation_type.h new file mode 100644 index 000000000000..a44921d5d650 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/activation_type.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +enum ActivationType { + GELU = 0, + RELU = 1, + SILU = 2, + GEGLU = 3, + ReGLU = 4, + SiGLU = 5, + IDENTITY = 6, + InvalidType = -1 +}; diff --git a/deepspeed/inference/v2/kernels/includes/conversion_utils.h b/deepspeed/inference/v2/kernels/includes/conversion_utils.h new file mode 100644 index 000000000000..d6d8f11e0854 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/conversion_utils.h @@ -0,0 +1,708 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +template <> +DS_D_INLINE double to(int64_t val) +{ + return __ll2double_rn(val); +} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +template <> +DS_D_INLINE __half to(double val) +{ +#ifdef __HIP_PLATFORM_AMD__ + float val_f = __double2float_rn(val); + return __float2half(val_f); +#else + return __double2half(val); +#endif +} +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ll2double_rn(val)); +#else + return __ll2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else + return __int2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else + return __short2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else + return __int2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ull2double_rn(val)); +#else + return __ull2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else + return __uint2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else + return __ushort2bfloat16_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else + return __uint2bfloat16_rn(val); +#endif +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __bfloat162bfloat162(__float2bfloat16(val)); +#else + return __float2bfloat162_rn(val); +#endif +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2ll_rn(__bfloat162float(val)); +#else + return __bfloat162ll_rn(val); +#endif +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else + return __bfloat162int_rn(val); +#endif +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else + return __bfloat162int_rn(val); +#endif +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else + return __bfloat162int_rn(val); +#endif +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2ull_rn(__bfloat162float(val)); +#else + return __bfloat162ull_rn(val); +#endif +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else + return __bfloat162uint_rn(val); +#endif +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else + return __bfloat162uint_rn(val); +#endif +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else + return __bfloat162uint_rn(val); +#endif +} +#endif + +} // namespace conversion diff --git a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h new file mode 100644 index 000000000000..f8b16ee6a315 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +#ifdef __HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; +#define HALF_PRECISION_AVAILABLE = 1 +#include +#include + +#else // !__HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 32; + +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE +#endif // __CUDA_ARCH__ >= 530 + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#endif // __CUDA_ARCH__ >= 800 + +#include +#include + +#endif //__HIP_PLATFORM_AMD__ + +inline int next_pow2(const int val) +{ + int rounded_val = val - 1; + rounded_val |= rounded_val >> 1; + rounded_val |= rounded_val >> 2; + rounded_val |= rounded_val >> 4; + rounded_val |= rounded_val >> 8; + return rounded_val + 1; +} diff --git a/deepspeed/inference/v2/kernels/includes/memory_access_utils.h b/deepspeed/inference/v2/kernels/includes/memory_access_utils.h new file mode 100644 index 000000000000..6789714d27c7 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/memory_access_utils.h @@ -0,0 +1,1115 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "ds_kernel_utils.h" + +/////////////////////////////// Memory Access Utils /////////////////////////////// +namespace mem_access { + +enum class LoadPolicy { + CacheAll, // Cache at all levels + CacheGlobal, // Cache at L2 only + CacheStreaming // Cache with evict first policy +}; + +enum class StorePolicy { + Writeback, // Cache in L1, write-back on eviction + CacheGlobal, // Bypass L1, write-back on eviction + CacheStreaming // Allocate cache line with evict first policy +}; + +template +__device__ __forceinline__ void load_global(void* dst, const void* src); + +template +__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void load_shared(void* dst, const void* src); + +template +__device__ __forceinline__ void load_shared(void* dst, const void* src, bool do_access); + +template +__device__ __forceinline__ void store_global(void* dst, const void* src); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void store_shared(void* dst, const void* src); + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl); + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate); + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate); + +__device__ __forceinline__ void memcpy_async_fence(); + +template +__device__ __forceinline__ void memcpy_async_wait(); + +template +__device__ __forceinline__ void tail_complete_wait(int remaining_stages); +#endif + +// Util for tracking pipeline buffers +// TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE +template +class BufferTracker { +public: + int current_state; + + __device__ __forceinline__ BufferTracker() : current_state(0) {} + + __device__ __forceinline__ int get() + { + int return_val = current_state++; + current_state = (current_state == max ? 0 : current_state); + return return_val; + } +}; + +__device__ __forceinline__ uint32_t lane_id() +{ +#ifdef PTX_AVAILABLE + unsigned int lane_id; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +#else + return threadIdx.x & (warpSize - 1); // Portable +#endif +} + +/////////// Load Global /////////// +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cg.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cs.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src, bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cg.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cs.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Load Shared /////////// +namespace internal { + +#ifdef PTX_AVAILABLE +__device__ __forceinline__ unsigned convert_to_shared(const void* ptr) +{ +#if __CUDACC_VER_MAJOR__ >= 11 + // In CUDA 11 we have a builtin intrinsic + return __cvta_generic_to_shared(ptr); +#else + unsigned ret_val; + asm volatile( + "{\n" + "\t.reg .u64 p1;\n" + "\tcvta.to.shared.u64 p1, %1\n" + "\tcvt.u32.u64 %0, p1;\n" + "}\n" + : "=r"(ret_val) + : "l"(ptr)); + return ret_val; +#endif +} +#endif + +} // namespace internal + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.shared.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.shared.u32 %0, [%1];\n" + "}\n" + : "=r"(data[0]) + : "r"(src_shr), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Store Global /////////// + +template <> +__device__ __forceinline__ void store_global<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Store Shared /////////// + +template <> +__device__ __forceinline__ void store_shared<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Asynchronous Memory Copy /////////// + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? AccessSize : 0); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +template +__device__ __forceinline__ void memcpy_async_zero_nop(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? AccessSize : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +// Cache global variants. Separate interface to require deliberate use of them. +__device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? 16 : 0); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" + : + : "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? 16 : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); } + +template +__device__ __forceinline__ void memcpy_async_wait() +{ + static_assert(stages <= 8); + + asm volatile("cp.async.wait_group %0;\n" : : "n"(stages)); +} + +// TODO: The tail complete should be a known compile time artifact, should try and induce this +// without all of the branches from the call-site. This is a hacky solution. +template <> +__device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages) +{ + if (remaining_stages == 0) memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages) +{ + if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages) +{ + if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages) +{ + if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages) +{ + if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages) +{ + if (remaining_stages == 5) + memcpy_async_wait<5>(); + else if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} +#endif + +} // namespace mem_access diff --git a/deepspeed/inference/v2/kernels/includes/reduction_utils.h b/deepspeed/inference/v2/kernels/includes/reduction_utils.h new file mode 100644 index 000000000000..eb8efab77ac1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/reduction_utils.h @@ -0,0 +1,778 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace reduce { + +enum class ROpType { + // Addition + Add, + + // Maximum reduction + Max, + + // Minimum reduction + Min, +}; + +constexpr int max_threads = 1024; +constexpr int max_warps = max_threads / hw_warp_size; + +/* +High level API. The API takes in a set of operations and variables +and performs that reduction operation on that variable. The reductions +of each of the arguments are completely independent of each other ( +i.e., the val1-op1 combination has no impact on val2-op2). + +Example usage: +``` cpp +float max_val; +float min_val; +reduce::block(tb, warp, max_val, min_val); +``` + +TODO(cmikeh2): In theory, we might be able to do this sequentially with +device functions and rely on the assembler correctly behaving. My initial +instinct is this won't work, but if it does it would reduce implementation +cost significantly. + +TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic +currently supports this (more incidentally than anything else). It is not +uncommon in something like softmax or a fused attention kernel to map multiple +reductions to a thread block, but each reduction itself is only scoped +to part of the threads (i.e block size = 512, 128 threads per reduction). +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +The partitioned block is a special case of the above where in the warps of a threadblock are +partitioned into separate independent reductions. For example, I might have an 8 warp thread block +in which each pair of warps is processing an independent piece of data. I would then reduce that +data with the something like the following: +``` cpp +float max_val; +reduce::partitioned_block(tb, warp, max_val); +``` +After which, each pair of warps would have coherent data with each other. Note, this API will not +provide correct results if the number of warps per partition is not a power of 2. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +Single element reduction primitives. Used inside serial collection +loops. + +Example usage: +using rop = reduce::OpType; +float min = init(); +for (int i = 0; i < 4; i++) { + min = reduce::element(min, data[i]); +} +*/ + +template +DS_D_INLINE T element(const T lhs, const T rhs); + +template +DS_D_INLINE T init(); + +/********************** Internal reduction APIs **********************/ + +/* +Single element "reductions". TODO(cmikeh2): this sort of "op" concept +should be refactored into its own implementation at some point. This interface +may be easily expanded for new types/operations, but the typical reductions +we need are covered with min/max/add on float. + +NOTE: there is no mean reduction because that relies on knowledge of how +many values were already reduced into each scalar. Implementing this on top +of reduce should be straightforward (can just wrap the sum reduction) and +would be a good extension of the header. +*/ + +DS_D_INLINE int _warp_rank() +{ + const int thread_rank = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + return thread_rank / hw_warp_size; +} + +/* Float element reduce implementations */ +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fmaxf(lhs, rhs); +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fminf(lhs, rhs); +} + +/* __half element reduce implementation */ +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmin(lhs, rhs); +#else + return (lhs < rhs) ? lhs : rhs; +#endif +} + +/* __half2 element reduce implementation */ +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmin2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +/* +Reduction initialization primitives +*/ +template <> +DS_D_INLINE float init() +{ + return 0.0f; +} + +template <> +DS_D_INLINE float init() +{ + // Positive infinity + return INFINITY; +} + +template <> +DS_D_INLINE float init() +{ + // Negative infinity + return -INFINITY; +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw zero = {0x0000}; + return __half(zero); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw inf = {0x7C00}; + return __half(inf); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw neg_inf = {0xFC00}; + return __half(neg_inf); +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x0000, 0x0000}}; +#else + constexpr __half2_raw zero = {0x0000, 0x0000}; + return __half2(zero); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x7C00, 0x7C00}}; +#else + constexpr __half2_raw inf = {0x7C00, 0x7C00}; + return __half2(inf); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0xFC00, 0xFC00}}; +#else + constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; + return __half2(neg_inf); +#endif +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); + data[3] = init(); +} + +/* +Warp reduction primitives + +`reduction_width` is an unsafe template parameter, that is that +when using `reduction_width` < hw_warp_size the warp is partitioned +into `hw_warp_size` / `reduction_width` groups of partial sums. + +If someone can figure out how to use variadic templates in a reasonable way +here (fold is C++17 only and I don't think helps and recursion feels like +huge overkill that harms readability) that would be wonderful. +*/ + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[3] = element(data[3], warp.shfl_xor(data[3], i)); + } +} + +/* +Implementation for primary block reduction that serves both `block` and +`partitioned_block`. + +Total warps refers to the reduction width of the reduction, not +the number of warps in the block (which may exceed that +if the block is partitioned or if we do a conservative bound at +compile time). +*/ +template +DS_D_INLINE void _block(cg::thread_block& tb, + cg::thread_block_tile& warp_arg, + T* data) +{ + constexpr int elems = sizeof...(Ops); + constexpr int bytes = sizeof(T); + // Unused when `partition_size == 1` or total_warps == 1 + __shared__ T reduce_buffer[max_warps * elems]; + +#ifdef __HIP_PLATFORM_AMD__ + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + const int running_warps = total_threads / hw_warp_size; +#else + const int running_warps = warp_arg.meta_group_size(); +#endif + + // Always perform warp-scope reduction + _warp(warp_arg, data); + + // If max_warps == 1 let's skip the runtime check + if (total_warps != 1) { + if (warp_arg.thread_rank() == 0) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + + if (_warp_rank() == 0) { + if (warp_arg.thread_rank() < running_warps) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared( + data + i, reduce_buffer + elems * warp_arg.thread_rank() + i); + } + } else { + init(data); + } + + _warp(warp_arg, data); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, + data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); + } + } +} + +/* +Main API implementations. For the most part, they just convert the individual +variables into arrays, which makes working with them easier with a single +implementation. In theory, we could use the `_block` implementation as another +option, but the nature of using a pointer is a little less safe and this allows +us to obfuscate the details of the partitioned implementation. +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) +{ + _block(tb, warp, &val); +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order +to shorten block scale reduction length. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val) +{ + if (num_threads <= hw_warp_size) { + _warp(warp, &val); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, &val); + } +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + +} // namespace reduce diff --git a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py new file mode 100644 index 000000000000..38a4ebd6fba3 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .atom_builder import * +from .blocked_flash import * +from .embed import * +from .linear_blocked_kv_rotary import * +from .logits_gather import * +from .moe_gather import * +from .moe_scatter import * +from .top_k_gating import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py new file mode 100644 index 000000000000..c79201cdf165 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .atom_builder import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp new file mode 100644 index 000000000000..7ad4dc5faa20 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "atom_builder.h" +#include "attention_atom.h" +#include "ragged_dtypes.h" + +int32_t build_atoms(torch::Tensor& atoms_ten, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& kv_ptrs, + const int32_t q_block_size, + const int32_t kv_block_size) +{ + const RaggedBatchDescriptor* batch_desc = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_desc = + reinterpret_cast(seq_metadata.data_ptr()); + + int32_t** kv_ptr_list = reinterpret_cast(kv_ptrs.data_ptr()); + + AttentionAtom* atoms = reinterpret_cast(atoms_ten.data_ptr()); + + int32_t n_atoms = 0; + for (int i = 0; i < batch_desc->n_sequences; i++) { + const int seq_atoms = (seq_desc[i].n_tokens + q_block_size - 1) / q_block_size; + int32_t cur_start_idx = seq_desc[i].start_idx; + int32_t global_start_idx = seq_desc[i].seen_tokens; + int32_t remaining_toks = seq_desc[i].n_tokens; + + for (int j = 0; j < seq_atoms; j++) { + atoms[n_atoms].block_idx_list = kv_ptr_list[i]; + atoms[n_atoms].q_start_idx = cur_start_idx; + atoms[n_atoms].q_len = std::min(remaining_toks, q_block_size); + atoms[n_atoms].global_q_idx = global_start_idx; + + const int32_t end_toks = global_start_idx + atoms[n_atoms].q_len; + // TODO(cmikeh2): This logic needs to be changed for sparse implementations + atoms[n_atoms].kv_blocks = (end_toks + kv_block_size - 1) / kv_block_size; + atoms[n_atoms].total_extent = end_toks; + + cur_start_idx += atoms[n_atoms].q_len; + global_start_idx += atoms[n_atoms].q_len; + remaining_toks -= atoms[n_atoms].q_len; + n_atoms++; + } + } + + return n_atoms; +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h new file mode 100644 index 000000000000..a3342d0e6695 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +/* +Construct the attention atoms given the ragged metadata for the current batch. +This could largely be done at the Python level, but since we pack the KV ptr +alongside the int32_t metadata, it gets very ugly to handle the mixed-width +data structures (since we're packing them in a single tensor). +*/ +int32_t build_atoms(torch::Tensor& atoms_ten, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& kv_ptrs, + const int32_t q_block_size, + const int32_t kv_block_size); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py new file mode 100644 index 000000000000..3355ca76c6a4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper + + +class AtomBuilder(DSKernelBase): + """ + C++ implementation to populate the attention atoms for the blocked attention + kernel. + """ + + def __init__(self) -> None: + """ + Triggers compilation of the C++ implementation. + """ + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.build_atoms + + def __call__(self, atoms: torch.Tensor, ragged_batch: RaggedBatchWrapper, q_block_size: int, + kv_block_size: int) -> Tuple[torch.Tensor, int]: + """ + Populates the attention atoms for the blocked attention kernel. + + Args: + atoms (torch.Tensor): Pre-allocated int32 tensor of shape [max_atoms, 8] + ragged_batch (torch.Tensor): Wrapper for the ragged batch. + q_block_size (int): The block size for the queries (as determined by the + attention implementation) + kv_block_size (int): The block size for the keys/values (as determined by the + attention implementation) + + Returns: + + """ + if atoms.device != torch.device("cpu"): + raise RuntimeError("AtomBuilder must be called on tensors") + + n_atoms = self.kernel(atoms, ragged_batch.batch_metadata_buffer(on_device=False), + ragged_batch.inflight_seq_descriptors(on_device=False), + ragged_batch.kv_ptrs(on_device=False), q_block_size, kv_block_size) + return atoms, n_atoms diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py new file mode 100644 index 000000000000..87b2b3d68777 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blocked_flash import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h new file mode 100644 index 000000000000..ed8eb9e19b3d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "cuda.h" + +struct AttentionAtom { + /* + The attention atom describes the workload of a particular query. The attention + kernel will execute each ``AttentionAtom`` for each head of the model. + */ + + // Pointer to a list of KV block indices. + int32_t* block_idx_list; + + // Index of first token in the ragged batch associated with this atom. + int32_t q_start_idx; + + // Number of tokens in the ragged batch associated with this atom. + int32_t q_len; + + // Number of key/value blocks associated with this atom. All but the last are + // assumed to be fully dense. + int32_t kv_blocks; + + // Number of tokens in the last key/value block. + int32_t total_extent; + + // Global index of the first token in the atom. For example, in a prompt continuation + // in which we have already processed 768 tokens, this would be 768. + int32_t global_q_idx; + + // Unused + int32_t unused; +}; diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp new file mode 100644 index 000000000000..6cb60ad00f9b --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +#include "blocked_flash.h" +#include "flash.h" + +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +void flash_attn_by_atoms(at::Tensor& out, + at::Tensor& q, + at::Tensor& k, + at::Tensor& v, + at::Tensor& attention_atoms, + const float softmax_scale, + const bool is_causal) +{ + auto dprops = at::cuda::getCurrentDeviceProperties(); + + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const int total_q = q.size(0); + const int head_size = k.size(-1); + const int num_heads_kv = k.size(-2); + const int num_heads_q = q.size(-1) / head_size; + + TORCH_CHECK(head_size <= 256, "head_size must be <= 256"); + TORCH_CHECK(head_size % 8 == 0, "head_size must be divisible by 8"); + TORCH_CHECK(num_heads_q % num_heads_kv == 0, "num_heads_q must be divisible by num_heads_kv"); + + Flash_fwd_params params; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.o_ptr = out.data_ptr(); + params.atoms = reinterpret_cast(attention_atoms.data_ptr()); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(0); + params.k_row_stride = k.stride(1); + params.v_row_stride = v.stride(1); + params.o_row_stride = out.stride(0); + + // Assume heads are contiguous. + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_head_stride = head_size; + + // Head params + params.h = num_heads_q; + params.h_k = num_heads_kv; + params.h_h_k_ratio = num_heads_q / num_heads_kv; + params.d = head_size; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.d_rounded = round_multiple(head_size, 32); + params.num_atoms = attention_atoms.size(0); + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.is_causal = is_causal; + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h new file mode 100644 index 000000000000..68037b425113 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void flash_attn_by_atoms(at::Tensor& out, + at::Tensor& q, + at::Tensor& k, + at::Tensor& v, + at::Tensor& attention_atoms, + const float softmax_scale, + const bool is_causal); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py new file mode 100644 index 000000000000..54d465698b4e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.accelerator import get_accelerator +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + +from ... import DSKernelBase + + +def get_q_block_size(head_size: int) -> int: + """ + Returns the query block size required by the kernel given a head size. + """ + cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda + + if cc_major < 8: + raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0") + + if head_size <= 64: + return 128 + elif head_size <= 160: + if cc_minor != 0: + return 64 + else: + return 128 + elif head_size == 192: + return 128 + elif head_size == 224: + if cc_minor != 0: + return 64 + else: + return 128 + else: + if cc_major == 8 and cc_minor == 0: + return 128 + else: + return 64 + + +def get_kv_block_size(head_size: int) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda + + if cc_major < 8: + raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0") + + if head_size <= 64: + return 128 + elif head_size != 160 or cc_minor != 0: + return 64 + else: + return 32 + + +class BlockedFlashAttn(DSKernelBase): + """ + Modified implementation of flash-attn-2 tuned for inference on blocked KV-cache and wider + range of input sequence lengths. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, head_size: int, dtype: DtypeEnum) -> None: + """ + Triggers any compilation of the kernels. + """ + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedFlashAttn.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported data types are {}".format( + dtype, BlockedFlashAttn.supported_dtypes)) + + # For testing, need to revert to 32 + if head_size % 16 != 0: + raise ValueError("Head size must be divisible by 32 (configured with {})".format(head_size)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.flash_attn_by_atoms + + def __call__(self, out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, atoms: torch.Tensor, + softmax_scale: float) -> torch.Tensor: + """ + Flash attention implementation atop a blocked KV-cache. Atoms should be pre-populated. + See attention_atom.h for further details on the structure of the information. + + Arguments: + out (torch.Tensor): Output tensor of shape [tokens, hidden_size] + q (torch.Tensor): Query tensor of shape [tokens, hidden_size] + k (torch.Tensor): Key cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension. + v (torch.Tensor): Value cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension. + atoms (torch.Tensor): Atom information tensor of shape [num_atoms, 8] and type int32. + Not all data is readable in this format. See attention_atom.h for further details. + softmax_scale (float): Softmax scale factor. + + Returns: + out (torch.Tensor): Output tensor of shape [tokens, hidden_size] + """ + self.kernel(out, q, k, v, atoms, softmax_scale, True) + return out diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h new file mode 100644 index 000000000000..b4a53e6d7f52 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/****************************************************************************** +Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "attention_atom.h" + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + + // The attention metadata + AttentionAtom* __restrict__ atoms; + + // Total attention atoms + int num_atoms; + + // The stride between rows of O. + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions + int d, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py new file mode 100644 index 000000000000..d6b8e6047d74 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .embed import RaggedEmbeddingKernel diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp new file mode 100644 index 000000000000..04b72bf948db --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "embed.h" +#include "ragged_kernel_helpers.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using float_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using float_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using float_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using float_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using float_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() +#endif + +#define DISPATCH_FOR_INT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kInt32) { \ + using int_t = int32_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kInt64) { \ + using int_t = int64_t; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() + +/* +Embeddings kernel aware of ragged batch structure. +*/ +void ragged_embed(torch::Tensor& embedded_tokens, + torch::Tensor& input_ids, + torch::Tensor& embedding_weight, + c10::optional& position_embedding_weight, + int32_t pos_embed_offset, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + // We don't care about KV cache here, so just hardcoding 0s for block_size/num_blocks + BatchWrapperCPP batch_wrapper = + make_cpp_batch_wrapper(batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, 0, 0); + + const int32_t n_tokens = input_ids.numel(); + const int32_t embed_dim = embedding_weight.size(1); + const int32_t vocab_size = embedding_weight.size(0); + + DISPATCH_FOR_INT(input_ids.scalar_type(), [&] { + DISPATCH_FOR_FLOAT(embedding_weight.scalar_type(), [&] { + float_t* pos_embed_ptr = nullptr; + int32_t max_position_embed_idx = 0; + if (position_embedding_weight.has_value()) { + TORCH_CHECK( + position_embedding_weight.value().options().dtype() == + embedding_weight.options().dtype(), + "position_embedding_weight and embedding_weight must have the same dtype"); + pos_embed_ptr = + reinterpret_cast(position_embedding_weight.value().data_ptr()); + max_position_embed_idx = position_embedding_weight.value().size(0) - 1; + } + + launch_ragged_embed_kernel((float_t*)embedded_tokens.data_ptr(), + (const int_t*)input_ids.data_ptr(), + (const float_t*)embedding_weight.data_ptr(), + pos_embed_ptr, + batch_wrapper, + n_tokens, + embed_dim, + vocab_size, + max_position_embed_idx, + pos_embed_offset, + at::cuda::getCurrentCUDAStream()); + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh new file mode 100644 index 000000000000..94c397439b80 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t n_tokens, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h new file mode 100644 index 000000000000..7897c1362669 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "embed.cuh" + +/* +Embeddings kernel aware of ragged batch structure. +*/ +void ragged_embed(torch::Tensor& embedded_tokens, + torch::Tensor& input_ids, + torch::Tensor& embedding_weight, + c10::optional& position_weight, + int32_t position_embed_offset, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py new file mode 100644 index 000000000000..0443ce3fdd8e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....inference_utils import elem_size +from ....ragged import RaggedBatchWrapper + + +class RaggedEmbeddingKernel(DSKernelBase): + """ + Ragged-aware CUDA kernel implementation for an embedding lookup. This will only lookup + the necessary tokens for a padded batch (i.e. if we are CGed and running with a slightly + larger batch size than the actual tokens). + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + supported_token_dtypes = [torch.int32, torch.int64] + + def __init__(self, embed_dtype: torch.dtype, token_dtype: torch.dtype, embed_dim: int) -> None: + """ + Args: + fp_dtype (torch.dtype): Data type of the embedding table and output dtype. + Supported values are torch.float16, torch.bfloat16, and torch.float32. + token_dtype (torch.dtype): Data type of the token ids. Supported values are + torch.int32 and torch.int64. + embed_dim (int): Embedding dimension. Must be aligned to 16 bytes. + """ + if embed_dtype not in RaggedEmbeddingKernel.supported_dtypes: + raise ValueError("Unsupported embedding data type: {}, supported_dtypes are {}".format( + embed_dtype, RaggedEmbeddingKernel.supported_dtypes)) + + if token_dtype not in RaggedEmbeddingKernel.supported_token_dtypes: + raise ValueError("Unsupported token data type: {}, supported_dtypes are {}".format( + token_dtype, RaggedEmbeddingKernel.supported_token_dtypes)) + + if elem_size(embed_dtype) * embed_dim % 16 != 0: + raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(embed_dim)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.ragged_embed + + def __call__(self, + embedded_tokens: torch.Tensor, + ragged_wrapper: RaggedBatchWrapper, + embedding_weight: torch.Tensor, + position_embed_weight: Optional[torch.Tensor] = None, + position_embed_offset: int = 0) -> torch.Tensor: + """ + Ragged aware embedding lookup. + + Args: + embedded_tokens (torch.Tensor): Output tensor of shape [num_tokens, embed_dim] + ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch. + embedding_weight (torch.Tensor): Embedding table of shape [vocab_size, embed_dim] + """ + self.kernel(embedded_tokens, ragged_wrapper.input_ids(), + embedding_weight, position_embed_weight, position_embed_offset, + ragged_wrapper.batch_metadata_buffer(), ragged_wrapper.inflight_seq_descriptors(), + ragged_wrapper.tokens_to_seq(), ragged_wrapper.kv_ptrs()) + return embedded_tokens diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed_cuda.cu new file mode 100644 index 000000000000..81d6d534ddf5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed_cuda.cu @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "embed.cuh" +#include "memory_access_utils.h" +#include "ragged_dtypes.h" + +namespace embed { + +constexpr int granularity = 16; +constexpr int threads = 512; + +} // namespace embed + +template +__global__ void ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset) +{ + constexpr int T_vector = embed::granularity / sizeof(EmbedType); + + const int32_t token_idx = blockIdx.y; + + // It's possible our batch is padded (under CG conditions typically) + if (token_idx >= batch_desc.batch_metadata->n_tokens) return; + + TokenType token_value = input_ids[token_idx]; + + if (token_value >= vocab_size || token_value < 0) { + // TODO(cmikeh2): This is invalid, but not sure how we want to handle it being invalid + // yet. + return; + } + + const EmbedType* embedding_row = embedding_weight + token_value * embed_dim; + EmbedType* dest_row = embedded_tokens + token_idx * embed_dim; + + const int channel_offset = (threadIdx.x + embed::threads * blockIdx.x) * T_vector; + + if (channel_offset < embed_dim) { + EmbedType reg_buf[T_vector]; + + mem_access::load_global(reg_buf, embedding_row + channel_offset); + + if (position_weight != nullptr) { + // Map the token to its global idx (indirect memory accesses aren't great but whatever) + const int32_t seq_idx = batch_desc.tokens_to_seq[token_idx]; + const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_idx]; + int32_t pos_emb_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); + + // Position embed offset is an OPT-specific feature I think? + pos_emb_idx = pos_emb_idx + position_embed_offset; + + // This clamping is technically + pos_emb_idx = (pos_emb_idx < 0) ? 0 : pos_emb_idx; + pos_emb_idx = (pos_emb_idx >= max_position_embed_idx) ? max_position_embed_idx + : pos_emb_idx; + + const EmbedType* position_embedding_row = position_weight + pos_emb_idx * embed_dim; + + EmbedType pos_buf[T_vector]; + mem_access::load_global(pos_buf, + position_embedding_row + channel_offset); + +#pragma unroll + for (int i = 0; i < T_vector; i++) { reg_buf[i] += pos_buf[i]; } + } + + mem_access::store_global(dest_row + channel_offset, reg_buf); + } +} + +template +void launch_ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t n_tokens, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset, + cudaStream_t stream) +{ + constexpr int T_vector = embed::granularity / sizeof(EmbedType); + constexpr int elems_per_block = embed::threads * T_vector; + const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block; + + const dim3 grid_dim(parallel_blocks, n_tokens, 1); + const dim3 block_dim(embed::threads, 1, 1); + + ragged_embed_kernel + <<>>(embedded_tokens, + input_ids, + embedding_weight, + position_weight, + batch_desc, + embed_dim, + vocab_size, + max_position_embed_idx, + position_embed_offset); +} + +#define INSTANTIATE_EMBED_FOR_TYPES(TOKEN_TYPE, EMBED_TYPE) \ + template void launch_ragged_embed_kernel( \ + EMBED_TYPE * embedded_tokens, \ + const TOKEN_TYPE* input_ids, \ + const EMBED_TYPE* embedding_weight, \ + const EMBED_TYPE* position_weight, \ + const BatchWrapperCPP batch_descriptor, \ + const int32_t n_tokens, \ + const int32_t embed_dim, \ + const int32_t vocab_size, \ + const int32_t max_position_embed_idx, \ + const int32_t position_embed_offset, \ + cudaStream_t stream); + +INSTANTIATE_EMBED_FOR_TYPES(int32_t, float) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, float) + +INSTANTIATE_EMBED_FOR_TYPES(int32_t, __half) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, __half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_EMBED_FOR_TYPES(int32_t, __nv_bfloat16) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, __nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h new file mode 100644 index 000000000000..f5104f899d9c --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#define TOP_K_SWITCH(N_TOP_K, ...) \ + [&] { \ + if (1 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 1; \ + __VA_ARGS__(); \ + } else if (2 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 2; \ + __VA_ARGS__(); \ + } else if (4 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 4; \ + __VA_ARGS__(); \ + } else if (8 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 8; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py new file mode 100644 index 000000000000..0e239dd6b4c7 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blocked_kv_rotary import * +from .blocked_trained_kv_rotary import * +from .linear_blocked_kv_copy import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp new file mode 100644 index 000000000000..634a63b81a31 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "blocked_kv_rotary.h" +#include "ragged_kernel_helpers.h" + +#define DISPATCH_KV_ROTARY(T_TYPE, C_TYPE) \ + if (q.options().dtype() == torch::T_TYPE) { \ + launch_kv_rotary_kernel((C_TYPE*)kv_cache.data_ptr(), \ + (C_TYPE*)q.data_ptr(), \ + (C_TYPE*)k.data_ptr(), \ + (C_TYPE*)v.data_ptr(), \ + (C_TYPE*)inv_freq_ptr, \ + rotary_dim, \ + theta_base, \ + batch_wrapper, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + inv_freq_stride, \ + q_ratio, \ + head_size, \ + n_tokens, \ + n_q_heads, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be ready from global memory rather than +synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] + inv_freq: [max_seq_len, head_size // 2] +*/ +void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& inv_freq, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + const float theta_base = 0.f; + const int32_t rotary_dim = inv_freq.size(0) * 2; + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + const int32_t inv_freq_stride = inv_freq.stride(0); // Per token idx + + const int n_q_heads = q.size(1) / head_size; + const int q_ratio = n_q_heads / n_kv_heads; + + void* inv_freq_ptr = (void*)inv_freq.data_ptr(); + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_ROTARY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16); +#endif +} + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] +*/ +void kv_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + const int32_t rotary_dim, + const float theta_base, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + const int32_t inv_freq_stride = 0; // Per token idx + + const int n_q_heads = q.size(1) / head_size; + const int q_ratio = n_q_heads / n_kv_heads; + + void* inv_freq_ptr = nullptr; + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_ROTARY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_KV_COPY(T_TYPE, C_TYPE) \ + if (q.options().dtype() == torch::T_TYPE) { \ + launch_kv_copy_kernel((C_TYPE*)kv_cache.data_ptr(), \ + (C_TYPE*)q.data_ptr(), \ + (C_TYPE*)k.data_ptr(), \ + (C_TYPE*)v.data_ptr(), \ + batch_wrapper, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + q_ratio, \ + head_size, \ + n_tokens, \ + n_q_heads, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Copy into linear KV cache. +*/ +void linear_kv_copy(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + TORCH_CHECK(qkv_stride == k.stride(0)); + TORCH_CHECK(qkv_stride == v.stride(0)); + + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + + const int n_q_heads = q.size(1) / head_size; + + TORCH_CHECK(n_q_heads % n_kv_heads == 0); + const int q_ratio = n_q_heads / n_kv_heads; + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_COPY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_COPY(kBFloat16, __nv_bfloat16); +#endif +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh new file mode 100644 index 000000000000..ff24b3f5bd80 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_kv_rotary_kernel(T* kv_cache, + T* q, + T* k, + T* v, + T* inv_freq, + const int32_t rotary_dim, + const float theta_base, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream); + +template +void launch_kv_copy_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h new file mode 100644 index 000000000000..c0700eda7147 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "blocked_kv_rotary.cuh" + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be ready from global memory rather than +synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] + inv_freq: [max_seq_len, head_size // 2] +*/ +void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& inv_freq, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] +*/ +void kv_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + const int32_t rotary_dim, + const float theta_base, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); + +/* +Copy into linear KV cache. +*/ +void linear_kv_copy(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py new file mode 100644 index 000000000000..aacbec0bd3ae --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper +from ... import DSKernelBase + + +class BlockedRotaryEmbeddings(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 80, 96, 128] + supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int, + theta_base: float) -> None: + """ + Args: + head_size: The size of the attention head. + q_ratio: Ratio of q heads to kv heads (for GQA) + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in BlockedRotaryEmbeddings.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, BlockedRotaryEmbeddings.supported_head_sizes)) + + if q_ratio not in BlockedRotaryEmbeddings.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, BlockedRotaryEmbeddings.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedRotaryEmbeddings.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, BlockedRotaryEmbeddings.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.kv_rotary_embeddings + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + self.rotary_dim = rotary_dim + self.theta_base = theta_base + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, self.rotary_dim, self.theta_base, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu new file mode 100644 index 000000000000..f7bc693eefee --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "blocked_kv_rotary.cuh" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace kv_rot { + +constexpr int granularity = 16; +constexpr int threads = 256; + +} // namespace kv_rot + +/* +Supports head size 32, 64, 128, 256 +*/ + +template +__global__ void kv_rotary_pos_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const T* inv_freq, + const int32_t rotary_dim, + const float theta_base, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride) +{ + // Derived constexpr + constexpr int vector_T = kv_rot::granularity / sizeof(T); + constexpr int real_threads_per_head = headSize / vector_T; + constexpr int threads_per_head = paddedHeadSize / vector_T; + + constexpr int tokens_per_block = kv_rot::threads / threads_per_head; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + cg::thread_block_tile head_group = cg::tiled_partition(tb); + + // Parallelize on the head dimension for X blocks + const int head_idx = blockIdx.x; + + const int block_seq_idx = threadIdx.x / threads_per_head; + const int base_neuron_idx = head_group.thread_rank() * vector_T; + const int half_rotary_size = rotary_dim / 2; + const int half_dim_lanes = half_rotary_size / vector_T; + const int half_idx = base_neuron_idx % half_rotary_size; + + // Multiple tokens processed by the same threadblock + const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx; + const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens; + + const bool valid_thread = valid_token && (head_group.thread_rank() < real_threads_per_head); + const bool load_inv_freq = (inv_freq != nullptr) && valid_thread; + + // If we have GQA, then only one of the Q heads needs to do rotary + copy + // for each of the heads in the group. + bool need_kv = head_idx % qRatio == 0; + // Make sure the following code is warp uniform + need_kv = warp.shfl(need_kv, 0); + + const int kv_head_idx = head_idx / qRatio; + + // Ensure we don't access invalid portions of the seq_metadata + const int32_t seq_id = (valid_thread) ? batch_desc.tokens_to_seq[token_idx] : 0; + const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id]; + // This will give an invalid index if valid_thread is false, but should never affect memory. + const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); + + T* q_row = q + token_idx * qkv_stride + head_idx * headSize; + T q_reg[vector_T]; + + if (need_kv) { + // The following logic assumes a linearly blocked KV cache. This means that no sparsity has + // been introduced into cache history. + const KVCacheDescriptor kv_desc = batch_desc.kv_desc; + const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size; + const int32_t mapped_kv_block_idx = + (valid_thread) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; + + const int32_t kv_block_offset = global_token_idx % kv_desc.block_size; + const int32_t kv_offset = + (mapped_kv_block_idx * kv_desc.block_size + kv_block_offset) * kv_cache_stride + + kv_head_idx * headSize; + + // Load indices from QKV output + T* k_row = k + token_idx * qkv_stride + kv_head_idx * headSize; + T* v_row = v + token_idx * qkv_stride + kv_head_idx * headSize; + + T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T]; + + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); + mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_thread); + mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_thread); + mem_access::load_global( + inv_freq_reg, inv_freq + half_idx, load_inv_freq); + if constexpr (doRotary) { +#pragma unroll + for (int i = 0; i < vector_T; i++) { + const int head_neuron_idx = base_neuron_idx + i; + + float inv_freq_flt; + if (inv_freq != nullptr) { + inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; + } else { + inv_freq_flt = + (float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim; + // Conversion to T and back means that both branches of this if statement + // will produce the same results if using the same algo for producing the + // freqs. + T trunc_freq = conversion::to(1.0 / powf(theta_base, inv_freq_flt)); + inv_freq_flt = conversion::to(trunc_freq) * (float)global_token_idx; + } + + float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f; + float q_f = conversion::to(q_reg[i]); + float k_f = conversion::to(k_reg[i]); + float q_rot = q_f * rotary_sign; + float k_rot = k_f * rotary_sign; + + const int target_lane = (head_neuron_idx < half_rotary_size) + ? head_group.thread_rank() + half_dim_lanes + : head_group.thread_rank() - half_dim_lanes; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); + + if (base_neuron_idx < rotary_dim) { + q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + + q_rot_temp * sinf(inv_freq_flt)); + k_reg[i] = conversion::to(k_f * cosf(inv_freq_flt) + + k_rot_temp * sinf(inv_freq_flt)); + } + } + } + + if (valid_thread) { + mem_access::store_global(kv_cache + kv_offset + base_neuron_idx, + k_reg); + mem_access::store_global( + kv_cache + kv_offset + base_neuron_idx + v_offset, v_reg); + } + } else { + T inv_freq_reg[vector_T]; + + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); + mem_access::load_global( + inv_freq_reg, inv_freq + half_idx, load_inv_freq); + + if constexpr (doRotary) { +#pragma unroll + for (int i = 0; i < vector_T; i++) { + const int head_neuron_idx = base_neuron_idx + i; + + float inv_freq_flt; + if (inv_freq != nullptr) { + inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; + } else { + inv_freq_flt = + (float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim; + inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx; + } + + float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f; + float q_f = conversion::to(q_reg[i]); + float q_rot = q_f * rotary_sign; + + const int target_lane = (head_neuron_idx < half_rotary_size) + ? head_group.thread_rank() + half_dim_lanes + : head_group.thread_rank() - half_dim_lanes; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + if (base_neuron_idx < rotary_dim) + q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + + q_rot_temp * sinf(inv_freq_flt)); + } + } + } + + if (valid_thread && doRotary) { + mem_access::store_global(q_row + base_neuron_idx, q_reg); + } +} + +#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + inv_freq, \ + rotary_dim, \ + theta_base, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + inv_freq_stride); + +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 96) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 96); \ + } else if (head_size == 128) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } + +template +void launch_kv_rotary_kernel(T* kv_cache, + T* q, + T* k, + T* v, + T* inv_freq, + const int32_t rotary_dim, + const float theta_base, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream) +{ + constexpr int vector_T = kv_rot::granularity / sizeof(T); + + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; + + const int tokens_per_block = kv_rot::threads / threads_per_head; + + const dim3 block(kv_rot::threads); + const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; + const dim3 grid(n_q_heads, token_blocks); + + LAUNCH_KV_ROTARY_FOR_Q_RATIO(1) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(2) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(4) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(5) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(6) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(7) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(8) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(16) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(29) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(35) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(36) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(71) +} + +#define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \ + template void launch_kv_rotary_kernel(TYPE * kv_cache, \ + TYPE * q, \ + TYPE * k, \ + TYPE * v, \ + TYPE * inv_freq, \ + const int32_t rotary_dim, \ + const float theta_base, \ + const BatchWrapperCPP batch_desc, \ + const int qkv_stride, \ + const int kv_cache_stride, \ + const int v_offset, \ + const int inv_freq_stride, \ + const int q_ratio, \ + const int head_size, \ + const int n_tokens, \ + const int n_q_heads, \ + cudaStream_t stream); + +INSTANTIATE_KV_ROTARY_KERNEL(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) +#endif + +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + nullptr, \ + -1, \ + 0.f, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + 0); + +#define LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_COPY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 96) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 96); \ + } else if (head_size == 128) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } + +template +void launch_kv_copy_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream) +{ + constexpr int vector_T = kv_rot::granularity / sizeof(T); + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; + const int tokens_per_block = kv_rot::threads / threads_per_head; + + const dim3 block(kv_rot::threads); + const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; + const dim3 grid(n_q_heads, token_blocks); + + LAUNCH_KV_COPY_FOR_Q_RATIO(1) + LAUNCH_KV_COPY_FOR_Q_RATIO(2) + LAUNCH_KV_COPY_FOR_Q_RATIO(4) + LAUNCH_KV_COPY_FOR_Q_RATIO(5) + LAUNCH_KV_COPY_FOR_Q_RATIO(8) +} + +#define INSTANTIATE_KV_COPY_KERNEL(TYPE) \ + template void launch_kv_copy_kernel(TYPE * kv_cache, \ + TYPE * q, \ + TYPE * k, \ + TYPE * v, \ + const BatchWrapperCPP batch_desc, \ + const int qkv_stride, \ + const int kv_cache_stride, \ + const int v_offset, \ + const int q_ratio, \ + const int head_size, \ + const int n_tokens, \ + const int n_q_heads, \ + cudaStream_t stream); + +INSTANTIATE_KV_COPY_KERNEL(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_KV_COPY_KERNEL(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py new file mode 100644 index 000000000000..f527be227ce1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper +from ... import DSKernelBase + + +class BlockedTrainedRotaryEmbeddings(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 80, 96, 128] + supported_q_ratios = [1, 2, 4, 5, 8] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + """ + Args: + head_size: The size of the attention head. + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in BlockedTrainedRotaryEmbeddings.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, BlockedTrainedRotaryEmbeddings.supported_head_sizes)) + + if q_ratio not in BlockedTrainedRotaryEmbeddings.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, BlockedTrainedRotaryEmbeddings.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedTrainedRotaryEmbeddings.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, BlockedTrainedRotaryEmbeddings.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.kv_trained_rotary_embeddings + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper, + inverse_freqs: torch.Tensor) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, rotary_dim // 2] + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, inverse_freqs, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py new file mode 100644 index 000000000000..4b2ad858a1bf --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from ....ragged import RaggedBatchWrapper +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ... import DSKernelBase + + +class LinearBlockedKVCopy(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 80, 96, 128] + supported_q_ratios = [1, 2, 4, 5, 8] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + """ + Args: + head_size: The size of the attention head. + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in LinearBlockedKVCopy.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, LinearBlockedKVCopy.supported_head_sizes)) + + if q_ratio not in LinearBlockedKVCopy.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, LinearBlockedKVCopy.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in LinearBlockedKVCopy.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, LinearBlockedKVCopy.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.linear_kv_copy + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), + ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py new file mode 100644 index 000000000000..72103a0d82a1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .logits_gather import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp new file mode 100644 index 000000000000..1a7e7c0a2167 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "logits_gather.h" + +#define DISPATCH_TO_LOGITS_GATHER(T_TYPE, C_TYPE) \ + if (all_acts.options().dtype() == torch::T_TYPE) { \ + launch_logits_gather((C_TYPE*)final_token_acts.data_ptr(), \ + (const C_TYPE*)all_acts.data_ptr(), \ + batch_metadata_raw, \ + seq_metadata_raw, \ + n_seqs, \ + embed_dim, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Logits gather will parse the ragged batch data structure and gather only the logits that +will be used for token sampling. +*/ +void gather_for_logits(torch::Tensor& final_token_acts, + torch::Tensor& all_acts, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata) +{ + const RaggedBatchDescriptor* batch_metadata_raw = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_metadata_raw = + reinterpret_cast(seq_metadata.data_ptr()); + + const int n_seqs = final_token_acts.size(0); + const int embed_dim = final_token_acts.size(1); + + TORCH_CHECK(all_acts.scalar_type() == final_token_acts.scalar_type(), + "all_acts and final_token_acts must have the same scalar type"); + + DISPATCH_TO_LOGITS_GATHER(kFloat, float) + DISPATCH_TO_LOGITS_GATHER(kHalf, half) +#ifdef BF16_AVAILABLE + DISPATCH_TO_LOGITS_GATHER(kBFloat16, __nv_bfloat16) +#endif +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh new file mode 100644 index 000000000000..c4e84c05e6d8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_logits_gather(T* final_token_acts, + const T* all_acts, + const RaggedBatchDescriptor* batch_metadata, + const InflightSeqDescriptor* seq_metadata, + const int32_t n_seqs, + const int32_t embed_dim, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h new file mode 100644 index 000000000000..73a855984daa --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "logits_gather.cuh" +#include "ragged_dtypes.h" + +/* +Logits gather will parse the ragged batch data structure and gather only the logits that +will be used for token sampling. +*/ +void gather_for_logits(torch::Tensor& final_token_acts, + torch::Tensor& all_acts, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py new file mode 100644 index 000000000000..64b453e9e9e3 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....inference_utils import elem_size +from ....ragged import RaggedBatchWrapper + + +class RaggedLogitsGather(DSKernelBase): + """ + CUDA Kernel implementation for gather the hidden states of the final token + of each sequence. This is used to reduce the cost of the performing the unembedding. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, model_dim: int, fp_dtype: torch.dtype): + """ + Parameters: + fp_dtype (torch.dtype): Data type for the input/output. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in RaggedLogitsGather.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, RaggedLogitsGather.supported_dtypes)) + + if elem_size(fp_dtype) * model_dim % 16 != 0: + raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(model_dim)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.gather_for_logits + + def __call__(self, final_token_activations: torch.Tensor, all_activations: torch.Tensor, + ragged_wrapper: RaggedBatchWrapper) -> torch.Tensor: + """ + Gather the hidden states of the final token of each sequence from `all_activations` into + `final_token_activations`. + + Args: + final_token_activations (torch.Tensor): Output tensor of shape [num_seqs, model_dim] + all_activations (torch.Tensor): Input tensor of shape [num_tokens, model_dim] + ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch. + """ + + self.kernel(final_token_activations, all_activations, ragged_wrapper.batch_metadata_buffer(), + ragged_wrapper.inflight_seq_descriptors()) + return final_token_activations diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather_cuda.cu new file mode 100644 index 000000000000..a539888ff904 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather_cuda.cu @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "logits_gather.cuh" +#include "memory_access_utils.h" +#include "ragged_dtypes.h" + +namespace logits_gather { + +constexpr int granularity = 16; +constexpr int threads = 512; + +} // namespace logits_gather + +template +__global__ void logits_gather_kernel(T* final_token_acts, + const T* token_acts, + const RaggedBatchDescriptor* ragged_batch, + const InflightSeqDescriptor* inflight_batch, + const int32_t embed_dim) +{ + constexpr int T_vector = logits_gather::granularity / sizeof(T); + + const int32_t seq_id = blockIdx.y; + + // It's possible we've padded the output Tensor (under CG conditions) + if (seq_id >= ragged_batch->n_sequences) return; + + const InflightSeqDescriptor seq = inflight_batch[seq_id]; + const int final_token_idx = seq.start_idx + seq.n_tokens - 1; + + const int token_offset = final_token_idx * embed_dim; + const int thread_offset = + threadIdx.x * T_vector + blockIdx.x * logits_gather::threads * T_vector; + + const int final_token_offset = seq_id * embed_dim; + + T reg_buf[T_vector]; + + if (thread_offset < embed_dim) { + mem_access::load_global( + reg_buf, token_acts + token_offset + thread_offset); + + mem_access::store_global( + final_token_acts + final_token_offset + thread_offset, reg_buf); + } +} + +template +void launch_logits_gather(T* final_token_acts, + const T* all_acts, + const RaggedBatchDescriptor* ragged_batch, + const InflightSeqDescriptor* inflight_batch, + const int32_t n_seqs, + const int32_t embed_dim, + cudaStream_t stream) +{ + constexpr int T_vector = logits_gather::granularity / sizeof(T); + constexpr int elems_per_block = logits_gather::threads * T_vector; + const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block; + + const dim3 grid(parallel_blocks, n_seqs, 1); + const dim3 block(logits_gather::threads, 1, 1); + + logits_gather_kernel<<>>( + final_token_acts, all_acts, ragged_batch, inflight_batch, embed_dim); +} + +#define INSTANTIATE_FOR_TYPE(T) \ + template void launch_logits_gather(T * final_token_acts, \ + const T* all_acts, \ + const RaggedBatchDescriptor* ragged_batch, \ + const InflightSeqDescriptor* inflight_batch, \ + const int32_t n_seqs, \ + const int32_t embed_dim, \ + cudaStream_t stream); + +INSTANTIATE_FOR_TYPE(float) +INSTANTIATE_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py new file mode 100644 index 000000000000..096c0d984a5a --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .moe_gather import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp new file mode 100644 index 000000000000..506629406f0d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "moe_gather.h" +#include + +#define DISPATCH_MOE_GATHER(T_TYPE, C_TYPE) \ + if (layer_output.options().dtype() == torch::T_TYPE) { \ + launch_moe_gather((C_TYPE*)layer_output.data_ptr(), \ + (const C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + (int32_t*)expert_count.data_ptr(), \ + n_channels, \ + n_experts, \ + n_tokens, \ + n_top_k, \ + normalize_scales, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Re-gather the outputs of MoE and scale them by the gating score. +*/ +void moe_gather(torch::Tensor& layer_output, + const torch::Tensor& moe_output, + const torch::Tensor& scores, + const torch::Tensor& mapped_slots, + const torch::Tensor& expert_count, + const bool normalize_scales) +{ + const int32_t n_channels = layer_output.size(1); + const int32_t n_experts = expert_count.size(0); + const int32_t n_tokens = layer_output.size(0); + const int32_t n_top_k = mapped_slots.size(1); + + TORCH_CHECK(moe_output.size(0) == n_tokens * n_top_k); + TORCH_CHECK(moe_output.size(1) == n_channels); + TORCH_CHECK(scores.size(0) == n_tokens); + TORCH_CHECK(mapped_slots.size(0) == n_tokens); + + TORCH_CHECK(scores.size(1) == n_top_k); + + TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type()); + TORCH_CHECK(scores.scalar_type() == torch::kFloat32); + TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); + TORCH_CHECK(expert_count.scalar_type() == torch::kInt32); + + DISPATCH_MOE_GATHER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_GATHER(kBFloat16, __nv_bfloat16); +#endif + + TORCH_CHECK(false, "Unsupported data type for MoE gather"); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh new file mode 100644 index 000000000000..b348d0cfb330 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +template +void launch_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h new file mode 100644 index 000000000000..ec9e03057eb8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "moe_gather.cuh" + +/* +Re-gather the outputs of MoE and scale them by the gating score. +*/ +void moe_gather(torch::Tensor& layer_output, + const torch::Tensor& moe_output, + const torch::Tensor& scores, + const torch::Tensor& mapped_slots, + const torch::Tensor& expert_counts, + const bool normalize_scales); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py new file mode 100644 index 000000000000..f03938171ba4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class MoEGather(DSKernelBase): + """ + CUDA implementation of MoE gather. This will bring the tokens back + to their original indices and perform the output scaling. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, dtype: DtypeEnum, channels: int, normalize_scores: bool = False) -> None: + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in MoEGather.supported_dtypes: + raise RuntimeError(f"Unsupported dtype {dtype}") + + if channels % 8 != 0: + raise RuntimeError(f"Channels {channels} must be divisible by 8") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.moe_gather + self.normalize_scores = normalize_scores + + def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: torch.Tensor, + mapped_slots: torch.Tensor, expert_counts: torch.Tensor) -> torch.Tensor: + """ + Reorders the moe_output tokens into their original order and scales them by their + gating scale. This will be a no-op for padded tokens. + + Arguments: + layer_output (torch.Tensor): The output of the layer of shape [n_tokens, hidden_size]. This has been scaled appropriately. + moe_output (torch.Tensor): The output of the MoE of shape [n_tokens * n_top_k, hidden_size]. + scores (torch.Tensor): The gating scores of shape [n_tokens]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. The indices of token ``i`` in layer_output is ``mapped_slots[i]``. + expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. This is passed to fuse the clearing of this data structure into the gather. + + Returns: + layer_output + """ + self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts, self.normalize_scores) + return layer_output diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather_cuda.cu new file mode 100644 index 000000000000..4153a2a3636f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather_cuda.cu @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "moe_gather.cuh" +#include "reduction_utils.h" +#include "top_k_gating.cuh" +#include "top_k_utils.h" + +namespace gather { + +constexpr int access_granularity = 16; +constexpr int threads = 256; + +} // namespace gather + +template +__global__ void moe_gather_kernel(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts, + const bool normalize_scales) +{ + constexpr int32_t vector_size = gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * gather::threads; + + const int32_t token_idx = blockIdx.x; + int32_t token_mapped_slots[N_TOP_K]; + + bool all_slots_invalid = true; + for (int i = 0; i < N_TOP_K; i++) { + token_mapped_slots[i] = mapped_slots[token_idx * N_TOP_K + i]; + all_slots_invalid &= (token_mapped_slots[i] == gating::unassigned); + } + + if (token_idx == 0) { + // Reset expert counts for its next use. + if (threadIdx.x < n_experts) { expert_counts[threadIdx.x] = 0; } + } + + if (all_slots_invalid) { + // This token was not assigned to anything. + // TODO(cmikeh2): It's possible we want different behavior here moving forward. + return; + } + + float token_scores[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] = scores[token_idx * N_TOP_K + i]; } + + if (normalize_scales) { + // Normalize the scores so that they sum to 1. + float sum = 0.0f; + for (int i = 0; i < N_TOP_K; i++) { sum += token_scores[i]; } + + if (sum > 0.0f) { + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] /= sum; } + } + } + + const int32_t channel_offset = threadIdx.x * vector_size; + + const T* moe_output_bases[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + moe_output_bases[i] = moe_output + token_mapped_slots[i] * n_channels + channel_offset; + } + + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + if (i * stride + channel_offset < n_channels) { + float accum_buffer[vector_size]; + for (int j = 0; j < vector_size; j++) { + accum_buffer[j] = reduce::init(); + } + +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + T reg_buffer[vector_size]; + mem_access::load_global( + reg_buffer, moe_output_bases[j] + i * stride); + +#pragma unroll + for (int k = 0; k < vector_size; k++) { + float up_cast = conversion::to(reg_buffer[k]); + accum_buffer[k] += up_cast * token_scores[j]; + } + } + + T store_buffer[vector_size]; +#pragma unroll + for (int j = 0; j < vector_size; j++) { + store_buffer[j] = conversion::to(accum_buffer[j]); + } + + mem_access::store_global(layer_output_base + i * stride, + store_buffer); + } + } +} + +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + expert_counts, \ + n_channels, \ + n_experts, \ + normalize_scales); \ + break; + +template +void launch_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = gather::threads * gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(gather::threads); + const dim3 grid(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } + }); +} + +#define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ + template void launch_moe_gather(TYPE * layer_output, \ + const TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + int32_t* expert_counts, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_tokens, \ + const int32_t n_top_k, \ + const bool normalize_scales, \ + cudaStream_t stream); + +INSTANTIATE_GATHER_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_GATHER_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py new file mode 100644 index 000000000000..a7ca91fe5363 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .moe_scatter import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp new file mode 100644 index 000000000000..8f7ecbd1a287 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "moe_scatter.h" +#include + +#define DISPATCH_MOE_SCATTER(T_TYPE, C_TYPE) \ + if (activations.options().dtype() == torch::T_TYPE) { \ + launch_moe_scatter((C_TYPE*)moe_input.data_ptr(), \ + (int64_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (const int32_t*)expert_counts.data_ptr(), \ + (const int32_t*)assignments.data_ptr(), \ + (const int32_t*)offsets.data_ptr(), \ + n_channels, \ + n_tokens, \ + n_experts, \ + n_top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Performs a cumsum on the expert counts and copies the hidden states to the +appropriate spot to ensure that each experts inputs are contiguous. +*/ +void moe_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& assignments, + torch::Tensor& offsets) +{ + const int32_t n_tokens = activations.size(0); + const int32_t n_channels = activations.size(1); + const int32_t n_top_k = assignments.size(1); + + // Should have a lot of matching buffer sizes here. + TORCH_CHECK(n_tokens == assignments.size(0)); + TORCH_CHECK(n_tokens == offsets.size(0)); + TORCH_CHECK(n_channels == moe_input.size(1)); + + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k * n_tokens == moe_input.size(0)); + TORCH_CHECK(n_top_k == mapped_slots.size(1)); + + const int32_t n_experts = expert_count_cumsums.size(0); + + TORCH_CHECK(moe_input.scalar_type() == activations.scalar_type()); + TORCH_CHECK(expert_count_cumsums.scalar_type() == torch::kInt64); + TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); + TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); + TORCH_CHECK(assignments.scalar_type() == torch::kInt32); + TORCH_CHECK(offsets.scalar_type() == torch::kInt32); + + DISPATCH_MOE_SCATTER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_SCATTER(kBFloat16, __nv_bfloat16); +#endif + + TORCH_CHECK(false, "Unsupported dtype for moe_scatter") +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh new file mode 100644 index 000000000000..d9756c80f05a --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +template +void launch_moe_scatter(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* expert_counts, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h new file mode 100644 index 000000000000..59597f63d123 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "moe_scatter.cuh" +#include "ragged_dtypes.h" + +/* +Performs a cumsum on the expert counts and copies the hidden states to the +appropriate spot to ensure that each experts inputs are contiguous. +*/ +void moe_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& assignments, + torch::Tensor& offsets); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py new file mode 100644 index 000000000000..7efcedb4e880 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Tuple + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class MoEScatter(DSKernelBase): + """ + CUDA implementation of MoE scatter + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, dtype: DtypeEnum, channels: int) -> None: + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in MoEScatter.supported_dtypes: + raise RuntimeError(f"Unsupported dtype {dtype}") + + if channels % 8 != 0: + raise RuntimeError(f"Channels {channels} must be divisible by 8") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.moe_scatter + + def __call__(self, moe_input: torch.Tensor, expert_cumsum: torch.Tensor, mapped_slots: torch.Tensor, + activations: torch.Tensor, expert_counts: torch.Tensor, assignments: torch.Tensor, + offsets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Scatters the hidden states such that the token stride for each expert's input is contiguous. + + Arguments: + moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens * n_top_k, hidden_size]. + expert_cumsum (torch.Tensor): The cumulative sum of the expert counts of shape [n_experts]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. + hidden_states (torch.Tensor): The hidden states of shape [n_tokens, hidden_size]. + expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. + assignments (torch.Tensor): The expert assignments of shape [n_tokens, n_top_k]. + offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens, n_top_K]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input (with scattered values), the cumsum of the offsets (for the MoE kernels themselves), and the assignments Tensor modified in place to show which row that token was mapped to in the input. + """ + self.kernel(moe_input, expert_cumsum, mapped_slots, activations, expert_counts, assignments, offsets) + return moe_input, expert_cumsum, mapped_slots diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter_cuda.cu new file mode 100644 index 000000000000..d3eb4f649e79 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter_cuda.cu @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "reduction_utils.h" +#include "top_k_gating.cuh" +#include "top_k_utils.h" + +using ROp = reduce::ROpType; + +namespace scatter { + +constexpr int access_granularity = 16; +constexpr int threads = 256; +constexpr int warps = threads / hw_warp_size; +constexpr int max_experts = 1024; + +} // namespace scatter + +template +__global__ void moe_scatter_kernel(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* assignments, + const int32_t* expert_counts, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_experts) +{ + constexpr int32_t vector_size = scatter::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // Bank aligned and sufficient + __shared__ int32_t red_buffer[32]; + __shared__ int32_t expert_offsets[scatter::max_experts]; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Fetch the assigned experts for this token. + int assigned_experts[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { + assigned_experts[i] = assignments[token_idx * N_TOP_K + i]; + } + + bool all_unassigned = true; + for (int i = 0; i < N_TOP_K; i++) { + if (assigned_experts[i] != gating::unassigned) { + all_unassigned = false; + } else { + mapped_slots[token_idx * N_TOP_K + i] = gating::unassigned; + } + } + if (all_unassigned && token_idx != 0) return; + + // Do a prefix scan on the expert counts to get the base offsets. Here we use the + // single up-sweep variant. + int32_t expert_vals; + if (tidx < n_experts) { + expert_vals = expert_counts[tidx]; + } else { + expert_vals = 0; + } + +#pragma unroll + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(expert_vals, i); + expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; + } + + if (warp.thread_rank() == hw_warp_size - 1) { + mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); + } + + tb.sync(); + + int32_t phase_2_val = 0; + if (warp.thread_rank() < scatter::warps) { + mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); + } + +#pragma unroll + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(phase_2_val, i); + phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; + } + + int warp_offset = 0; + if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } + const int32_t expert_cumsum = warp_offset + expert_vals; + + // Token 0 will write the + if (token_idx == 0 && tidx < n_experts) { + int64_t expert_cumsum_64 = (int64_t)expert_cumsum; + expert_count_cumsums[tidx] = expert_cumsum_64; + } + + // Since token 0 has now written the expert cumsum to global memory, + // if it has no valid experts, we can early return. + if (token_idx == 0 && all_unassigned) return; + + if (tidx < n_experts) { expert_offsets[tidx] = expert_cumsum; } + + // Ensure all the expert offsets are written in shared memory. + tb.sync(); + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + const T* load_base_ptr = activations + base_load_offset; + + int32_t store_rows[N_TOP_K]; + T* store_base_ptrs[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + const int32_t cur_expert_offset = + (assigned_experts[i] > 0) ? expert_offsets[assigned_experts[i] - 1] : 0; + store_rows[i] = cur_expert_offset + offsets[token_idx * N_TOP_K + i]; + const int32_t base_store_offset = store_rows[i] * n_channels + thread_offset; + store_base_ptrs[i] = moe_input + base_store_offset; + } + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T tmp_buf[vector_size]; + + if (i * load_stride + thread_offset < n_channels) { + mem_access::load_global(tmp_buf, + load_base_ptr + i * load_stride); +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + mem_access::store_global( + store_base_ptrs[j] + i * load_stride, tmp_buf); + } + } + } + + if (threadIdx.x == 0) { + for (int i = 0; i < N_TOP_K; i++) { mapped_slots[token_idx * N_TOP_K + i] = store_rows[i]; } + } +} + +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel \ + <<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ + break; + +template +void launch_moe_scatter(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* expert_counts, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter::threads * scatter::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter::threads); + const dim3 grid(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } + }); +} + +#define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ + template void launch_moe_scatter(TYPE*, \ + int64_t*, \ + int32_t*, \ + const TYPE*, \ + const int32_t*, \ + const int32_t*, \ + const int32_t*, \ + const int32_t, \ + const int32_t, \ + const int32_t, \ + const int32_t, \ + cudaStream_t); + +INSTANTIATE_SCATTER_FOR_TYPE(__half); + +#ifdef BF16_AVAILABLE +INSTANTIATE_SCATTER_FOR_TYPE(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h new file mode 100644 index 000000000000..7876b354af0d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +struct +#ifdef __CUDA_CC__ + __align__(8) +#endif +{ + int32_t n_tokens; + int32_t n_sequences; +} +typedef RaggedBatchDescriptor; + +struct +#ifdef __CUDA_CC__ + __align__(16) +#endif +{ + int32_t start_idx; + int32_t n_tokens; + int32_t seen_tokens; + int32_t UNUSED; // Explicit padding to match the Python code pattern. +} +typedef InflightSeqDescriptor; + +struct +#ifdef __CUDA_CC__ + __align__(8) +#endif +{ + int32_t** block_lists; + int32_t block_size; + int32_t n_blocks; +} +typedef KVCacheDescriptor; + +struct { + const RaggedBatchDescriptor* batch_metadata; // Offset 0 + const InflightSeqDescriptor* seq_metadata; // Offset 8 + const int32_t* tokens_to_seq; // Offset 16 + const KVCacheDescriptor kv_desc; // Offset 24 +} typedef BatchWrapperCPP; diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp new file mode 100644 index 000000000000..a6cb7f275366 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ragged_kernel_helpers.h" + +BatchWrapperCPP make_cpp_batch_wrapper(torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_cache_desc, + int32_t block_size, + int32_t n_blocks) +{ + const RaggedBatchDescriptor* batch_metadata_raw = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_metadata_raw = + reinterpret_cast(seq_metadata.data_ptr()); + + const int32_t* tokens_to_seq_raw = tokens_to_seq.data_ptr(); + + int32_t** kv_ptrs_raw = reinterpret_cast(kv_cache_desc.data_ptr()); + KVCacheDescriptor kv_desc = {kv_ptrs_raw, block_size, n_blocks}; + + BatchWrapperCPP wrapper = {batch_metadata_raw, seq_metadata_raw, tokens_to_seq_raw, kv_desc}; + return wrapper; +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h new file mode 100644 index 000000000000..7ce082d31853 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "ragged_dtypes.h" + +BatchWrapperCPP make_cpp_batch_wrapper(torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_cache_desc, + int32_t block_size, + int32_t n_blocks); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp new file mode 100644 index 000000000000..f320f46e2620 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "atom_builder.h" +#include "blocked_flash.h" +#include "blocked_kv_rotary.h" +#include "embed.h" +#include "logits_gather.h" +#include "moe_gather.h" +#include "moe_scatter.h" +#include "top_k_gating.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // atom_builder.h + m.def("build_atoms", &build_atoms, "Host kernel for building the atoms."); + + // blocked_flash.h + m.def("flash_attn_by_atoms", + &flash_attn_by_atoms, + "Blocked flash attention scheduled with atoms"); + + // blocked_kv_rotary.h + m.def("kv_rotary_embeddings", &kv_rotary_embeddings, "KV rotary embedding for blocked KV"); + m.def("kv_trained_rotary_embeddings", + &kv_trained_rotary_embeddings, + "KV rotary embeddings for blocked KV"); + m.def("linear_kv_copy", &linear_kv_copy, "Linear copy for blocked KV"); + + // embed.h + m.def("ragged_embed", &ragged_embed, "Embedding lookup for ragged batch"); + + // logits_gather.h + m.def("gather_for_logits", &gather_for_logits, "Sparse gather from ragged batch"); + + // moe_gather.h + m.def("moe_gather", &moe_gather, "MoE gather for top-1-gating."); + + // moe_scatter.h + m.def("moe_scatter", &moe_scatter, "MoE scatter for top-1-gating."); + + // top_k_gating.h + m.def("top_k_gating", &top_k_gating, "Top-1 gating for MoE with ragged batch awareness."); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py new file mode 100644 index 000000000000..487735b015b0 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .top_k_gating import RaggedTopKGating diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp new file mode 100644 index 000000000000..5eec7e2b955f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "top_k_gating.h" +#include + +#define DISPATCH_TOP_K_GATING(T_TYPE, C_TYPE) \ + if (logits.options().dtype() == torch::T_TYPE) { \ + launch_top_k_gating((int32_t*)expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (const C_TYPE*)logits.data_ptr(), \ + batch_metadata_ptr, \ + n_tokens, \ + n_experts, \ + n_top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Perform softmax plus atomics in order to do first pass of top_k_gating. +*/ +void top_k_gating(torch::Tensor& expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& logits, + torch::Tensor& batch_metadata) +{ + const int32_t n_tokens = scores.size(0); + const int32_t n_top_k = scores.size(1); + + // Should have the same buffer size for scores, offsets, and assignments + TORCH_CHECK(n_tokens == offsets.size(0)); + TORCH_CHECK(n_tokens == logits.size(0)); + TORCH_CHECK(n_tokens == assignments.size(0)); + + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k == assignments.size(1)); + + TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); + TORCH_CHECK(scores.scalar_type() == torch::kFloat); + TORCH_CHECK(assignments.scalar_type() == torch::kInt32); + TORCH_CHECK(offsets.scalar_type() == torch::kInt32); + + const int32_t n_experts = logits.size(1); + const RaggedBatchDescriptor* batch_metadata_ptr = + reinterpret_cast(batch_metadata.data_ptr()); + + DISPATCH_TOP_K_GATING(kFloat, float) + DISPATCH_TOP_K_GATING(kHalf, __half) +#ifdef BF16_AVAILABLE + DISPATCH_TOP_K_GATING(kBFloat16, __nv_bfloat16) +#endif + + TORCH_CHECK(false, "Unsupported dtype for logits in top_k_gating"); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh new file mode 100644 index 000000000000..c525cc5f524e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +namespace gating { +constexpr int unassigned = -1; +} // namespace gating + +template +void launch_top_k_gating(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_tokens, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h new file mode 100644 index 000000000000..00840c3c93b5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ragged_dtypes.h" +#include "top_k_gating.cuh" + +/* +Perform softmax plus atomics to get token mapping. +*/ +void top_k_gating(torch::Tensor& expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& logits, + torch::Tensor& batch_metadata); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py new file mode 100644 index 000000000000..72ba2b6019bb --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Tuple + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from ....ragged import RaggedBatchWrapper +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class RaggedTopKGating(DSKernelBase): + """ + CUDA implementation of top-1 gating. This will perform a softmax on the logits, + and return the scale as well as its idx within that expert's allocation. + """ + + supported_logit_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32] + + def __init__(self, logit_dtype: DtypeEnum) -> None: + + if not isinstance(logit_dtype, DtypeEnum): + logit_dtype = DtypeEnum(logit_dtype) + + if logit_dtype not in RaggedTopKGating.supported_logit_dtypes: + raise RuntimeError(f"Unsupported logit dtype {logit_dtype}") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.top_k_gating + + def __call__(self, expert_counts: torch.Tensor, scores: torch.Tensor, assignments: torch.Tensor, + offsets: torch.Tensor, logits: torch.Tensor, + batch: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Perform the ragged top_k_gating. + + Arguments: + expert_counts (torch.Tensor): Tensor of 0s of shape [n_experts] to be filled with + number of tokens assigned to each expert. This must be filled with 0s else + the copy kernel will buffer overflow. In order to minimize the zero-fill cost, + it is recommended to write to 0 during the MoE output remapping. + scores (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place expert scaling + value. + expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place + which expert a token has been assigned to. + expert_offset (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which + offset within an experts group a token is. + logits (torch.Tensor): Raw logits of gating function. + batch (RaggedBatchWrapper): Batch information for ragged tensor. + + Returns: + tuple of (expert_counts, scores, expert_assignment, expert_offset) + """ + self.kernel(expert_counts, scores, assignments, offsets, logits, batch.batch_metadata_buffer()) + return expert_counts, scores, assignments, offsets diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating_cuda.cu new file mode 100644 index 000000000000..58f95c045593 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating_cuda.cu @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" +#include "top_k_gating.cuh" +#include "top_k_utils.h" + +using ROp = reduce::ROpType; + +template +__global__ void top_k_gating_kernel(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_experts) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + const int32_t max_warps = 1024 / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Padding tokens do not require + if (token_idx >= batch_metadata->n_tokens) { + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + assignments[token_idx * TOP_K + i] = gating::unassigned; + offsets[token_idx * TOP_K + i] = gating::unassigned; + } + } + return; + } + + const T* token_logits = logits + token_idx * n_experts; + + float logit_val; + if (expert_idx < n_experts) { + logit_val = conversion::to(token_logits[expert_idx]); + } else { + reduce::init(&logit_val); + } + float reduce_val = logit_val; + + int32_t local_assigned_experts[TOP_K]; + float local_assigned_logits[TOP_K]; + + // Training code tends to use ``torch.argmax`` to select the expert, which + // which has ties broken by the lower index. Since our fused comparison algorithm + // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit + // comparison), we invert the expert index to break ties by the lower index. + int32_t inverted_expert = n_experts - expert_idx - 1; + + // Find the top k logits + for (int i = 0; i < TOP_K; ++i) { + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, reduce_val, inverted_expert); + local_assigned_experts[i] = n_experts - res.idx - 1; + local_assigned_logits[i] = res.val; + + // Set the max logit to -inf so that it is not selected again + if (threadIdx.x == n_experts - res.idx - 1) { reduce::init(&reduce_val); } + } + + const float max_logit = local_assigned_logits[0]; + float softmax_sum = __expf(logit_val - max_logit); + reduce::block(tb, warp, softmax_sum); + + for (int i = 0; i < TOP_K; ++i) { + const float softmax = __expf(local_assigned_logits[i] - max_logit) / softmax_sum; + + if (threadIdx.x == 0) { + scores[token_idx * TOP_K + i] = softmax; + assignments[token_idx * TOP_K + i] = local_assigned_experts[i]; + offsets[token_idx * TOP_K + i] = + atomicAdd(expert_counts + local_assigned_experts[i], 1); + } + } +} + +template +void launch_top_k_gating(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_tokens, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); + + TOP_K_SWITCH(n_top_k, [&] { + top_k_gating_kernel<<>>( + expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + }); +} + +#define INSTANTIATE_top_k_KERNEL(T) \ + template void launch_top_k_gating(int32_t * expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + const T* logits, \ + const RaggedBatchDescriptor* batch_metadata, \ + const int32_t n_tokens, \ + const int32_t n_experts, \ + const int32_t n_top_k, \ + cudaStream_t stream); + +INSTANTIATE_top_k_KERNEL(float) INSTANTIATE_top_k_KERNEL(__half) +#ifdef BF16_AVAILABLE + INSTANTIATE_top_k_KERNEL(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/logging.py b/deepspeed/inference/v2/logging.py new file mode 100644 index 000000000000..77afe351cbea --- /dev/null +++ b/deepspeed/inference/v2/logging.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import logging + +from deepspeed.utils.logging import LoggerFactory + +inf_logger = None + + +def inference_logger(level: int = logging.INFO) -> logging.Logger: + """ + Create the inference logger. NOTE: Logging is not cost free. On a 3960X, + there is a cost of about 6 us per call to a no-op logger, so this should + be used during setup only and not during the inference loop. + + Args: + level (int, optional): The logging level. Defaults to logging.INFO. + """ + global inf_logger + if inf_logger is None: + inf_logger = LoggerFactory.create_logger(name="DS-Inference", level=level) + inf_logger.debug("Inference logger created.") + return inf_logger diff --git a/deepspeed/inference/v2/model_implementations/AddingAModel.md b/deepspeed/inference/v2/model_implementations/AddingAModel.md new file mode 100644 index 000000000000..8fe27297080b --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/AddingAModel.md @@ -0,0 +1,84 @@ +# Adding Support for a New Model in DeepSpeed Inference V2 + +Adding supoprt for a new model in DeepSpeed Inference requires developing three related components: +- Containers: These describe the parameters contained in the model +- Model implementation: How should the model be computed. +- Policy: The map for adding parameters to your containers and creating the model implementation. + +In this tutorial, we will assume that you'd like to use a relatively traditionally styled Transformer model and will be able to inherit from `DSTransformerModelBase` and can take advantage of the utilities that provides. + +## Defining Your Containers + +A container is the bridge between the original model's parameters and how to transform them to serve them for inference. For a model implementation, there are two primary kinds of containers: transformer containers and non-transformer containers. A transformer container consists of the parameters for a single Transformer layer in the model. So this includes your traditional parameters like the projections for the fully connected network, or query-key-value projections. The non-transformer container will contain basically everything else! However, before defining these containers, we need to understand how to define an individual parameter. + +In DeepSpeed inference, the original model parameters are populated into the model and mapped as dependencies to a parameter. A `Parameter` has two primary components: its dependencies and its `finalize` method. Let's do an example. In Llama models, the native format is for the `query`, `key`, and `value` projections to be performed independently. However, we can achieve higher throughput by fusing them into a single larger projection. We can define this fusion with a parameter: + +```python +from deepspeed.inference.module_implementations.parameter_base import ParameterBase + +class UnfusedQKVParameter(ParameterBase): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + + def finalize(self) -> torch.Tensor: + fused_param = torch.cat([self.query, self.key, self.value], dim=0) + return self.inference_model.transform_qkv_param(fused_param) +``` + +Let's walk through each part of this implementation. First, parameters should inherit from `ParameterBase`. This will allow it to automatically determine when its dependencies are met and set the appropriate components of a parent `LayerContainer`. The second key component is the type annotations on the class itself. Each type annotation represents a dependency of the parameter. Since the original Llama mode has separate query, key, and value dependencies, our fused parameter will declare dependencies for each. Finally, we have the `finalize` method. This method is automatically called once all dependencies on the layer are met and should return the final parameter. + +In this `finalize` method, we are doing two things: the first is the act of fusing the parameters together through the concatenate method. Note that each of the dependencies can be accessed via `self.{name}`. The second is calling `self.inference_model.transform_qkv_param`. A parameter's finalize method always has access to the inference model. In this case we are using that to use a feature provided by `DSTransformerBase`. This method will automatically shard the parameter for tensor parallelism and then pass it to the linear module implementation to perform additional optimizations or shape transformations, like quantization. + +Since many patterns are very common in Transformer models, `model_implementations.common_parameters` provides implementations for many of the patterns (all compatible with `DSTransformerBase`) to help accelerate development. + +Once all parameters are created, we need to compose them into a layer container. In our simplified Llama model, let's assume there's only QKV and attention output projection matrices. A layer container would appear as the following: + +```python +from deepspeed.inference.module_implementations.layer_container_base import LayerContainer + +class ExampleContainer(LayerContainer): + qkvw: UnfusedQKVParameter + + attn_o: AttentionOutputParameter + + PARAM_MAPPING: { + "self_attn.q_proj.weight": "qkvw.query", + "self_attn.k_proj.weight": "qkvw.key", + "self_attn.v_proj.weight": "qkvw.value", + "self_attn.o_proj.weight": "attn_o.params", + } +``` + +Once again, we have a couple of key components. The first are parameter type annotations. Each annotation corresponds to a parameter that can be used in the model implementation. In the model implementation, I can simply write `container.qkvw` to access my fused and transformed QKV parameter. The second key component is the `PARAM_MAPPING` dictionary. This is our explicit mapping of the names of parameters in the source model to a parameter dependency. This mapping dictionary will be used by the policy to automatically populate dependencies. + +Once you have written `LayerContainer`s for both the transformer and non-transformer parameters, it's time to work on the model implementation! + +## Building a Model Implementation that Inherits from `DSTransformerBase` + +By inheriting from `DSTransformerBase`, most of the implementation work for sharding and transforming parameters will be automatically handled for you. However, there are four key tasks that still need to be completed. + +1. Defining the abstract properties based on your model configuration. +2. Configuring embedding and unembedding modules and the forward implementations for them. +3. Configuring the attention configuration and desired KV cache behaviors. +4. Writing the forward implementation for your layer. + +## Writing a Policy + +The `InferenceV2Policy` is the level of composition. This is the object that will be passed directly to the inference engine and will compose the model implementation and your containers to create an end-to-end solution. There are two main components to be implemented: the first is to create the model that you defined earlier. This is done by implementing the `instantiate_model` method of the policy. In general, this can just be implemented by calling the constructor for your model and passing the engine config, tensor-parallel communication object, and your custom model config. + +The second component is to define how the parameters from the checkpoint will map to each container. From the section on `LayerContainer`s above, you may remember that the `LayerContainer` can handle the internal routing of a checkpoint parameter to its dependency. In order to find the correct `LayerContainer` though, we need a second abstraction: the `ContainerMap`. + +A `ContainerMap` performs this mapping by categorizing checkpoint prefix strings to the type of container they map to. Typically, the easiest way to do this is through iterating over a model checkpoint's state dict or by iterating over the `named_parameters` of a PyTorch model. There are three types of mappings to define: the transformer mappings, the non-transformer mappings, and the what we'll call the rest. Let's work through an example: + +```python +from deepspeed.inference.module_implementations.inference_policy_base import ContainerMap + +def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [MyTransformerContainer(self.model) for _ in range(self.model.num_layers)] + map.set_transformer_params("model.layers", transformer_containers) + + non_transformer_container = MyNonTransformerContainer(self.model) +``` diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py new file mode 100644 index 000000000000..d696368e2c25 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .inference_model_base import DSInferenceModelBase +from .inference_transformer_base import DSTransformerModelBase, DSMoETransformerModelBase +from .inference_policy_base import InferenceV2Policy, ContainerMap +from .sharding import * + +# Model Implementations +from .llama_v2 import * +from .opt import * +from .mistral import * +from .mixtral import * +from .falcon import * +from .phi import * +from .phi3 import * +from .qwen import * +from .qwen_v2 import * +from .qwen_v2_moe import * +from .exaone4 import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py b/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py new file mode 100644 index 000000000000..60963011cd66 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attn_output_parameters import * +from .embedding_parameters import * +from .mlp_parameters import * +from .moe_parameters import * +from .norm_parameters import * +from .qkv_parameters import * +from .unembed_parameters import * +from .invfreq_parameters import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py new file mode 100644 index 000000000000..f220cf7a7125 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common Attention Output Parameter Patterns +""" + + +class AttentionOutputParameter(ParameterBase): + """ + Attention output parameter container. + + Note: The differentiation for something like GQA for this matrix is primarily + encompassed in the sharding logic, which is currently expected to be performed by + the model implementation. + """ + + params: torch.Tensor + """ + Unsharded attention output parameter of shape [model_dim, model_dim] + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_attn_out_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py new file mode 100644 index 000000000000..2ed34b5fd259 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Embedding containers. +""" + + +class EmbeddingParameter(ParameterBase): + """ + Embedding container. This should be safe to use for all types of embeddings (i.e. word, position, + and token type). + """ + + params: torch.Tensor + """ + Vocabulary parameter of shape [vocab_size, model_dim]. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_embedding_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py new file mode 100644 index 000000000000..163f9de81d98 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common InvFreq Parameter Patterns +""" + + +class InvFreqParameter(ParameterBase): + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.params.to(self.inference_model.activation_dtype.value) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py new file mode 100644 index 000000000000..17def1fa021f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +MLP Parameter Containers +""" + + +class MLP1Parameter(ParameterBase): + """ + First MLP projection weight container. This performs a straight pass-through to the + model implementation for transformation. + """ + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + # NOTE(cmikeh2): If we are gated but not in the format specified below, we should trigger a permutation here. + # I am not currently aware of any models that use this format (or how we should even detect it; probably should + # just be a different param entirely, but until then we'll just assume the format is correct). + return self.inference_model.transform_mlp_1_param(self.params) + + +class GatedMLPParameter(ParameterBase): + """ + Gated MLP projection container. + """ + + gate_params: torch.Tensor + """ + Weight parameter for the gating matrix. + """ + + up_params: torch.Tensor + """ + For lack of a better name, the non-gating weight parameters. + """ + + def finalize(self) -> torch.Tensor: + """ + Our gated format (this is different from InferenceV1!) is to have the gate and activated neurons + interleaved. So if we have 4 output neurons (two effective neurons) with 4 input neurons, the finalized + parameter will look like: + [g0_0, g0_1, g0_2, g0_3] + [a0_0, a0_1, a0_2, a0_3] + [g1_0, g1_1, g1_2, g1_3] + [a1_0, a1_1, a1_2, a1_3] + + As a reference, in inference v1, the format is: + [g0_0, g0_1, g0_2, g0_3] + [g1_0, g1_1, g1_2, g1_3] + [a0_0, a0_1, a0_2, a0_3] + [a1_0, a1_1, a1_2, a1_3] + """ + assert self.gate_params.shape[0] == self.up_params.shape[ + 0], "Gated MLP parameters must have the same number of neurons." + total_neurons = self.gate_params.shape[0] + self.up_params.shape[0] + + # flip the order if even with the correct tokenizer we get wrong output + #fused_param = torch.cat([self.up_params, self.gate_params], dim=-1).reshape(total_neurons, -1) + fused_param = torch.cat([self.gate_params, self.up_params], dim=-1).reshape(total_neurons, -1) + return self.inference_model.transform_mlp_1_param(fused_param) + + +class FusedGatedMLPParameter(ParameterBase): + """ + Gated MLP projection container. + """ + + params: torch.Tensor + """ + Weight parameter for the fused gating and non-gating weight parameters. + """ + + def finalize(self) -> torch.Tensor: + gate_params = self.params[:self.params.shape[0] // 2] + up_params = self.params[self.params.shape[0] // 2:] + total_neurons = gate_params.shape[0] + up_params.shape[0] + fused_param = torch.cat([gate_params, up_params], dim=-1).reshape(total_neurons, -1) + return self.inference_model.transform_mlp_1_param(fused_param) + + +class MLP2Parameter(ParameterBase): + """ + Second MLP projection weight container. This performs a straight pass-through to the + model implementation for transformation. + """ + + params: torch.Tensor + """ + Full weight parameter. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_mlp_2_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py new file mode 100644 index 000000000000..8ababf567ba9 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase, ParamList +""" +Moe Parameters + +These parameters are compatible with any model inheriting from ``DSMoETransformerModelBase``. +""" + + +class MoEGatingWeightParameter(ParameterBase): + """ + Gating weight matrix. + """ + + params: torch.Tensor + """ + Projection matrix from the input activations to the gate logits. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_moe_gate_param(self.params) + + +class UnfusedMoEMLP1Parameter(ParameterBase): + """ + This container should be used when the experts are held in separate parameters + and need to be joined into a single group. + """ + + experts: ParamList("n_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + stacked_experts = torch.stack([p for p in self.experts], dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) + + +class UnfusedMoEMLP2Parameter(ParameterBase): + """ + This container should be used when the experts are held in separate parameters + and need to be joined into a single group. + """ + + experts: ParamList("n_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + stacked_experts = torch.stack([p for p in self.experts], dim=0) + return self.inference_model.transform_moe_mlp_2_param(stacked_experts) + + +class UnfusedMoEGatedMLPParameter(ParameterBase): + """ + MoE Parameter for a gated activation function in which the gating matrix is not + fused in the same parameter as the non-gating matrix. + + This is a stacked version of the ``GatedMLPParameter``. Please see that class for more + documentation on the layout of the parameters. + """ + + gating_experts: ParamList("n_experts") # noqa: F821 + + up_experts: ParamList("n_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + transposed_experts = [] + for gate, up in zip(self.gating_experts, self.up_experts): + assert gate.shape[0] == up.shape[0], "Gated MLP parameters must have the same number of neurons." + total_neurons = gate.shape[0] + up.shape[0] + fused_expert = torch.cat([gate, up], dim=-1).reshape(total_neurons, -1) + transposed_experts.append(fused_expert) + + stacked_experts = torch.stack(transposed_experts, dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py new file mode 100644 index 000000000000..81ffcc3221df --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common Attention Output Parameter Patterns +""" + + +class NormParameter(ParameterBase): + """ + Simple normalization container. + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_norm_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py new file mode 100644 index 000000000000..e240137186fe --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common QKV Parameter Patterns +""" + + +class FusedQKVParameter(ParameterBase): + """ + Traditional fused QKV parameters for QKV projection. This is functionally + a direct copy. + + src_qkv_w shape: [3 * out_features, in_features] + qkv_w shape: [3 * out_features, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_qkv_param(self.params) + + +class UnfusedQKVParameter(ParameterBase): + """ + QKV parameter container for unfused QKV projection. + + src_param shapes: 3 x [out_features, in_features] + dst_param shape: [3 x out_features, in_features] + """ + + q_params: torch.Tensor + + k_params: torch.Tensor + + v_params: torch.Tensor + + def finalize(self): + fused_param = torch.cat([self.q_params, self.k_params, self.v_params], dim=0) + return self.inference_model.transform_qkv_param(fused_param) + + +def megatron_qkv_reshape(param: torch.Tensor, head_size: int, n_heads: int) -> torch.Tensor: + assert param.shape[0] == 3 * n_heads * head_size + + all_heads = torch.chunk(param, chunks=3 * n_heads, dim=0) + q_heads = all_heads[::3] + k_heads = all_heads[1::3] + v_heads = all_heads[2::3] + return torch.cat([q_heads, k_heads, v_heads], dim=0) + + +class MegatronQKVParameter(ParameterBase): + """ + QKV parameter container for Megatron-style QKV projection. Megatron stores the parameter + as [n_heads, 3, head_size, in_features] whereas our inference system is built around + [3, n_heads, head_size, in_features]. This container handles the conversion. + + Note: this container expects the model implementation to implement properties for + `head_size` and `n_heads`. + + src_qkv_w shape: [3 * out_features, in_features] + qkv_w shape: [3 * out_features, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + head_size = self.inference_model.head_size + n_heads = self.inference_model.n_heads + + transposed_param = megatron_qkv_reshape(self.params, head_size, n_heads) + return self.inference_model.transform_qkv_param(transposed_param) + + +def transform_gqa_megatron(src_param: torch.Tensor, head_size: int, n_q_heads: int, n_kv_heads: int) -> torch.Tensor: + assert src_param.shape[0] == (2 * n_kv_heads + n_q_heads) * head_size + + head_ratio = n_q_heads // n_kv_heads + + # Reshape to get the groups as the leading dimension + groups_leading_view = src_param.reshape(n_kv_heads, 2 + head_ratio, head_size, -1) + q_heads = groups_leading_view[:, :head_ratio, :, :].reshape(-1, groups_leading_view.shape[-1]) + k_heads = groups_leading_view[:, head_ratio, :, :].reshape(-1, groups_leading_view.shape[-1]) + v_heads = groups_leading_view[:, head_ratio + 1, :, :].reshape(-1, groups_leading_view.shape[-1]) + # Squeeze will remove extra dimension for bias + return torch.cat([q_heads, k_heads, v_heads], dim=0).squeeze() + + +class GQAMegatronQKVParameter(ParameterBase): + """ + QKV parameter for Megatron-style QKV projection with GQA-style QKV projection. In this + storage format each of the groups is stored consecutively, so there will be multiple q_heads, + then one k head, and one v head. + + Note: this container expects the model implementation to implement properties for + `head_size`, `n_q_heads`, and `n_kv_heads`. + + src_qkv_w shape: [(2 * n_kv_heads + n_q_heads) * head_size, in_features] + qkv_w shape: [(2 * n_kv_heads + n_q_heads) * head_size, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + head_size = self.inference_model.head_size + n_q_heads = self.inference_model.n_heads_q + n_kv_heads = self.inference_model.n_heads_kv + transposed_param = transform_gqa_megatron(self.params, head_size, n_q_heads, n_kv_heads) + return self.inference_model.transform_qkv_param(transposed_param) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py new file mode 100644 index 000000000000..9f67c0ce3c27 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Unembedding containers. +""" + + +class UnembedParameter(ParameterBase): + """ + Unembedding parameter. This will likely be mapped to the same original weight in the model as the + embedding, but we have a different preferred sharding approach. + """ + + params: torch.Tensor + """ + Unembedding parameter of shape [vocab_size, model_dim]. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_unembed_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/exaone4/__init__.py b/deepspeed/inference/v2/model_implementations/exaone4/__init__.py new file mode 100644 index 000000000000..bf98a6656e36 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/exaone4/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +from .policy import Exaone4Policy diff --git a/deepspeed/inference/v2/model_implementations/exaone4/container.py b/deepspeed/inference/v2/model_implementations/exaone4/container.py new file mode 100644 index 000000000000..ae5ac0b7c7f8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/exaone4/container.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class Exaone4TransformerContainer(LayerContainer): + """ + Transformer layer container for the EXAONE 4.0 model. + """ + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + q_norm_gamma: NormParameter + k_norm_gamma: NormParameter + post_attn_norm_gamma: NormParameter + post_ff_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "self_attn.q_norm.weight": "q_norm_gamma.params", + "self_attn.k_norm.weight": "k_norm_gamma.params", + "post_attention_layernorm.weight": "post_attn_norm_gamma.params", + "post_feedforward_layernorm.weight": "post_ff_norm_gamma.params", + } + + +class Exaone4NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the EXAONE 4.0 model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/exaone4/model.py b/deepspeed/inference/v2/model_implementations/exaone4/model.py new file mode 100644 index 000000000000..b835dc137fd8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/exaone4/model.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper +from ...kernels.core_ops.cuda_rms_norm.rms_norm import CUDARMSNorm + +from .container import Exaone4NonTransformerContainer, Exaone4TransformerContainer + + +class Exaone4InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for EXAONE 4.0 models. + + Key differences from Mistral/Llama: + - Post-norm architecture (norm after attn/mlp, not before) + - QK-Norm (RMSNorm on Q and K projections per head) + """ + + _non_transformer: Optional[Exaone4NonTransformerContainer] + _transformer: Optional[Iterable[Exaone4TransformerContainer]] + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return getattr(self._config, "head_dim", self.model_dim // self.n_heads) + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "silu": + return ActivationType.SiGLU + elif activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + rope_theta = getattr(self._config, "rope_theta", 1000000.0) + return RotateHalfConfig(theta_base=rope_theta) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._qk_norm = CUDARMSNorm( + channels=self.head_size, + fp_dtype=torch.float16 if self.activation_dtype == DtypeEnum.fp16 else torch.bfloat16, + epsilon=getattr(self._config, "rms_norm_eps", 1e-5), + ) + + def _apply_qk_norm(self, hidden_states: torch.Tensor, q_norm_gamma: torch.Tensor, + k_norm_gamma: torch.Tensor) -> torch.Tensor: + """ + Apply RMSNorm to Q and K projections independently per head. + hidden_states shape: [tokens, (n_q + n_kv + n_kv) * head_size] + """ + tokens = hidden_states.shape[0] + local_n_heads = self.n_heads_q_local + local_n_heads_kv = self.n_heads_kv_local + q_len = local_n_heads * self.head_size + kv_len = local_n_heads_kv * self.head_size + + q = hidden_states[:, :q_len].contiguous() + k = hidden_states[:, q_len:q_len + kv_len].contiguous() + v = hidden_states[:, q_len + kv_len:] + + # Reshape to [tokens * n_heads, head_size] for per-head RMSNorm + q = q.view(-1, self.head_size) + self._qk_norm(q, q, q_norm_gamma) + q = q.view(tokens, q_len) + + k = k.view(-1, self.head_size) + self._qk_norm(k, k, k_norm_gamma) + k = k.view(tokens, kv_len) + + hidden_states[:, :q_len] = q + hidden_states[:, q_len:q_len + kv_len] = k + + return hidden_states + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + EXAONE 4.0 uses post-norm architecture: + hidden = attn(hidden) + hidden = post_attn_norm(hidden) + residual = residual + hidden + hidden = mlp(residual) + hidden = post_ff_norm(hidden) + residual = residual + hidden + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + # Attention block + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self._apply_qk_norm(hidden_states, cur_params.q_norm_gamma, cur_params.k_norm_gamma) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + # Post-attn norm + residual add + _, hidden_states = self.norm(hidden_states, None, cur_params.post_attn_norm_gamma, beta=None) + residual.add_(hidden_states) + + # MLP block + hidden_states = self.mlp_1(residual, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + # Post-ff norm + residual add + _, hidden_states = self.norm(hidden_states, None, cur_params.post_ff_norm_gamma, beta=None) + residual.add_(hidden_states) + + return residual, residual + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, residual, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/exaone4/policy.py b/deepspeed/inference/v2/model_implementations/exaone4/policy.py new file mode 100644 index 000000000000..891c5f1f321f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/exaone4/policy.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Exaone4NonTransformerContainer, Exaone4TransformerContainer +from .model import Exaone4InferenceModel + + +class Exaone4Policy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Exaone4InferenceModel: + return Exaone4InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Exaone4TransformerContainer(self.model) for _ in range(self.model.num_layers)] + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Exaone4NonTransformerContainer(self.model)) + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/falcon/__init__.py b/deepspeed/inference/v2/model_implementations/falcon/__init__.py new file mode 100644 index 000000000000..20f37538274c --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import FalconPolicy diff --git a/deepspeed/inference/v2/model_implementations/falcon/container.py b/deepspeed/inference/v2/model_implementations/falcon/container.py new file mode 100644 index 000000000000..caccfe1ecb00 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/container.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Falcon 7b model looks like this: + +FalconForCausalLM( + (transformer): FalconModel( + (word_embeddings): Embedding(65024, 4544) + (h): ModuleList( + (0-31): 32 x FalconDecoderLayer( + (self_attention): FalconAttention( + (maybe_rotary): FalconRotaryEmbedding() + (query_key_value): FalconLinear(in_features=4544, out_features=4672, bias=False) + (dense): FalconLinear(in_features=4544, out_features=4544, bias=False) + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (mlp): FalconMLP( + (dense_h_to_4h): FalconLinear(in_features=4544, out_features=18176, bias=False) + (act): GELU(approximate='none') + (dense_4h_to_h): FalconLinear(in_features=18176, out_features=4544, bias=False) + ) + (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True) + ) + ) + (ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True) + ) + (lm_head): Linear(in_features=4544, out_features=65024, bias=False) +) +''' + + +class FalconTransformerContainer(LayerContainer): + """ + Transformer layer container for the Falcon model. + """ + qkv_w: FusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_2_w: MLP2Parameter + ln_attn_gamma: NormParameter + ln_attn_beta: NormParameter + + PARAM_MAPPING = { + "self_attention.query_key_value.weight": "qkv_w.params", + "self_attention.dense.weight": "attn_out_w.params", + "mlp.dense_h_to_4h.weight": "mlp_1_w.params", + "mlp.dense_4h_to_h.weight": "mlp_2_w.params", + "input_layernorm.weight": "ln_attn_gamma.params", + "input_layernorm.bias": "ln_attn_beta.params", + } + + +class FalconNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Falcon model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "transformer.word_embeddings.weight": "word_emb.params", + "transformer.ln_f.weight": "final_norm_gamma.params", + "transformer.ln_f.bias": "final_norm_beta.params", + "lm_head.weight": "word_unembed.params", + } + + +''' + # HF Falcon 40b model looks like this: + + FalconForCausalLM( + (transformer): FalconModel( + (word_embeddings): Embedding(65024, 8192) + (h): ModuleList( + (0-59): 60 x FalconDecoderLayer( + (self_attention): FalconAttention( + (maybe_rotary): FalconRotaryEmbedding() + (query_key_value): FalconLinear(in_features=8192, out_features=9216, bias=False) + (dense): FalconLinear(in_features=8192, out_features=8192, bias=False) + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (mlp): FalconMLP( + (dense_h_to_4h): FalconLinear(in_features=8192, out_features=32768, bias=False) + (act): GELU(approximate='none') + (dense_4h_to_h): FalconLinear(in_features=32768, out_features=8192, bias=False) + ) + (ln_attn): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + (ln_mlp): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + ) + ) + (ln_f): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + ) + (lm_head): Linear(in_features=8192, out_features=65024, bias=False) +) +''' + + +class FalconNewArchTransformerContainer(LayerContainer): + """ + Transformer layer container for the Falcon model. + """ + qkv_w: GQAMegatronQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_2_w: MLP2Parameter + ln_attn_gamma: NormParameter + ln_attn_beta: NormParameter + ln_mlp_gamma: NormParameter + ln_mlp_beta: NormParameter + + PARAM_MAPPING = { + "self_attention.query_key_value.weight": "qkv_w.params", + "self_attention.dense.weight": "attn_out_w.params", + "mlp.dense_h_to_4h.weight": "mlp_1_w.params", + "mlp.dense_4h_to_h.weight": "mlp_2_w.params", + "ln_attn.weight": "ln_attn_gamma.params", + "ln_attn.bias": "ln_attn_beta.params", + "ln_mlp.weight": "ln_mlp_gamma.params", + "ln_mlp.bias": "ln_mlp_beta.params", + } diff --git a/deepspeed/inference/v2/model_implementations/falcon/model.py b/deepspeed/inference/v2/model_implementations/falcon/model.py new file mode 100644 index 000000000000..b2830c80b562 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/model.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .container import FalconNonTransformerContainer, FalconTransformerContainer + + +class FalconInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[FalconNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[FalconTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return 4 * self._config.hidden_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_kv_heads if (self._config.new_decoder_architecture + or not self._config.multi_query) else 1 + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.GELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> RotateHalfConfig: + """ + The positional embedding configuration for the model. + """ + return RotateHalfConfig() + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + assert self.config.parallel_attn, "Only parallel attention implementation is supported" + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + attn_ln_out = hidden_states + attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=None) + attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info) + attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=None) + + if self.config.new_decoder_architecture: + residual, mlp_ln_out = self.norm(residual, + None, + gamma=cur_params.ln_mlp_gamma, + beta=cur_params.ln_mlp_beta) + else: + mlp_ln_out = hidden_states + + mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=None) + mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=None) + + mlp_output.add_(attention_output) + + if self.tp_size > 1: + dist.all_reduce(mlp_output, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, mlp_output = self.norm(residual, + mlp_output, + next_params.ln_attn_gamma, + beta=next_params.ln_attn_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(mlp_output) + + return residual, mlp_output + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].ln_attn_gamma, + beta=self._transformer[0].ln_attn_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/falcon/policy.py b/deepspeed/inference/v2/model_implementations/falcon/policy.py new file mode 100644 index 000000000000..c6612090a0df --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/policy.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNewArchTransformerContainer +from .model import FalconInferenceModel + + +class FalconPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> FalconInferenceModel: + return FalconInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + trans_container_cls = FalconNewArchTransformerContainer if self._model_config.new_decoder_architecture else FalconTransformerContainer + transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['transformer.h'], transformer_containers) + + map.set_non_transformer_params(FalconNonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py new file mode 100644 index 000000000000..c5e02adaffc4 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict, Iterable, Tuple, Optional +from os import path + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import RaggedUtilsBuilder +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .layer_container_base import LayerContainer +from ..inference_parameter import InferenceParameter, STR_TO_DTYPE +from ..inference_utils import elem_size + + +def pad_to_aligned_offset(offset: int, alignment: int = 256) -> int: + """ + Pad the provided offset to a well-aligned value. + """ + return ((offset + alignment - 1) // alignment) * alignment + + +class TensorMetadata(DeepSpeedConfigModel): + """ + A class to represent a tensor specification. + """ + dtype: Optional[str] = None + shape: Optional[Tuple[int, ...]] = None + strides: Optional[Tuple[int, ...]] = None + offset: int + + +class ParameterMetadata(DeepSpeedConfigModel): + """ + A class to represent a parameter specification. + """ + core_param: Optional[TensorMetadata] = None + aux_params: Dict[str, TensorMetadata] = {} + + +class LayerMetadata(DeepSpeedConfigModel): + """ + A class to represent a layer specification. + """ + params: Dict[str, ParameterMetadata] = {} + + +class ModelMetadata(DeepSpeedConfigModel): + """ + A class to represent a model specification. + """ + policy: str = "" + layers: Dict[str, LayerMetadata] = {} + + +def make_param_filename(base: str, rank: int, n_ranks: int) -> str: + """ + Make a filename for a parameter file. + + Arguments: + rank: Rank of the file. + n_ranks: Total number of ranks. + + Returns: + str: Filename. + """ + return path.join(base, f"params_rank_{rank}_of_{n_ranks}.pt") + + +def make_metadata_filename(base: str, rank: int, n_ranks: int) -> str: + """ + Make a filename for a metadata file. + + Arguments: + rank: Rank of the file. + n_ranks: Total number of ranks. + + Returns: + str: Filename. + """ + return path.join(base, f"metadata_rank_{rank}_of_{n_ranks}.json") + + +def make_model_config_filename(base: str) -> str: + """ + Make a filename for a model config file. + + Arguments: + base: Base directory. + + Returns: + str: Filename. + """ + return path.join(base, "ds_model_config.json") + + +def flatten_inference_model( + transformer_containers: Iterable[LayerContainer], + non_transformer_container: LayerContainer, + policy_name: str, +) -> Tuple[torch.Tensor, ModelMetadata]: + """ + Flatten the underlying parameters into + + Arguments: + transformer_containers: Iterable of layer containers corresponding to the transformer + parameters. + non_transformer_container: Layer container corresponding to the non-transformer parameters. + policy_name: The name of the policy class (typically accessed with `type(policy).__name__`). + + Returns: + Iterable[Any]: Flattened list of parameters. + """ + alloc_fn = RaggedUtilsBuilder().load().allocate_view_on + + total_size = 0 + metadata = ModelMetadata(policy=policy_name) + + def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) -> int: + """ + Iterate over the parameters of a single container and collect metadata for the final + flattened buffer. + + Arguments: + layer_container: The layer container to process. + l_name: The name of the layer container to key the metadata. + cur_offset: The current offset into the flattened buffer. + + Captured Variables: + metadata: The metadata object to populate. + + Returns: + int: The updated offset into the flattened buffer. + """ + try: + _ = layer_container.is_populated + except ValueError as e: + raise ValueError(f"Layer container {l_name} is not populated.") from e + + layer_metadata = LayerMetadata() + + for p_name in layer_container.annotation_attrs: + param = getattr(layer_container, p_name) + param_metadata = ParameterMetadata() + + if param is None: + param_metadata.core_param = TensorMetadata(offset=-1) + layer_metadata.params[p_name] = param_metadata + continue + + param_metadata.core_param = TensorMetadata(dtype=str(param.dtype), + shape=param.shape, + strides=param.stride(), + offset=cur_offset) + + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) + + for t_name, tensor in param.aux_attrs.items(): + param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype), + shape=tensor.shape, + strides=tensor.stride(), + offset=cur_offset) + + cur_offset += pad_to_aligned_offset(elem_size(tensor.dtype) * tensor.numel()) + + layer_metadata.params[p_name] = param_metadata + + metadata.layers[l_name] = layer_metadata + return cur_offset + + for i, layer in enumerate(transformer_containers): + l_name = f"transformer_layer_{i}" + total_size = process_layer(layer, l_name, total_size) + + l_name = "non_transformer" + total_size = process_layer(non_transformer_container, l_name, total_size) + + buffer = torch.empty(total_size, dtype=torch.uint8, device=get_accelerator().current_device()) + + def copy_layer(layer_container: LayerContainer, l_name: str) -> None: + """ + Local method for copying from the layer container to the flattened buffer. + + Arguments: + layer_container: The layer container to copy from. + l_name: The name of the layer container to key the metadata. + + Captured Variables: + buffer: The flattened buffer to copy into. + metadata: The metadata object to populate. + """ + l_metadata = metadata.layers[l_name] + for p_name in layer_container.annotation_attrs: + p_metadata = l_metadata.params[p_name] + param = getattr(layer_container, p_name) + + if param is None: + continue + + core_param = alloc_fn(param, buffer, p_metadata.core_param.offset) + core_param.copy_(param) + + aux_params = {} + + for t_name, tensor in param.aux_attrs.items(): + t_view = alloc_fn(tensor, buffer, p_metadata.aux_params[t_name].offset) + aux_params[t_name] = t_view + t_view.copy_(tensor) + + setattr(layer_container, p_name, InferenceParameter.initialize(core_param, **aux_params)) + + for i, layer in enumerate(transformer_containers): + l_name = f"transformer_layer_{i}" + copy_layer(layer, l_name) + + l_name = "non_transformer" + copy_layer(non_transformer_container, l_name) + + return buffer, metadata + + +def restore_inference_model(buffer: torch.Tensor, metadata: ModelMetadata, + transformer_containers: Iterable[LayerContainer], + non_transformer_container: LayerContainer) -> None: + """ + Restore the model from the buffer and metadata. + + Arguments: + buffer: Buffer containing the model parameters. + metadata: Metadata for the model. + transformer_containers: Iterable of transformer layer containers. + non_transformer_container: Non-transformer layer container. + """ + alloc_fn = RaggedUtilsBuilder().load().allocate_view_like + + def restore_layer(layer_container: LayerContainer, l_name: str) -> None: + """ + Local method for restoring a layer container from a flattened buffer. This + only constructs views for the parameters onto the buffer. No data movement + is performed. + + Arguments: + layer_container: The layer container to restore. + l_name: The name of the layer container to key the metadata. + + Captured Variables: + buffer: The flattened buffer to reconstruct views on top of. + metadata: The metadata object describing the each parameter in the model. + """ + l_metadata = metadata.layers[l_name] + + for p_name in layer_container.annotation_attrs: + p_metadata = l_metadata.params[p_name] + + if p_metadata.core_param.offset == -1: + layer_container.direct_injection(p_name, None) + continue + + dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) + core_param = alloc_fn(p_metadata.core_param.shape, p_metadata.core_param.strides, dummy_tensor, buffer, + p_metadata.core_param.offset) + + aux_params = {} + + for t_name, t_metadata in p_metadata.aux_params.items(): + dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[t_metadata.dtype]) + t_view = alloc_fn(t_metadata.shape, t_metadata.strides, dummy_tensor, buffer, t_metadata.offset) + + aux_params[t_name] = t_view + + restored_param = InferenceParameter.initialize(core_param, **aux_params) + layer_container.direct_injection(p_name, restored_param) + + for i, layer in enumerate(transformer_containers): + l_name = f"transformer_layer_{i}" + restore_layer(layer, l_name) + + l_name = "non_transformer" + restore_layer(non_transformer_container, l_name) diff --git a/deepspeed/inference/v2/model_implementations/inference_model_base.py b/deepspeed/inference/v2/model_implementations/inference_model_base.py new file mode 100644 index 000000000000..894a4137407e --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_model_base.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import Iterable, Optional, Tuple, Type + +import torch + +import deepspeed.comm as dist +from ..ragged import DSStateManager, RaggedBatchWrapper +from ..ragged.manager_configs import KVCacheConfig +from ..ragged import DSSequenceDescriptor +from ..model_implementations.layer_container_base import LayerContainer +from ..config_v2 import RaggedInferenceEngineConfig +from .flat_model_helpers import ModelMetadata + +try: + from functools import cached_property +except ImportError: + + def cached_property(func): + return property(func) + + +""" +This abstract class defines the interfaces that a model implementation should implement +in order to include anything that may be called by the engine. Most models should be able +to inherit from `DSInferenceTransformerModelBase` to reduce implementation work so it is recommended +to begin there. +""" +""" +Placeholder for typing the model config, which can vary based on model implementation/ +""" +DSModelImplementationConfig = Type['DSModelImplementationConfig'] +""" +Placeholder for typing the distributed comm object. + +TODO(cmikeh2): Replace when we have a more defined API for the inference communication system. +""" +MPType = Type["MPType"] + + +class DSInferenceModelBase(torch.nn.Module, ABC): + """ + Implementation of a model for inference composable with ragged batching. + """ + + _config: DSModelImplementationConfig + """ + Model-specific configuration. No abstraction surrounds this yet. + """ + + _engine_config: RaggedInferenceEngineConfig + """ + Engine configuration. + """ + + _base_mp_group: MPType + """ + Base communication group for Tensor-parallel inference. + """ + + _non_transformer: Optional[LayerContainer] + """ + Abstract container for storing both embedding (pre-transformer) and unembedding (post-transformer) + parameters. This attribute should be None at model instantiation until the Policy sets + the model parameters. These parameters are grouped together since many model implementations + will tie the embedding and unembedding parameters together. + """ + + _transformer: Optional[Iterable[LayerContainer]] + """ + List of abstract containers (1 per layer) for storing transformer (transformer) + parameters. This attribute should be None at model instantiation until the Policy + sets the model parameters. + """ + + state_manager: Optional[DSStateManager] + """ + Since the state manager is lazy initialized, by the engine, it is not guaranteed to be present + until full initialization. + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Minimal initialization of the model. + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__() + self._config = config + self._engine_config = engine_config + self._base_mp_group = base_mp_group + + # Set to None until the Policy sets the model parameters + self._non_transformer = None + self._transformer = None + self._flattened_param_buffer = None + self._flattened_param_metadata = None + + @property + def config(self) -> DSModelImplementationConfig: + """ + The model config. + """ + return self._config + + def set_parameters(self, transformer: Iterable[LayerContainer], non_transformer: LayerContainer, + flattened_param_buffer: torch.Tensor, flattened_param_metadata: ModelMetadata): + """ + Set the model parameters for the embedding, transformer, and unembedding containers. + """ + self._transformer = transformer + self._non_transformer = non_transformer + self._flattened_param_buffer = flattened_param_buffer + self._flattened_param_metadata = flattened_param_metadata + + def set_state_manager(self, state_manager: DSStateManager): + """ + Sets the state manager attribute. This is called by the inference engine after + the model is fully initialized. + """ + self.state_manager = state_manager + + @cached_property + def tp_rank(self) -> int: + """ + The rank of the current process. + + # TODO(cmikeh2): Kind of a hack right now, but this is too verbose to use at + the frequency we need. + """ + return dist.get_rank(group=self._base_mp_group) + + @cached_property + def tp_size(self) -> int: + """ + The total number of processes. + + # TODO(cmikeh2): Kind of a hack right now, but this is too verbose to use at + the frequency we need. + """ + return dist.get_world_size(group=self._base_mp_group) + + @property + def model_config(self): + """ + The model config. + """ + return self._config + + @property + def engine_config(self): + """ + The engine config. + """ + return self._engine_config + + @property + def flattened_params(self) -> Optional[torch.Tensor]: + """ + The flattened parameter buffer. + """ + return self._flattened_param_buffer + + @property + def flattened_param_metadata(self) -> Optional[ModelMetadata]: + """ + The flattened parameter metadata. + """ + return self._flattened_param_metadata + + @abstractmethod + def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int, + max_new_blocks: Tuple[int, ...]) -> Tuple[int, torch.Tensor]: + """ + Given a sequence and the number of new tokens in the sequence, determine the + number of new KV blocks needed to support the sequence. This method is + used to help the engine provide schedulability APIs and can be used as a helper + for ``maybe_allocate_kv``. + + Args: + sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. + max_new_tokens (int): Maximum number of tokens to hypothetically schedule. + max_new_blocks (int): Maximum number of blocks to hypothetically allocate. + + Returns: + Tuple[int, torch.Tensor]: The tuple of number of tokens scheduled and number + of blocks allocated (per KV cache). In general, only one of these numbers will + match the corresponding input argument, but this is not guaranteed. + """ + raise NotImplementedError() + + @abstractmethod + def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: + raise NotImplementedError() + + @abstractmethod + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + """ + Given a sequence and the number of new tokens in the sequence, determine + whether or not additional KV-storage is needed and allocate it if so. + + Args: + sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. + n_new_tokens (int): The number of new tokens in the sequence. + """ + raise NotImplementedError() + + @abstractmethod + def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: + """ + Return the KV-cache configuration for this model. This should be a tuple of one or more + KVCacheConfig objects (one for each distinct cache group). + """ + raise NotImplementedError() + + @property + @abstractmethod + def max_sequence_length(self) -> int: + """ + The maximum sequence length supported by the model. + """ + ... + + def maybe_free_kv(self, sequence: DSSequenceDescriptor) -> None: + """ + After completing a forward pass, determine whether or not the there are any KV blocks + that maybe freed since they are no longer in use. + + Consider the following example: + + We have a block size of 4 and a local window size of 8. At the beginning of the forward + pass there 10 tokens had been seen and the new forward has a size of 4. This would lend + itself to the following cache structure prior to the forward: + [[0, 1, 2*, 3*] [4*, 5*, 6*, 7*] [8*, 9*, x, x] [x x x x]] + Where x's denote empty cache locations and * denote values that are needed for attention + of the next open slot. After the forward, the cache would look like the following: + [[0, 1, 2, 3] [4, 5, 6*, 7*] [8*, 9*, 10*, 11*] [12* 13* x x]] + In this case, the first block is no longer needed since it is not needed for any future + local attention windows. This function would be responsible for freeing that block. + + Default behavior assumes no local patterns that require freeing and in general should + be sufficient. + """ + pass + + @abstractmethod + def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None: + """ + This will be called before each forward with the intent of building forward-specific metadata + about a batch. The intent here is to build data structures like attention atoms without necessarily + needing to implement graphable kernels to do so. + + Abstract so as to force model implementations to opt out of doing anything here explicitly. + """ + raise NotImplementedError() + + def forward(wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Complete a forward pass of the model. This interface should be graphable, so it + should not rely on the ability to use python control flow. + """ + raise NotImplementedError() diff --git a/deepspeed/inference/v2/model_implementations/inference_policy_base.py b/deepspeed/inference/v2/model_implementations/inference_policy_base.py new file mode 100644 index 000000000000..2f4266a8cb88 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_policy_base.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import json +from abc import ABC, ABCMeta, abstractmethod +from typing import Any, Iterable, List, Optional, Union + +import torch + +from ..config_v2 import RaggedInferenceEngineConfig +from ..checkpoint import CheckpointEngineBase +from ..logging import inference_logger +from .layer_container_base import LayerContainer +from .inference_model_base import DSInferenceModelBase +from .flat_model_helpers import ( + flatten_inference_model, + make_param_filename, + make_metadata_filename, + ModelMetadata, + restore_inference_model, +) + +POLICIES = {} + + +class ContainerMap: + + def __init__(self) -> None: + self._prefix_map = {} + self._transformer_params = None + self._non_transformer_params = None + + @property + def transformer_params(self) -> Iterable[LayerContainer]: + return self._transformer_params + + @property + def non_transformer_params(self) -> LayerContainer: + return self._non_transformer_params + + def set_transformer_params(self, prefixes: Union[str, Iterable[str]], containers: List[LayerContainer]) -> None: + if not isinstance(containers, list): + raise ValueError( + f"The transformer containers should be a list, of one container per layer, but got {type(containers)} instead." + ) + + self._transformer_prefixes = prefixes if isinstance(prefixes, list) else [prefixes] + self._transformer_params = containers + + def set_non_transformer_params(self, container: LayerContainer) -> None: + self._non_transformer_params = container + + def set_unmapped_params(self, prefixes: Union[str, Iterable[str]]) -> None: + self._unmapped_prefixes = prefixes + + def map_param(self, name, parameter) -> None: + for unmapped_prefix in self._unmapped_prefixes: + if name.startswith(unmapped_prefix): + inference_logger().debug(f"Ignoring: {name} for {unmapped_prefix}") + return + + for transformer_prefix in self._transformer_prefixes: + if name.startswith(transformer_prefix): + popped_name = name[len(transformer_prefix) + 1:] + layer_idx = popped_name.split(".")[0] + assert layer_idx.isdigit( + ), f"expected name to start w. list index but got {layer_idx} instead, name={name}" + layer_idx = int(layer_idx) + inference_logger().debug( + f"Setting: {'.'.join(popped_name.split('.')[1:])} for layer-idx={layer_idx} to {parameter.shape}") + self._transformer_params[layer_idx].set_dependency(".".join(popped_name.split(".")[1:]), parameter) + return + + try: + inference_logger().debug(f"Setting: {name} to {parameter.shape}") + self._non_transformer_params.set_dependency(name, parameter) + except ValueError: + # Catch the ValueError here from the non_transformer_params because we are knowingly + # calling it with something that may not match. This should allow us to raise a slightly more + # informative error message. + raise ValueError(f"Cannot find container for {name}, please double check the Containers/ContainerMap") + + def validate(self) -> None: + if not self._non_transformer_params.is_initialized: + raise RuntimeError("Non-transformer parameters not fully initialized after checkpoint load.") + + for layer_idx, container in enumerate(self._transformer_params): + if not container.is_initialized: + raise RuntimeError( + f"Transformer container at index {layer_idx} not fully initialized after checkpoint load.") + + +class PolicyMeta(ABCMeta): + + def __new__(cls, name, bases, dct): + new_obj = super().__new__(cls, name, bases, dct) + if name != "InferenceV2Policy": + POLICIES[name] = new_obj + return new_obj + + +class InferenceV2Policy(ABC, metaclass=PolicyMeta): + """ + The InferenceV2Policy is the base class for all inference policies. An inference policy + is responsible for instantiating the inference model and mapping the parameters from the + checkpoint engine to the model itself. + """ + + def __init__( + self, + model_config: Any, + checkpoint_engine: Optional[CheckpointEngineBase] = None, + inf_checkpoint_path: Optional[str] = None, + ) -> None: + """ + Create the Policy with sufficient context to build the model. There are two supported + model creation mechanisms. + + The first is the generalized ``checkpoint_engine`` which + will iterate over the parameters of the model and provide them to the policy. These in + turn will be sharded/transformed by the model implementation. + + The second is used to re-create a previously serialized DeepSpeed inference model. These + checkpoints should not be used across different model backend configurations. + + TODO(cmikeh2): Enforce this in code + """ + if checkpoint_engine is None and inf_checkpoint_path is None: + raise ValueError("Either checkpoint_engine or ds_checkpoint_path must be provided.") + + if checkpoint_engine is not None and inf_checkpoint_path is not None: + raise ValueError("Only one of checkpoint_engine or ds_checkpoint_path can be provided.") + + self._checkpoint_engine = checkpoint_engine + self._inf_checkpoint_path = inf_checkpoint_path + self._model_config = model_config + + def build_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> DSInferenceModelBase: + """ + Completely instantiate the inference model. This will both create the ops needed to run the + model, as well as load the model parameters via the checkpoint engine. For more context + on each of these components please see ``instantiate_model`` and ``populate_model_parameters``. + + Arguments: + engine_config: The config that has been used to instantiate the engine. This is used + to communicate to the model implementation the limits on batches (sequences/tokens) + and bound the size of intermediate buffers. + mp_group: Object to enable communication between tensor parallel ranks. + + Returns: + DSInferenceModelBase: An implementation of the inference model abstraction that will be + run by the engine. + """ + self.model = self.instantiate_model(engine_config, mp_group) + self.populate_model_parameters() + return self.model + + @abstractmethod + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig) -> DSInferenceModelBase: + """ + Instantiate the inference model. Depending on the engine/model config, this could be where + different model implementations could be selected. + + Arguments: + engine_config: The config that has been used to instantiate the engine. This is used + to communicate to the model implementation the limits on batches (sequences/tokens) + and bound the size of intermediate buffers. + + Returns: + DSInferenceModelBase: An implementation of the inference model abstraction that will be + run by the engine. + """ + ... + + @abstractmethod + def build_container_map(self) -> ContainerMap: + """ + Build a dictionary representing the structure of the string prefixes leading + to the parameters to be mapped to the container. + + Returns: + ContainerMap: An instantiated mapping describing how checkpoint prefixes map + to ``LayerContainer`` instances. + """ + raise NotImplementedError() + + def populate_model_parameters(self) -> None: + """ + This model will iterate over the parameters (as provided by the checkpoint engine) and + use the container map built by ``build_container_map`` to populate the model + """ + + container_map = self.build_container_map() + + if self._checkpoint_engine is not None: + for name, parameter in self._checkpoint_engine.parameters(): + container_map.map_param(name, parameter) + + buffer, metadata = flatten_inference_model(container_map.transformer_params, + container_map.non_transformer_params, self.__class__.__name__) + else: + + buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) + metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) + + buffer = torch.load(buffer_path, weights_only=False) + metadata = json.load(open(metadata_path, "r")) + metadata = ModelMetadata.parse_raw(metadata) + + restore_inference_model(buffer, metadata, container_map.transformer_params, + container_map.non_transformer_params) + + container_map.validate() + + self.model.set_parameters(transformer=container_map.transformer_params, + non_transformer=container_map.non_transformer_params, + flattened_param_buffer=buffer, + flattened_param_metadata=metadata) diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py new file mode 100644 index 000000000000..fae67dc8fc2a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -0,0 +1,617 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ..config_v2 import RaggedInferenceEngineConfig +from ..inference_utils import ActivationType, ceil_div, is_gated +from ..model_implementations import * +from ..model_implementations.sharding import * +from ..modules.configs import ( + DSEmbeddingsConfig, + DSLinearConfig, + DSMoEConfig, + DSNormConfig, + DSSelfAttentionConfig, + DSUnembedConfig, + NormTypeEnum, + PositionalEmbeddingType, + RotateHalfConfig, +) +from ..modules import heuristics +from ..ragged import ( + DSSequenceDescriptor, + KVCacheConfig, + RaggedBatchWrapper, +) +from .inference_model_base import ( + DSInferenceModelBase, + DSModelImplementationConfig, + MPType, +) +from ..inference_parameter import InferenceParameter + +try: + from functools import cached_property +except ImportError: + + def cached_property(func): + return property(func) + + +class DSTransformerModelBase(DSInferenceModelBase): + """ + Dimensioning properties + """ + + @property + @abstractmethod + def num_layers(self) -> int: + """ + Number of the layers in the model + """ + ... + + @property + @abstractmethod + def model_dim(self) -> int: + """ + Size of embedding projection and residuals. + """ + ... + + @property + @abstractmethod + def vocab_size(self) -> int: + """ + Size of the vocabulary (including padding). + """ + ... + + @property + @abstractmethod + def head_size(self) -> int: + """ + Size of each attention head. + """ + ... + + @property + @abstractmethod + def n_heads(self) -> int: + """ + The number of query heads on the model. This should not take into account + any dimension reductions from model sharding. + """ + ... + + @property + def n_heads_q(self) -> int: + """ + Alias to n_heads. + """ + return self.n_heads + + @property + def n_heads_kv(self) -> int: + """ + The number of key and value heads on the model. For GQA or MQA, overload this attribute. + Otherwise it adopts MHA formulations and uses n_heads. This should not take into account + any dimension reductions from model sharding. + """ + return self.n_heads + + @property + @abstractmethod + def intermediate_dim(self) -> int: + """ + The size of the (unsharded) intermediate projection dim. For a gated activation function + this is the size of the input to the second MLP layer. This should not take into account + any dimension reductions from model sharding. + """ + ... + + @property + @abstractmethod + def positional_embedding_type(self) -> PositionalEmbeddingType: + """ + The type of positional embedding used by the model. + """ + ... + + """ + Architectural properties + """ + + @property + @abstractmethod + def activation_dtype(self) -> torch.dtype: + """ + The activation dtype of the model. + """ + ... + + @property + @abstractmethod + def mlp_activation_fn(self) -> ActivationType: + """ + The activation function used in the MLP. + """ + ... + + @property + @abstractmethod + def norm_type(self) -> NormTypeEnum: + """ + The type of normalization used in the model. + """ + ... + + @property + @abstractmethod + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + """ + The positional embedding configuration for the model. + """ + ... + + """ + Derived helpers + """ + + @cached_property + def n_heads_q_local(self) -> int: + """ + Number of local heads post sharding. + """ + return get_local_heads(self.tp_rank, self.tp_size, self.n_heads_q, self.n_heads_kv)[0] + + @cached_property + def n_heads_kv_local(self) -> int: + """ + Number of local heads post sharding. + """ + return get_local_heads(self.tp_rank, self.tp_size, self.n_heads_q, self.n_heads_kv)[1] + + @property + def gated_mlp(self) -> bool: + """ + Return a boolean to determine whether the model uses a gated activation function. + """ + return is_gated(self.mlp_activation_fn) + + """ + Method implementations + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_mlp_1_layer() + self.make_mlp_2_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + ######### Embedding ######### + def make_embedding_layer(self) -> None: + """ + Performs setup and creates embedding DSModule. This will set the `self.embed` attribute. + """ + + embed_config = DSEmbeddingsConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + residual_dtype=self.activation_dtype, + embedding_dim=self.model_dim, + ) + + self.embed = heuristics.instantiate_embed(embed_config, self._engine_config) + + def transform_embedding_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Performs embedding sharding along the channels dimension. + """ + # Until we can do non-contiguous all-gather, we won't shard the embedding parameters. + param = param.to(self.activation_dtype.value) + return InferenceParameter.initialize(param) + + ######### Unembedding ######### + def make_unembedding_layer(self) -> None: + """ + Performs setup and creates an unembedding layer. This implementation assumes + normalization prior to the LM head projection. If this does not match the model's + implementation, override this method. This will set the ``self.unembed`` attribute. + """ + unembed_dim = sharded_unembed_dim(self.vocab_size, self.tp_rank, self.tp_size) + + unembed_config = DSUnembedConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + dtype=self.activation_dtype, + model_dim=self.model_dim, + vocab_size=unembed_dim, + norm_type=self.norm_type, + ) + + self.unembed = heuristics.instantiate_unembed(unembed_config, self._engine_config) + + if self.tp_size > 1: + self._comm_logits = torch.empty(self.tp_size, + self._engine_config.state_manager.max_ragged_sequence_count, + unembed_dim, + device=get_accelerator().current_device(), + dtype=self.activation_dtype.value) + self._return_logits = torch.empty(self._engine_config.state_manager.max_ragged_sequence_count, + self.vocab_size, + device=get_accelerator().current_device(), + dtype=self.activation_dtype.value) + + def transform_unembed_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Performs sharding along the vocab dimension. + """ + param = shard_unembed_param(param, self.tp_rank, self.tp_size).to(self.activation_dtype.value) + return InferenceParameter.initialize(param) + + ######### QKV ######### + def make_qkv_layer(self) -> None: + """ + Instantiates the linear projection layer for the QKV linear layer. This sets the + `self.qkv` attribute. + """ + out_features = qkv_out_features(self.model_dim, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=out_features, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.qkv = heuristics.instantiate_linear(linear_config, self._engine_config) + + def transform_qkv_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Passes a QKV parameter to the underlying implementation for any necessary + transformations. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons) + """ + param = shard_qkv_param(param, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, self.n_heads_kv) + return self.qkv.transform_param(param) + + ######### Attention ######### + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=self.positional_embedding_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int, + max_new_blocks: int) -> Tuple[int, int]: + """ + See ``DSInferenceModelBase.get_kv_requirements`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + total_tokens = sequence.seen_tokens + max_new_tokens + req_blocks = ceil_div(total_tokens, self.attn.kv_block_size) + block_lim = req_blocks - sequence.cur_allocated_blocks + + if block_lim <= max_new_blocks: + return max_new_tokens, block_lim + + token_capacity = (max_new_blocks + + sequence.cur_allocated_blocks) * self.attn.kv_block_size - sequence.seen_tokens + + return token_capacity, max_new_blocks + + def get_remaining_block_capacity(self, sequence: DSSequenceDescriptor) -> int: + return sequence.seen_tokens % self.attn.kv_block_size + + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + """ + See ``DSInferenceModelBase.maybe_allocate_kv`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + free_block = self.state_manager.free_blocks[0] + _, n_needed_blocks = self.get_kv_requirements(sequence, n_new_tokens, free_block) + + if n_needed_blocks > 0: + new_blocks = self.state_manager.allocate_blocks(n_needed_blocks) + sequence.extend_kv_cache(new_blocks) + + def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: + """ + See ``DSInferenceModelBase.kv_cache_config`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + if self._kv_cache_config is None: + cache_shape = (self.num_layers, self.n_heads_kv_local, self.head_size) + max_blocks = ceil_div(self.max_sequence_length, self.attn.kv_block_size) + self._kv_cache_config = KVCacheConfig(block_size=self.attn.kv_block_size, + cache_shape=cache_shape, + cache_dtype=self.activation_dtype, + max_blocks_per_allocation_group=max_blocks) + return (self._kv_cache_config, ) + + def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None: + """ + See ``DSInferenceModelBase.prepare_batch`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + self.attn.build_atoms(wrapped_batch) + + ######### Attention output ######### + def make_attn_out_layer(self) -> None: + """ + Instantiates the linear projection layer for the attention output linear layer. This sets the + `self.attn_out` attribute. + """ + in_features = attn_out_in_features(self.model_dim, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=in_features, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.attn_out = heuristics.instantiate_linear(linear_config, self._engine_config) + + def transform_attn_out_param(self, param: torch.Tensor) -> Optional[InferenceParameter]: + """ + Shards an attention output projection parameter and passes it to the underlying + implementation for any necessary transformations. This will return `None` for bias parameters + if they are not on TP rank 0. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_attn_out_param(param, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + if param is not None: + param = self.attn_out.transform_param(param) + + return param + + ######### MLP ######### + def make_mlp_1_layer(self) -> None: + """ + Instantiates the linear projection layer for the first MLP in the feedforward network. + This sets the `self.mlp_1` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=shard_size, + activation=self.mlp_activation_fn, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config) + + def transform_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Shards the first MLP parameter and passes it to the underlying implementation + for any necessary transformations. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_mlp_1_param(param, self.tp_rank, self.tp_size, gated=self.gated_mlp) + + return self.mlp_1.transform_param(param) + + def make_mlp_2_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config) + + def transform_mlp_2_param(self, param: torch.Tensor) -> Optional[InferenceParameter]: + """ + Shards the second MLP parameter and passes it to the underlying implementation + for any necessary transformations. This will return `None` for bias parameters + if they are not on TP rank 0. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_mlp_2_param(param, self.tp_rank, self.tp_size) + + if param is not None: + param = self.mlp_2.transform_param(param) + + return param + + ######### Norm ######### + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + def transform_norm_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Passes a normalization parameter to the underlying implementation for any + necessary transformations. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + shape (model_dim,) + """ + return self.norm.transform_param(param) + + +class DSMoETransformerModelBase(DSTransformerModelBase): + + @property + def n_experts(self) -> int: + """ + Return the number of experts in the model. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts") + + @property + def n_top_k(self) -> int: + """ + Number of experts per token. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts per token") + + @property + def normalize_expert_scores(self) -> bool: + """ + Whether to normalize expert scores. If true, sum(expert_scores) = 1. + """ + raise NotImplementedError("Attempted to access an unimplemented normalization flag") + + def make_moe_layer(self) -> None: + """ + Instantiates the MoE layer for the model. This sets the `self.moe` attribute. + """ + sharded_dim = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + moe_config = DSMoEConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + model_dim=self.model_dim, + intermediate_features=sharded_dim, + activation=self.mlp_activation_fn, + n_experts=self.n_experts, + top_k=self.n_top_k, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, + ) + + self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) + + def transform_moe_gate_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Passes a MoE gate parameter to the underlying implementation for any necessary transformations. + + TODO(cmikeh2): This will need to be updated/overridden for expert parallelism. + """ + return self.moe.transform_gate_param(param) + + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Shards the first MoE param and passes it to the underlying implementation. Since it's possible for an architecture + to have both MoE and non-MoE layers, this can't be overloaded on the MLP1 transform. Furthermore, since both + the MoE DSModule owns both MLP1 and MLP2, under certain sharding conditions it's not possible for the model implementation + to infer from the shape whether to perform a different transformation based on MLP1 or MLP2. This (and the below) + separations are intended to solve both these issues. + + Args: + param (torch.Tensor): The parameter to transform. This should have shape (n_experts, out_neurons, in_neurons). + """ + param = shard_mlp_1_param(param, self.tp_rank, self.tp_size, gated=self.gated_mlp, is_moe=True) + + return self.moe.transform_moe_mlp_1_param(param) + + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> Optional[torch.Tensor]: + """ + Shards the second MoE param and passes it to the underlying implementation. See the above for context on why this API + exists. + + This will return `None` for expert bias params not on TP rank 0. NOTE(cmikeh2): Does it make sense to round-robin assign? + My intuition is that this will make debugging much more difficult for minimal memory reduction. + + Args: + param (torch.Tensor): The parameter to transform. This should have shape (n_experts, out_neurons, in_neurons). + """ + param = shard_mlp_2_param(param, self.tp_rank, self.tp_size, is_moe=True) + + if param is not None: + param = self.moe.transform_moe_mlp_2_param(param) + + return param diff --git a/deepspeed/inference/v2/model_implementations/layer_container_base.py b/deepspeed/inference/v2/model_implementations/layer_container_base.py new file mode 100644 index 000000000000..8357f076cc02 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/layer_container_base.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import re +from typing import Type + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.compat import get_annotations_from_namespace, get_annotations +from .parameter_base import ParameterBase, ParametrizedList +from ..inference_parameter import InferenceParameter + +# Currently have dependency loops for the type hints. +InferenceModel = Type["InferenceModel"] +LayerContainer = Type["LayerContainer"] # noqa: F811 + +MAPPING_KEY = "PARAM_MAPPING" +PLIST_HELPERS = "_ds_plist_strip_vals" + + +def make_finalization_callback(all_names: str): + """ + Helper method for building the finalization callback for a LayerContainer. This + is not client code and should not be used or called directly. + """ + + def finalization_callback(self, param: ParameterBase, finalized_param: torch.Tensor) -> None: + """ + Callback for when a parameter is finalized. + """ + self._finalized_params += 1 + + for name in all_names: + if getattr(self, name) is param: + setattr(self, name, finalized_param) + + return finalization_callback + + +class LayerMetaclass(type): + """ + MetaClass for the LayerContainer base class. This class will parse the annotations + of the class that correspond to `ParameterBase` and create None initializers for each + as well as a finalization callback that for when each `ParameterBase` is finalized + and should be replaced with a Tensor. + """ + + def __new__(cls, clsname, bases, attrs): + + annotations = get_annotations_from_namespace(attrs) + + for base in bases: + # We'll pick up all annotations on any base classes. This will allow us to + # to use inheritance to share common parameter groups in base classes. + annotations.update(get_annotations(base)) + + if hasattr(base, MAPPING_KEY): + if MAPPING_KEY not in attrs: + # This is likely a fail state. If a parent has MAPPING KEY but the child does + # not, then we're guaranteed only a subset of the parameters will be mapped. + attrs[MAPPING_KEY] = {} + attrs[MAPPING_KEY].update(getattr(base, MAPPING_KEY)) + + all_names = [name for name, annotation in annotations.items() if issubclass(annotation, ParameterBase)] + + if MAPPING_KEY in attrs: + # If we have a mapping key at all, then we will enter the validation mode for building + # helpers for mapping and ensuring we have complete mapping. + + # First we'll build a flat list of every dependency for this layer. + all_deps = set() + for name in all_names: + parameter_deps = [ + name for name, annotation in get_annotations(annotations[name]).items() + if issubclass(annotation, (torch.Tensor, ParametrizedList)) + ] + + all_deps.update([f"{name}.{dep}" for dep in parameter_deps]) + + # Create static helper for doing the string processing only once. + attrs[PLIST_HELPERS] = [] + + # Iterate over all the mappings + for src_name, target_or_targets in attrs[MAPPING_KEY].items(): + if isinstance(target_or_targets, str): + target_or_targets = [target_or_targets] + + actual_targets = [] + for target_name in target_or_targets: + base_dependency, dependency_attr = target_name.split(".") + + # Check for invalid mappings + if base_dependency not in all_names: + raise ValueError( + "Target parameter \"{}\" not found in this layer. Valid targets are {}".format( + base_dependency, all_names)) + if dependency_attr not in get_annotations(annotations[base_dependency]): + # This check is not universal (see below) if a single dependency is being + # mapped to by a single row. + raise ValueError( + "Target dependency \"{}\" not found on parameter \"{}\". Valid targets are {}".format( + dependency_attr, base_dependency, + get_annotations(annotations[base_dependency]).keys())) + if target_name not in all_deps: + raise ValueError( + "Target dependency \"{}\" was targeted with multiple mapping rules.".format(target_name)) + + # If we've made it this far, the dependency definitely exists. + actual_targets.append(get_annotations(annotations[base_dependency])[dependency_attr]) + + all_deps.remove(target_name) + + are_plists = [issubclass(target, ParametrizedList) for target in actual_targets] + if all(are_plists): + # We can do direct sets on everything but ParametrizedLists, so we'll only explicitly + # handle these here. + # TODO(cmikeh2): SPLIT, error if more than 1 + glob_count = src_name.count("*") + if glob_count > 1: + raise ValueError( + "ParametrizedList index inference can only work with a single glob: {}".format(src_name)) + elif glob_count == 0: + raise ValueError( + "Must have wildcard (*) in source name for ParametrizedList mapping: {}".format(src_name)) + + wildcard_idx = src_name.find("*") + prefix = src_name[:wildcard_idx] + suffix = src_name[wildcard_idx + 1:] + attrs[PLIST_HELPERS].append((prefix, suffix, target_or_targets)) + elif any(are_plists): + raise ValueError("Cannot mix ParametrizedLists and Tensors in a single mapping rule.") + + if len(all_deps) > 0: + raise ValueError( + "A parameter mapping was provided for {}, but the following dependencies were not mapped: {}". + format(clsname, all_deps)) + + attrs["finalization_callback"] = make_finalization_callback(all_names) + + new_obj = super().__new__(cls, clsname, bases, attrs) + + setattr(new_obj, "_n_params", len(all_names)) + setattr(new_obj, "_annotation_attrs", all_names) + + return new_obj + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls, *args, **kwargs) + instance.__init__(*args, **kwargs) + + for name, annotation in get_annotations(instance).items(): + if issubclass(annotation, ParameterBase): + # TODO(cmikeh2): Do we want to make this a property + # It might also make sense to do this in the base class __init__ + # but since it is tied with the changes made in __new__ it feels + # to me like it should be here. + setattr(instance, name, annotation(instance.inference_model, instance)) + + return instance + + +class LayerContainer(metaclass=LayerMetaclass): # noqa: F811 + """ + Abstract base class for containing model parameters. + + This is primarily a guidance abstraction since we do not put any restrictions + on how the parameters are stored. + + To use this class, annotate the class with `ParameterBase` subclasses and give them + names. As a checkpoint is loaded into this container, the `ParameterBase` instances + will be replaced with realized Tensors as soon as each of their dependencies are met. + + To enable automatic mapping, add a static attribute `PARAM_MAPPING` to the class + definition. This should be a dictionary mapping from a source string to one or + more dependencies. + + ```python + class MyLayer(LayerContainer): + PARAM_MAPPING = { + "path.to.param.dependency", "container_param_1.dependency", + "path.to.param2.dependency", "container_param_2.dependency", + "path.to.param3.*.dependency", "container_param_3.list_dependency" + } + + ... + ``` + """ + + def __init__(self, model: InferenceModel) -> None: + """ + Initialization of the LayerContainer. This method does not need to be overridden + for any children classes. + + Args: + model (InferenceModel): Inference model that will be used to shard and transform + parameters correctly, as well as provide specific information about the model + for `ParameterizedList`s that may be part of one of the member `ParameterBase`s. + """ + self.inference_model = model + self._finalized_params = 0 + + def _initialization_checker(self, check_device: bool = True) -> bool: + """ + Returns whether or not all parameters have been initialized and transformed by + the model. Once this returns True, all the `ParameterBase` instances will be + torch.Tensors. + """ + if self._finalized_params != self.n_params: + return False + + for name in self._annotation_attrs: + tensor = getattr(self, name) + if tensor is None: + continue + elif not isinstance(tensor, InferenceParameter): + raise ValueError("Layer should be finalized, but {} ({}) is neither InferenceParameter or None".format( + name, type(tensor))) + elif check_device and tensor.device != torch.device(get_accelerator().current_device()): + raise RuntimeError("Layer should be finalized, but {} is not on device {}".format( + name, + get_accelerator().current_device())) + return True + + @property + def is_populated(self) -> bool: + """ + Returns whether or not all parameters have been populated by the checkpoint engine, but + does not validat the parameters are on the correct device. + """ + return self._initialization_checker(check_device=False) + + @property + def is_initialized(self) -> bool: + """ + Returns whether or not all parameters have been initialized and transformed by + the model and are located on the appropriate device. Once this returns True, all + the `ParameterBase` instances ``InferenceParameter``s or explicitly set to ``None``. + """ + return self._initialization_checker() + + @property + def n_params(self) -> int: + """ + The number of parameters this container holds. This is a read-only value + that is set by the metaclass. + """ + return self._n_params + + @property + def annotation_attrs(self) -> list: + return self._annotation_attrs + + @property + def mapping_params(self) -> dict: + return getattr(self.__class__, MAPPING_KEY, {}) + + @property + def plist_helpers(self) -> list: + return getattr(self.__class__, PLIST_HELPERS, []) + + def direct_injection(self, name: str, tensor: InferenceParameter) -> None: + + if name not in self._annotation_attrs: + raise ValueError(f"Cannot directly inject {name}, not a valid parameter.") + + setattr(self, name, tensor) + self._finalized_params += 1 + + def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None: + """ + Set dependency can be used for managing dependencies when a mapping is provided + in the class definition for the layer. The dep_name here should have any prefix + for transformer layers removed (such as model.layers.*.attn.qkv.weight -> attn.qkv.weight). + + Args: + dep_name (str): The name of the dependency to set. + dep_value (torch.Tensor): The value to set the dependency to. + """ + + def get_dep_name_target(dep_name: str) -> str: + """ + Helper method for getting the target name for a dependency from the + mapping params. Tries to match exact string first, then looks for + wildcards and attempts regex matching. Will return empty string if + no match found. + """ + if dep_name in self.mapping_params: + # If we have an exact match, it's a direct mapping and we can + # immediately set the value. + return self.mapping_params[dep_name] + + matched_targets = [] + for key, target in self.mapping_params.items(): + regex_key = key.replace("*", ".*") + if re.match(regex_key, dep_name): + matched_targets.append(target) + if len(matched_targets) > 1: + raise ValueError(f"Multiple targets matched for dependency {dep_name}: {matched_targets}") + if matched_targets: + return matched_targets[0] + return "" + + if dep_name in self.mapping_params: + # If we have an exact match, it's a direct mapping and we can immediately set + # the value. + target = self.mapping_params[dep_name] + + # Convert single targets to a list for consistency + if isinstance(target, str): + target = [target] + + for target_name in target: + # Double setting doesn't set the attribute correctly, so we do a getattr then setattr + target_param_name, target_dependency_name = target_name.split(".") + target_param = getattr(self, target_param_name) + setattr(target_param, target_dependency_name, dep_value) + return + + # Otherwise we need to map to one of the parameter lists. + for prefix, suffix, dests in self.plist_helpers: + if dep_name.startswith(prefix) and dep_name.endswith(suffix): + # We have a match, so we can set the value. + target_idx = int(dep_name[len(prefix):-len(suffix)]) + + # Convert single targets to a list for consistency + if isinstance(dests, str): + dests = [dests] + + for dest in dests: + target_param_name, target_dependency_name = dest.split(".") + target_param = getattr(self, target_param_name) + target_dependency = getattr(target_param, target_dependency_name) + target_dependency[target_idx] = dep_value + return + + # TODO: Refactor this with the help of cmikeh2 + # We should be able to combine this with the wildcard matching above. + target = get_dep_name_target(dep_name) + if target: + # Convert single targets to a list for consistency + if isinstance(target, str): + target = [target] + + for target_name in target: + # Double setting doesn't set the attribute correctly, so we do a getattr then setattr + target_param_name, target_dependency_name = target_name.split(".") + target_param = getattr(self, target_param_name) + setattr(target_param, target_dependency_name, dep_value) + return + + raise ValueError( + "Could not find a mapping for dependency \"{}\". Check that it is included in the ``MAPPING_PARAMS``. See docstring for more on ``MAPPING_PARAMS``" + .format(dep_name)) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py new file mode 100644 index 000000000000..79605a76a4c2 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Llama2Policy diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/container.py b/deepspeed/inference/v2/model_implementations/llama_v2/container.py new file mode 100644 index 000000000000..9de9bdb34574 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/container.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Llama model looks like this: + +LlamaForCausalLM( + (model): LlamaModel( + (embed_tokens): Embedding(32000, 4096, padding_idx=0) + (layers): ModuleList( + (0-31): 32 x LlamaDecoderLayer( + (self_attn): LlamaAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=4096, bias=False) + (v_proj): Linear(in_features=4096, out_features=4096, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + (rotary_emb): LlamaRotaryEmbedding() + ) + (mlp): LlamaMLP( + (gate_proj): Linear(in_features=4096, out_features=11008, bias=False) + (up_proj): Linear(in_features=4096, out_features=11008, bias=False) + (down_proj): Linear(in_features=11008, out_features=4096, bias=False) + (act_fn): SiLUActivation() + ) + (input_layernorm): LlamaRMSNorm() + (post_attention_layernorm): LlamaRMSNorm() + ) + ) + (norm): LlamaRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=32000, bias=False) +) +''' + + +class Llama2TransformerContainer(LayerContainer): + """ + Transformer layer container for the Llama-2 model. + """ + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Llama2NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Llama-2 model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/model.py b/deepspeed/inference/v2/model_implementations/llama_v2/model.py new file mode 100644 index 000000000000..a0c81f4d749e --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/model.py @@ -0,0 +1,209 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer + + +class Llama2InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Llama2NonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Llama2TransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + # llama model family is special and is always gated so force gated versions of relu, gelu, silu + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/policy.py b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py new file mode 100644 index 000000000000..bb13ab6d5bf4 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer +from .model import Llama2InferenceModel + + +class Llama2Policy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Llama2InferenceModel: + return Llama2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Llama2TransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Llama2NonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/mistral/__init__.py b/deepspeed/inference/v2/model_implementations/mistral/__init__.py new file mode 100644 index 000000000000..60d636693ef3 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import MistralPolicy diff --git a/deepspeed/inference/v2/model_implementations/mistral/container.py b/deepspeed/inference/v2/model_implementations/mistral/container.py new file mode 100644 index 000000000000..b4c0956f4049 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/container.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +''' + # HF Mistral model (mistralai/Mistral-7B-v0.1) looks like this: +MistralForCausalLM( + (model): MistralModel( + (embed_tokens): Embedding(32000, 4096) + (layers): ModuleList( + (0-31): 32 x MistralDecoderLayer( + (self_attn): MistralAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + (rotary_emb): MistralRotaryEmbedding() + ) + (mlp): MistralMLP( + (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) + (up_proj): Linear(in_features=4096, out_features=14336, bias=False) + (down_proj): Linear(in_features=14336, out_features=4096, bias=False) + (act_fn): SiLUActivation() + ) + (input_layernorm): MistralRMSNorm() + (post_attention_layernorm): MistralRMSNorm() + ) + ) + (norm): MistralRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=32000, bias=False) +) +''' + + +class MistralTransformerContainer(LayerContainer): + """ + Transformer layer container for the Mistral model. + """ + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class MistralNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Mistral model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py new file mode 100644 index 000000000000..318d362f1a64 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .container import MistralNonTransformerContainer, MistralTransformerContainer + + +class MistralInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Mistral models. + """ + + _non_transformer: Optional[MistralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MistralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mistral/policy.py b/deepspeed/inference/v2/model_implementations/mistral/policy.py new file mode 100644 index 000000000000..b67ec311c952 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MistralNonTransformerContainer, MistralTransformerContainer +from .model import MistralInferenceModel + + +class MistralPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MistralInferenceModel: + return MistralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [MistralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MistralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/mixtral/__init__.py b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py new file mode 100644 index 000000000000..2cb1aa889291 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import MixtralPolicy diff --git a/deepspeed/inference/v2/model_implementations/mixtral/container.py b/deepspeed/inference/v2/model_implementations/mixtral/container.py new file mode 100644 index 000000000000..6ec4a0552b8f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/container.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MixtralTransformerContainer(LayerContainer): + + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "block_sparse_moe.gate.weight": "moe_gate.params", + "block_sparse_moe.experts.*.w1.weight": "moe_mlp_1.gating_experts", + "block_sparse_moe.experts.*.w3.weight": "moe_mlp_1.up_experts", + "block_sparse_moe.experts.*.w2.weight": "moe_mlp_2.experts", + } + + +class MixtralNonTransformerContainer(LayerContainer): + + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "lm_head.weight": "word_unembed.params", + "model.norm.weight": "final_norm.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mixtral/model.py b/deepspeed/inference/v2/model_implementations/mixtral/model.py new file mode 100644 index 000000000000..878cd8e31cec --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/model.py @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import MixtralNonTransformerContainer, MixtralTransformerContainer + + +class MixtralInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Mixtral models. + """ + + _non_transformer: Optional[MixtralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MixtralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + """ + The positional embedding configuration for the model. + """ + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_local_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return True + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma) + + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mixtral/policy.py b/deepspeed/inference/v2/model_implementations/mixtral/policy.py new file mode 100644 index 000000000000..2f0087919720 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MixtralTransformerContainer, MixtralNonTransformerContainer +from .model import MixtralInferenceModel + + +class MixtralPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MixtralInferenceModel: + return MixtralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + + map = ContainerMap() + + transformer_containers = [MixtralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MixtralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/opt/__init__.py b/deepspeed/inference/v2/model_implementations/opt/__init__.py new file mode 100644 index 000000000000..c0f24d5243b8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import OPTPolicy diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py new file mode 100644 index 000000000000..e97599ef8e50 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF OPT model looks like this: + +OPTForCausalLM( + (model): OPTModel( + (decoder): OPTDecoder( + (embed_tokens): Embedding(50272, 768, padding_idx=1) + (embed_positions): OPTLearnedPositionalEmbedding(2050, 768) + (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (layers): ModuleList( + (0-11): 12 x OPTDecoderLayer( + (self_attn): OPTAttention( + (k_proj): Linear(in_features=768, out_features=768, bias=True) + (v_proj): Linear(in_features=768, out_features=768, bias=True) + (q_proj): Linear(in_features=768, out_features=768, bias=True) + (out_proj): Linear(in_features=768, out_features=768, bias=True) + ) + (activation_fn): ReLU() + (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (fc1): Linear(in_features=768, out_features=3072, bias=True) + (fc2): Linear(in_features=3072, out_features=768, bias=True) + (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + ) + ) + ) + ) + (lm_head): Linear(in_features=768, out_features=50272, bias=False) +) + +''' + + +class OPTTransformerContainer(LayerContainer): + """ + Transformer layer container for the OPT model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + attn_norm_beta: NormParameter + attn_norm_gamma: NormParameter + mlp_norm_beta: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.out_proj.weight": "attn_out_w.params", + "self_attn.out_proj.bias": "attn_out_b.params", + "fc1.weight": "mlp_1_w.params", + "fc1.bias": "mlp_1_b.params", + "fc2.weight": "mlp_2_w.params", + "fc2.bias": "mlp_2_b.params", + "self_attn_layer_norm.weight": "attn_norm_gamma.params", + "self_attn_layer_norm.bias": "attn_norm_beta.params", + "final_layer_norm.weight": "mlp_norm_gamma.params", + "final_layer_norm.bias": "mlp_norm_beta.params", + } + + +class OPTNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the OPT model. + """ + word_emb: EmbeddingParameter + word_emb_pos: EmbeddingParameter + word_unembed: UnembedParameter + final_norm_w: NormParameter + final_norm_b: NormParameter + + PARAM_MAPPING = { + "*decoder.embed_tokens.weight": ["word_emb.params", "word_unembed.params"], + "*decoder.embed_positions.weight": "word_emb_pos.params", + "*decoder.final_layer_norm.weight": "final_norm_w.params", + "*decoder.final_layer_norm.bias": "final_norm_b.params", + } diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py new file mode 100644 index 000000000000..adf011d8f1a7 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...ragged import RaggedBatchWrapper +from .container import OPTNonTransformerContainer, OPTTransformerContainer + +from ...modules.heuristics import instantiate_embed + + +class OPTInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for OPT models. + """ + + _non_transformer: Optional[OPTNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[OPTTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.ffn_dim + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.RELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.none + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return None + + """ + Overrides of ``DSTransformerModelBase`` methods + """ + + def make_embedding_layer(self) -> None: + """ + Performs setup and creates embedding DSModule. Since OPT includes trained + positional embeddings, we will override the base model implementation. + """ + + embed_config = DSEmbeddingsConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + residual_dtype=self.activation_dtype, + embedding_dim=self.model_dim, + positional_embedding=True, + positional_offset=2) + + self.embed = instantiate_embed(embed_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + embed = self.embed(ragged_batch, self._non_transformer.word_emb, self._non_transformer.word_emb_pos) + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, + hidden_states, + cur_params.mlp_norm_gamma, + beta=cur_params.mlp_norm_beta) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=cur_params.mlp_1_b) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=cur_params.mlp_2_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, + hidden_states, + next_params.attn_norm_gamma, + beta=next_params.attn_norm_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_w, + beta=self._non_transformer.final_norm_b) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + self._transformer[0].attn_norm_gamma, + beta=self._transformer[0].attn_norm_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py new file mode 100644 index 000000000000..d57d5beb48d5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import OPTNonTransformerContainer, OPTTransformerContainer +from .model import OPTInferenceModel + + +class OPTPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> OPTInferenceModel: + return OPTInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [OPTTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.decoder.layers', 'decoder.layers'], transformer_containers) + + map.set_non_transformer_params(OPTNonTransformerContainer(self.model)) + + map.set_unmapped_params(['lm_head.weight']) + + return map diff --git a/deepspeed/inference/v2/model_implementations/parameter_base.py b/deepspeed/inference/v2/model_implementations/parameter_base.py new file mode 100644 index 000000000000..c480f93d51bd --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/parameter_base.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import weakref +from abc import abstractmethod +from typing import Type + +import torch + +from deepspeed.compat import get_annotations_from_namespace, get_annotations + +# Currently have dependency loops for the type hints. +InferenceModel = Type["InferenceModel"] +LayerContainer = Type["LayerContainer"] + +MAPPING_KEY = "PARAM_MAPPING" + + +def make_param_getter(clsname, param): + """ + Normal getter implementation for a property. + """ + + def param_getter(self): + return getattr(self, f"__{clsname}__{param}") + + return param_getter + + +def make_param_setter(clsname, param): + """ + Setter implementation that will call complete component to potentially + finalize the parameter. + """ + + def param_setter(self, value): + setattr(self, f"__{clsname}__{param}", value) + self.dtype = value.dtype + self.complete_component() + + return param_setter + + +def make_readonly_setter(): + """ + Setter implementation that will raise an error if called. + """ + + def paramlist_setter(self, value): + raise ValueError("Cannot set a ParametrizedList directly.") + + return paramlist_setter + + +class ParameterMetaclass(type): + """ + MetaClass for the ParameterBase base class. This class will parse the `src_params` + attribute and create properties for each of the dependencies. A dependency can either + be represented as a string, which is interpreted as a named Tensor, or a `ParametrizedList` + subclass. + """ + + def __new__(cls, clsname, bases, attrs): + + annotations = get_annotations_from_namespace(attrs) + dependencies = { + name: annotation + for name, annotation in annotations.items() if issubclass(annotation, (torch.Tensor, ParametrizedList)) + } + n_dependencies = len(dependencies) + + # Create properties for each of our dependencies + for d_name, d_type in dependencies.items(): + if issubclass(d_type, ParametrizedList): + assert hasattr( + d_type, "count_attr" + ), "ParametrizedList must have a count_attr attribute to access on the inference module." + attrs[d_name] = property(make_param_getter(clsname, d_name), make_readonly_setter()) + else: # torch.Tensor + attrs[d_name] = property(make_param_getter(clsname, d_name), make_param_setter(clsname, d_name)) + + new_cls = super().__new__(cls, clsname, bases, attrs) + new_cls.n_dependencies = n_dependencies + + return new_cls + + def __call__(cls, *args, **kwargs): + new_obj = super().__call__(*args, **kwargs) + new_obj.__init__(*args, **kwargs) + + setattr(new_obj, "dest_param", None) + + # Initialize our dependences to None/empty `ParametrizedList`s + for name, annotation in get_annotations(new_obj).items(): + if issubclass(annotation, ParametrizedList): + #TODO(jeff): update assert with this, model implementation attribute does not align or missing wrt the ParametrizedList attributes + assert hasattr( + new_obj.inference_model, annotation.count_attr + ), f"new_obj={new_obj.__class__.__name__}, name={name}, annotation.count_attr={annotation.count_attr}" + param_list = annotation(new_obj, getattr(new_obj.inference_model, annotation.count_attr)) + setattr(new_obj, f"__{new_obj.__class__.__name__}__{name}", param_list) + else: # torch.Tensor + setattr(new_obj, f"__{new_obj.__class__.__name__}__{name}", None) + + return new_obj + + +class ParameterBase(metaclass=ParameterMetaclass): + """ + A ParameterBase allows us to consolidate tracking the dependencies of loading a parameter from + a checkpoint into a single object. This class should not be used directly, but rather subclassed + and the `src_params` attribute set to a list of strings and/or `ParametrizedList`s. + """ + + # inference_model: InferenceModel + """ + Inference model that will provide context on how to shard and transform the parameter. + """ + + #completed_components: int + """ + How many of the layer dependencies have been met. This is used to determine when the parameter + is ready to be finalized. A ParametrizedList counts as a single dependency for the purposes + of this counter. + """ + + def __init__(self, model: InferenceModel, parent_container: LayerContainer) -> None: + """ + Direct constructor. This should not be called from client code. + + Args: + model (InferenceModel): Inference model that will be used to shard and transform the + parameter in `finalize`. + parent_container (LayerContainer): The parent container that this parameter is a member + of. We will build a weakref to this container to call the finalization callback. + """ + self.inference_model = model + self.completed_components = 0 + self.parent_container = weakref.ref(parent_container) + + @abstractmethod + def finalize(self) -> torch.Tensor: + """ + Finalize the parameter after all of its source parameters have been set. This method + will be automatically called when all inputs have been set. It should return the Tensor + with all transformations performed on it. + """ + pass + + def complete_component(self) -> None: + """ + Mark a component as completed. This should be called by the relevant setter of a direct + property or a ParametrizedList. This method will automatically call `finalize` when all + dependencies have been met and then call the finalization callback on the parent container. + + Once the finalization callback has been called, the parameter will be replaced with the + `dst_param` attribute on the parent container, and this instance will be destroyed. + """ + self.completed_components += 1 + + if self.completed_components != self.n_dependencies: + return + + finalized_param = self.finalize() + self.parent_container().finalization_callback(self, finalized_param) + + +class ParametrizedList: + """ + A ParametrizedList is a list of parameters that are dependencies + of a `ParameterBase` but may vary in length depending on the model + configuration (rather than architecture). For example, a MoE layer + may have different number of experts depending on the size of the model. + + This class is used to manage these lists and provide integer indexing + of a single component rather than accessing names directly. For example, + it tends to be more natural to access the 8th expert with `experts[8]` + rather than a name like `expert_8`, especially as an attribute. + + To inherit from this class, set static variables `name` and `count_attr`. + + ```python + class MyParametrizedList(ParametrizedList): + count_attr: str = "my_list_count" + ``` + + In the above example, `my_list_count` should be an accessible attribute + of the inference model (i.e. via `self.inference_model.my_list_count`). + + NOTE: There are some APIs in which this type cannot be used as if it is + just a list of Tensors. For example, `torch.cat(param_list)` will not work. + However, you can make it compatible with a tuple wrapper: + `torch.cat(tuple(param_list))` + """ + + n_params: int + """ + Number of params this list contains. + """ + + param: ParameterBase + """ + WeakRef to the owning parameter. + """ + + def __init__(self, param: ParameterBase, n_params: int) -> None: + """ + Constructor. Should not be called from client code. + + Args: + param (ParameterBase): The owning parameter. + n_params (int): The number of parameters this list contains. This should be + """ + self.n_params = n_params + self.set_params = 0 + self.param = weakref.ref(param) + self._params = [None] * n_params + + def __getitem__(self, index): + return self._params[index] + + def __setitem__(self, index, value): + if self._params[index] is not None: + raise ValueError("Cannot set a parameter twice.") + + self._params[index] = value + self.set_params += 1 + + if self.set_params != self.n_params: + return + + self.param().complete_component() + + def __iter__(self): + return iter(self._params) + + +def ParamList(attr: str): + """ + Helper to create a subclass of ParametrizedList with the desired `count_attr`. + + In this manner, we can annotate the type of a Parameter dependency with the + following: + + ```python + class CustomParameter(ParameterBase): + dependency_list: ParamList("dependencies_count_name") + ``` + + where "dependencies_count_name" is the name of the attribute on the inference model. + """ + + class ParametrizedListInstance(ParametrizedList): + count_attr: str = attr + + return ParametrizedListInstance diff --git a/deepspeed/inference/v2/model_implementations/phi/__init__.py b/deepspeed/inference/v2/model_implementations/phi/__init__.py new file mode 100644 index 000000000000..3ab107e75a91 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import PhiPolicy diff --git a/deepspeed/inference/v2/model_implementations/phi/containers.py b/deepspeed/inference/v2/model_implementations/phi/containers.py new file mode 100644 index 000000000000..21f07eb8c99a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/containers.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-2 model looks like this: + +PhiForCausalLM( + (model): PhiModel( + (embed_tokens): Embedding(51200, 2560) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x PhiDecoderLayer( + (self_attn): PhiAttention( + (q_proj): Linear(in_features=2560, out_features=2560, bias=True) + (k_proj): Linear(in_features=2560, out_features=2560, bias=True) + (v_proj): Linear(in_features=2560, out_features=2560, bias=True) + (dense): Linear(in_features=2560, out_features=2560, bias=True) + (rotary_emb): PhiRotaryEmbedding() + ) + (mlp): PhiMLP( + (activation_fn): NewGELUActivation() + (fc1): Linear(in_features=2560, out_features=10240, bias=True) + (fc2): Linear(in_features=10240, out_features=2560, bias=True) + ) + (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (resid_dropout): Dropout(p=0.1, inplace=False) + ) + ) + (final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + ) + (lm_head): Linear(in_features=2560, out_features=51200, bias=True) +) +''' + + +class PhiTransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + ln_gamma: NormParameter + ln_beta: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.dense.weight": "attn_out_w.params", + "self_attn.dense.bias": "attn_out_b.params", + "mlp.fc1.weight": "mlp_1_w.params", + "mlp.fc1.bias": "mlp_1_b.params", + "mlp.fc2.weight": "mlp_2_w.params", + "mlp.fc2.bias": "mlp_2_b.params", + "input_layernorm.weight": "ln_gamma.params", + "input_layernorm.bias": "ln_beta.params", + } + + +class PhiNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + word_unembed_w: UnembedParameter + word_unembed_b: UnembedParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.final_layernorm.weight": "final_norm_gamma.params", + "model.final_layernorm.bias": "final_norm_beta.params", + "lm_head.weight": "word_unembed_w.params", + "lm_head.bias": "word_unembed_b.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi/model.py b/deepspeed/inference/v2/model_implementations/phi/model.py new file mode 100644 index 000000000000..2d5826810cb5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/model.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .containers import PhiNonTransformerContainer, PhiTransformerContainer + + +class PhiInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[PhiNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[PhiTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.GELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + rotary_dim = int(self._config.partial_rotary_factor * self.head_size) + return RotateHalfConfig(rotate_dim=rotary_dim, theta_base=self._config.rope_theta) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + attn_ln_out = hidden_states + attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=cur_params.qkv_b) + attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info) + attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=cur_params.attn_out_b) + + mlp_ln_out = hidden_states + mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=cur_params.mlp_1_b) + mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=cur_params.mlp_2_b) + + mlp_output.add_(attention_output) + + if self.tp_size > 1: + dist.all_reduce(mlp_output, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, mlp_output = self.norm(residual, mlp_output, next_params.ln_gamma, beta=next_params.ln_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(mlp_output) + + return residual, mlp_output + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed_w, + ragged_batch_info, + bias=self._non_transformer.word_unembed_b, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].ln_gamma, + beta=self._transformer[0].ln_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi/policy.py b/deepspeed/inference/v2/model_implementations/phi/policy.py new file mode 100644 index 000000000000..4b081a8e61bd --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/policy.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .containers import PhiNonTransformerContainer, PhiTransformerContainer +from .model import PhiInferenceModel + + +class PhiPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> PhiInferenceModel: + return PhiInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + trans_container_cls = PhiTransformerContainer + transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(PhiNonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/phi3/__init__.py b/deepspeed/inference/v2/model_implementations/phi3/__init__.py new file mode 100644 index 000000000000..1a4b756d210c --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Phi3Policy diff --git a/deepspeed/inference/v2/model_implementations/phi3/containers.py b/deepspeed/inference/v2/model_implementations/phi3/containers.py new file mode 100644 index 000000000000..1cb52a75ae0b --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3/containers.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-3 model looks like this: + +Phi3ForCausalLM( + (model): Phi3Model( + (embed_tokens): Embedding(32064, 3072) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x Phi3DecoderLayer( + (self_attn): Phi3Attention( + (o_proj): Linear(in_features=3072, out_features=3072, bias=False) + (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False) + (rotary_emb): Phi3RotaryEmbedding() + ) + (mlp): PhiMLP( + (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False) + (down_proj): Linear(in_features=16384, out_features=3072, bias=False) + (activation_fn): SiLU() + ) + (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + (resid_attn_dropout): Dropout(p=0.0) + (resid_mlp_dropout): Dropout(p=0.0) + (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + ) + (final_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + (lm_head): Linear(in_features=3072, out_features=32064, bias=False) +) +''' + + +class Phi3TransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: FusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: FusedGatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.qkv_proj.weight": "qkv_w.params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_up_proj.weight": "mlp_1_w.params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Phi3NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + word_unembed_w: UnembedParameter + final_norm_gamma: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm_gamma.params", + "lm_head.weight": "word_unembed_w.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi3/model.py b/deepspeed/inference/v2/model_implementations/phi3/model.py new file mode 100644 index 000000000000..507bb4fc9af1 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3/model.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer + + +class Phi3InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Phi3NonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Phi3TransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed_w, + ragged_batch_info, + gamma=self._non_transformer.final_norm_gamma) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi3/policy.py b/deepspeed/inference/v2/model_implementations/phi3/policy.py new file mode 100644 index 000000000000..a1b445929053 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer +from .model import Phi3InferenceModel + + +class Phi3Policy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3InferenceModel: + return Phi3InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Phi3TransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Phi3NonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/qwen/__init__.py b/deepspeed/inference/v2/model_implementations/qwen/__init__.py new file mode 100644 index 000000000000..18206048fa29 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import QwenPolicy diff --git a/deepspeed/inference/v2/model_implementations/qwen/container.py b/deepspeed/inference/v2/model_implementations/qwen/container.py new file mode 100644 index 000000000000..313de68555b9 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen/container.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Qwen model looks like this: + +QWenLMHeadModel( + (transformer): QWenModel( + (wte): Embedding(151936, 4096) + (drop): Dropout(p=0.0, inplace=False) + (rotary_emb): RotaryEmbedding() + (h): ModuleList( + (0-31): 32 x QWenBlock( + (ln_1): RMSNorm() + (attn): QWenAttention( + (c_attn): Linear(in_features=4096, out_features=12288, bias=True) + (c_proj): Linear(in_features=4096, out_features=4096, bias=False) + (attn_dropout): Dropout(p=0.0, inplace=False) + ) + (ln_2): RMSNorm() + (mlp): QWenMLP( + (w1): Linear(in_features=4096, out_features=11008, bias=False) + (w2): Linear(in_features=4096, out_features=11008, bias=False) + (c_proj): Linear(in_features=11008, out_features=4096, bias=False) + ) + ) + ) + (ln_f): RMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=151936, bias=False) +) +''' + + +class QwenTransformerContainer(LayerContainer): + """ + Transformer layer container for the Qwen model. + """ + qkv_w: FusedQKVParameter + qkv_b: FusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "attn.c_attn.weight": "qkv_w.params", + "attn.c_attn.bias": "qkv_b.params", + "attn.c_proj.weight": "attn_out_w.params", + "mlp.w1.weight": "mlp_1_w.up_params", + "mlp.w2.weight": "mlp_1_w.gate_params", + "mlp.c_proj.weight": "mlp_2_w.params", + "ln_1.weight": "attn_norm_gamma.params", + "ln_2.weight": "mlp_norm_gamma.params", + } + + +class QwenNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Qwen model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "transformer.wte.weight": "word_emb.params", + "transformer.ln_f.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/qwen/model.py b/deepspeed/inference/v2/model_implementations/qwen/model.py new file mode 100644 index 000000000000..e867e4be6713 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen/model.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper + +from .container import QwenNonTransformerContainer, QwenTransformerContainer + + +class QwenInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[QwenNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[QwenTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size // 2 + + @property + def n_heads_kv(self) -> int: + return self._config.hidden_size // self._config.kv_channels + + @property + def activation_dtype(self) -> DtypeEnum: + autoset_precision = self._config.bf16 + self._config.fp16 == 0 + if autoset_precision: + return DtypeEnum.fp16 + if self._config.fp16: + return DtypeEnum.fp16 + elif self._config.bf16: + # TODO(ZonePG): bf16 inference results may be different from huggingface bf16, + # because in rms_norm, Qwen still use float() instead of bf16 + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.SiGLU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rotary_emb_base) + + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + eps=self._config.layer_norm_epsilon, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/qwen/policy.py b/deepspeed/inference/v2/model_implementations/qwen/policy.py new file mode 100644 index 000000000000..a9263f553621 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import QwenNonTransformerContainer, QwenTransformerContainer +from .model import QwenInferenceModel + + +class QwenPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> QwenInferenceModel: + return QwenInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [QwenTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['transformer.h'], transformer_containers) + + map.set_non_transformer_params(QwenNonTransformerContainer(self.model)) + + map.set_unmapped_params(['transformer.rotary_emb.inv_freq']) + + return map diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py b/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py new file mode 100644 index 000000000000..80b09757c74d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Qwen2Policy diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2/container.py new file mode 100644 index 000000000000..6556d87d6afb --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/container.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Qwen2 model looks like this: + +Qwen2ForCausalLM( + (model): Qwen2Model( + (embed_tokens): Embedding(151936, 1024) + (layers): ModuleList( + (0-23): 24 x Qwen2DecoderLayer( + (self_attn): Qwen2SdpaAttention( + (q_proj): Linear(in_features=1024, out_features=1024, bias=True) + (k_proj): Linear(in_features=1024, out_features=1024, bias=True) + (v_proj): Linear(in_features=1024, out_features=1024, bias=True) + (o_proj): Linear(in_features=1024, out_features=1024, bias=False) + (rotary_emb): Qwen2RotaryEmbedding() + ) + (mlp): Qwen2MLP( + (gate_proj): Linear(in_features=1024, out_features=2816, bias=False) + (up_proj): Linear(in_features=1024, out_features=2816, bias=False) + (down_proj): Linear(in_features=2816, out_features=1024, bias=False) + (act_fn): SiLU() + ) + (input_layernorm): Qwen2RMSNorm() + (post_attention_layernorm): Qwen2RMSNorm() + ) + ) + (norm): Qwen2RMSNorm() + ) + (lm_head): Linear(in_features=1024, out_features=151936, bias=False) +) +''' + + +class Qwen2TransformerContainer(LayerContainer): + """ + Transformer layer container for the Qwen2 model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Qwen2NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Qwen2 model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2/model.py new file mode 100644 index 000000000000..d535462a954d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/model.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper + +from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer + + +class Qwen2InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Qwen2NonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Qwen2TransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + # TODO(ZonePG): bf16 inference results may be different from huggingface bf16, + # because in rms_norm, Qwen still use float() instead of bf16 + # if self._config.torch_dtype == torch.float16: + # return DtypeEnum.fp16 + # elif self._config.torch_dtype == torch.bfloat16: + # return DtypeEnum.bf16 + # else: + # raise NotImplementedError("Only fp16 and bf16 are supported") + return DtypeEnum.fp16 + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.SiGLU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + eps=self._config.rms_norm_eps, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py b/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py new file mode 100644 index 000000000000..9c5db2ba0065 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer +from .model import Qwen2InferenceModel + + +class Qwen2Policy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2InferenceModel: + return Qwen2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Qwen2TransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Qwen2NonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py new file mode 100644 index 000000000000..23e06a770023 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Qwen2MoePolicy diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py new file mode 100644 index 000000000000..e499379da7e3 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Qwen2-57B-A14B model looks like this: + +Qwen2MoeForCausalLM( + (model): Qwen2MoeModel( + (embed_tokens): Embedding(151936, 3584) + (layers): ModuleList( + (0-27): 28 x Qwen2MoeDecoderLayer( + (self_attn): Qwen2MoeSdpaAttention( + (q_proj): Linear(in_features=3584, out_features=3584, bias=True) + (k_proj): Linear(in_features=3584, out_features=512, bias=True) + (v_proj): Linear(in_features=3584, out_features=512, bias=True) + (o_proj): Linear(in_features=3584, out_features=3584, bias=False) + (rotary_emb): Qwen2MoeRotaryEmbedding() + ) + (mlp): Qwen2MoeSparseMoeBlock( + (gate): Linear(in_features=3584, out_features=64, bias=False) + (experts): ModuleList( + (0-63): 64 x Qwen2MoeMLP( + (gate_proj): Linear(in_features=3584, out_features=2560, bias=False) + (up_proj): Linear(in_features=3584, out_features=2560, bias=False) + (down_proj): Linear(in_features=2560, out_features=3584, bias=False) + (act_fn): SiLU() + ) + ) + (shared_expert): Qwen2MoeMLP( + (gate_proj): Linear(in_features=3584, out_features=20480, bias=False) + (up_proj): Linear(in_features=3584, out_features=20480, bias=False) + (down_proj): Linear(in_features=20480, out_features=3584, bias=False) + (act_fn): SiLU() + ) + (shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False) + ) + (input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) + (post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06) + ) + ) + (norm): Qwen2MoeRMSNorm((3584,), eps=1e-06) + ) + (lm_head): Linear(in_features=3584, out_features=151936, bias=False) +) +''' + + +class Qwen2MoeTransformerContainer(LayerContainer): + """ + Transformer layer container for the Qwen2Moe model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + shared_moe_mlp_1: GatedMLPParameter + shared_moe_mlp_2: MLP2Parameter + shared_moe_gate: MoEGatingWeightParameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate.weight": "moe_gate.params", + "mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts", + "mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts", + "mlp.experts.*.down_proj.weight": "moe_mlp_2.experts", + "mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params", + "mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params", + "mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params", + "mlp.shared_expert_gate.weight": "shared_moe_gate.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Qwen2MoeNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Qwen2Moe model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py new file mode 100644 index 000000000000..c7841b24e5fc --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py @@ -0,0 +1,359 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer + + +class Qwen2MoeInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Qwen2MoE models. + """ + + _non_transformer: Optional[Qwen2MoeNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Qwen2MoeTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.shared_expert_intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + # TODO(ZonePG): bf16 inference results may be different from huggingface bf16, + # because in rms_norm, Qwen still use float() instead of bf16 + # if self._config.torch_dtype == torch.float16: + # return DtypeEnum.fp16 + # elif self._config.torch_dtype == torch.bfloat16: + # return DtypeEnum.bf16 + # else: + # raise NotImplementedError("Only fp16 and bf16 are supported") + return DtypeEnum.fp16 + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.SiGLU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return self._config.norm_topk_prob + + def make_moe_layer(self) -> None: + """ + Instantiates the MoE layer for the model. This sets the `self.moe` attribute. + """ + sharded_dim = sharded_intermediate_dim(self.intermediate_dim // self.n_top_k, self.tp_size, self.tp_rank) + + moe_config = DSMoEConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + model_dim=self.model_dim, + intermediate_features=sharded_dim, + activation=self.mlp_activation_fn, + n_experts=self.n_experts, + top_k=self.n_top_k, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, + ) + + self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) + + ######### MLP 1 ######### + def make_shared_expert_mlp_1_layer(self) -> None: + """ + Instantiates the linear projection layer for the first MLP in the feedforward network. + This sets the `self.mlp_1` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=shard_size, + activation=self.mlp_activation_fn, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config) + + ######### MLP 2 ######### + def make_shared_expert_mlp_2_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config) + + ######### MLP 2 ######### + def make_shared_expert_gate_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.model_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=8, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_gate = heuristics.instantiate_linear(linear_config, self._engine_config) + + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + eps=self._config.rms_norm_eps, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_shared_expert_mlp_1_layer() + self.make_shared_expert_mlp_2_layer() + self.make_shared_expert_gate_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + shared_expert_output = self.shared_expert_mlp_1(hidden_states, cur_params.shared_moe_mlp_1, b=None) + shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None) + shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1] + # shared_expert_gate_output shape[-1] is 1 + shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output)) + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + hidden_states.add_(shared_expert_output) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py new file mode 100644 index 000000000000..630bafe993a8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer +from .model import Qwen2MoeInferenceModel + + +class Qwen2MoePolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2MoeInferenceModel: + return Qwen2MoeInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Qwen2MoeTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Qwen2MoeNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/sharding/__init__.py b/deepspeed/inference/v2/model_implementations/sharding/__init__.py new file mode 100644 index 000000000000..63421bc1c622 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attn import * +from .attn_out import * +from .embedding import * +from .mlp import * +from .qkv import * +from .types import * +from .unembed import * diff --git a/deepspeed/inference/v2/model_implementations/sharding/attn.py b/deepspeed/inference/v2/model_implementations/sharding/attn.py new file mode 100644 index 000000000000..de8d6f6ac4c5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/attn.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + + +def get_local_heads(shard_rank: int, + num_shards: int, + n_heads_q: int, + n_heads_kv: Optional[int] = None) -> Tuple[int, int]: + """ + Helper to determine the number of local heads of a given shard. + + Args: + shard_rank (int): The rank of the shard. + num_shards (int): The total number of shards that attention is distributed over. + n_heads_q (int): The number of query heads. + n_heads_kv (int): The number of key/value heads. If not passed, it is assumed that + the number of query and key/value heads are the same. + """ + if n_heads_q < num_shards: + raise ValueError("There must be at least as many attention heads as there are shards.") + + if n_heads_kv is None or n_heads_kv == n_heads_q: + # MHA attention + base_heads = n_heads_q // num_shards + extra_heads = n_heads_q % num_shards + + if shard_rank < extra_heads: + return (base_heads + 1), (base_heads + 1) + else: + return base_heads, base_heads + else: + # GQA attention + if n_heads_q % n_heads_kv != 0: + raise ValueError("Must be an even ratio between query and key/value heads.") + + if n_heads_kv < num_shards and num_shards % n_heads_kv != 0: + raise ValueError( + "If splitting a group across multiple shards, we must be able to distribute the groups evenly.") + + if n_heads_kv >= num_shards and n_heads_kv % num_shards != 0: + raise ValueError("If parallelizing groups, must be able to evenly distribute them.") + + q_ratio = n_heads_q // n_heads_kv + + if n_heads_kv >= num_shards: + local_kv_heads = n_heads_kv // num_shards + local_q_heads = local_kv_heads * q_ratio + return local_q_heads, local_kv_heads + else: + group_sharding_size = num_shards // n_heads_kv + group_rank_idx = shard_rank % group_sharding_size + + base_heads = q_ratio // group_sharding_size + extra_heads = q_ratio % group_sharding_size + + if group_rank_idx < extra_heads: + return (base_heads + 1), 1 + else: + return base_heads, 1 diff --git a/deepspeed/inference/v2/model_implementations/sharding/attn_out.py b/deepspeed/inference/v2/model_implementations/sharding/attn_out.py new file mode 100644 index 000000000000..ce7c105531ea --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/attn_out.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_attn_out_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> Optional[torch.Tensor]: + """ + Utility method for sharding an attention output parameter. + """ + if len(param.shape) == 1: + # We will do the bias addition on the 0th rank only rather than scale the parameter and + # implicitly reconstruct this in the distributed reduce. + return param if shard_rank == 0 else None + + assert n_heads_kv is None or (n_heads_q is not None + and n_heads_kv is not None), "n_heads_kv should not be passed without n_heads_q" + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if mha_sharding: + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards, granularity=head_size) + else: + assert param.shape[0] == head_size * n_heads_q, "GQA param shape is not correct" + + # 32 KV heads, 16 shards for example + even_kv_sharding = n_heads_kv % num_shards == 0 + + # 8 KV heads, 16 shards for example + even_kv_distribution = num_shards % n_heads_kv == 0 + + assert even_kv_sharding or even_kv_distribution, "No partitioning algorithm for this yet." + + if even_kv_sharding: + # Same as original sharding scenario + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards, granularity=head_size) + else: + # We will first do a sharding on the KV and Q to map to the one KV shard per group of Q. + q_sharding_degree = num_shards // n_heads_kv + + kv_head = shard_rank // q_sharding_degree + + q_sharding_rank = shard_rank % q_sharding_degree + q_factor = n_heads_q // n_heads_kv + + q_chunk = param[..., q_factor * kv_head * head_size:q_factor * (kv_head + 1) * head_size] + + return shard_param(q_chunk, + ShardingType.INNER_DIMENSION, + q_sharding_rank, + q_sharding_degree, + granularity=head_size) + + +def attn_out_in_features(out_features: int, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> int: + """ + Helper to calculate the expected output projection dimension of a QKV projection matrix. + + Args: + in_features (int): The model dimension. + shard_rank (int): Which rank to return the corresponding size for. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each attention head. + n_heads_q (int): The number of query heads on the model. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key and value heads on the model. This only needs to be passed + if the number of query and key/value heads are different. This argument cannot be passed without + also passing n_heads_q (we want to explicitly opt into GQA sharding). + """ + assert n_heads_kv is None or (n_heads_q is not None + and n_heads_kv is not None), "n_heads_kv should not be passed without n_heads_q" + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if mha_sharding: + endpoints = get_shard_endpoints(out_features, shard_rank, num_shards, granularity=head_size) + return endpoints[1] - endpoints[0] + else: + if n_heads_kv >= num_shards: + assert n_heads_kv % num_shards == 0, "No partitioning algorithm for this yet." + n_local_groups = n_heads_kv // num_shards + group_size = n_heads_q // n_heads_kv + + return n_local_groups * head_size * group_size + else: + assert num_shards % n_heads_kv == 0, "No partitioning algorithm for this yet." + q_split_degree = num_shards // n_heads_kv + q_split_rank = shard_rank % q_split_degree + split_granularity = (n_heads_q // n_heads_kv) * head_size + + q_endpoints = get_shard_endpoints(split_granularity, q_split_rank, q_split_degree, granularity=head_size) + + return q_endpoints[1] - q_endpoints[0] diff --git a/deepspeed/inference/v2/model_implementations/sharding/embedding.py b/deepspeed/inference/v2/model_implementations/sharding/embedding.py new file mode 100644 index 000000000000..00d335768ae6 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/embedding.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_embedding_param(param: torch.Tensor, shard_rank: int, num_shards: int) -> torch.Tensor: + """ + Utility method for sharding an embedding parameter. + + Args: + param (torch.Tensor): The parameter to shard. Should be of shape [vocab_size, model_dim] + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards) + + +def sharded_embedding_dim(embedding_size: int, shard_rank: int, num_shards: int) -> int: + """ + Utility method for getting the size of the embedding dimension of a sharded embedding. + + Args: + embedding_size (int): The size of the embedding. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + start_idx, end_idx = get_shard_endpoints(embedding_size, shard_rank, num_shards) + return end_idx - start_idx diff --git a/deepspeed/inference/v2/model_implementations/sharding/mlp.py b/deepspeed/inference/v2/model_implementations/sharding/mlp.py new file mode 100644 index 000000000000..8abd0ff8622d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/mlp.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType, DEFAULT_SHARD_GRANULARITY +from .utils import shard_param, get_shard_endpoints + + +def shard_mlp_1_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + gated: bool = False, + is_moe: bool = False) -> torch.Tensor: + """ + Utility method for sharding an MLP 1 parameter. Both biases and weights are supported, as well + as for fused weights for MoE. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + gated (bool): Whether or not the parameter is from a gated MLP. + """ + bias_dims = 2 if is_moe else 1 + + if gated: + return shard_param(param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + granularity=DEFAULT_SHARD_GRANULARITY * 2, + bias_dims=bias_dims) + else: + return shard_param(param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, bias_dims=bias_dims) + + +def shard_mlp_2_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + is_moe: bool = False) -> Optional[torch.Tensor]: + """ + Utility method for sharding an MLP 2 parameter. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + is_moe (bool): Whether or not the parameter is from a MoE model. + """ + bias_dim_size = 2 if is_moe else 1 + + if len(param.shape) == bias_dim_size: + # We will do the bias addition on the 0th rank only rather than scale the parameter and + # implicitly reconstruct this in the distributed reduce. + return param if shard_rank == 0 else None + + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards) + + +def sharded_intermediate_dim(intermediate_size: int, num_shards: int, shard_rank: int) -> int: + """ + Utility method for getting the size of the intermediate dimension of a sharded MLP. + + Args: + intermediate_size (int): The size of the intermediate dimension. + num_shards (int): The total number of shards the parameter is distributed across. + shard_rank (int): Which shard of the partitioned tensor to return. + """ + endpoints = get_shard_endpoints(intermediate_size, shard_rank, num_shards) + return endpoints[1] - endpoints[0] diff --git a/deepspeed/inference/v2/model_implementations/sharding/qkv.py b/deepspeed/inference/v2/model_implementations/sharding/qkv.py new file mode 100644 index 000000000000..19dff1436de5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/qkv.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_qkv_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> Optional[torch.Tensor]: + """ + Utility method for sharding a QKV parameter. Both biases and weights are supported. It is assumed + that the layout of the parameter is such that all Q heads, all K heads, and all V heads + are contiguous with respect to each other. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each head. + n_heads_q (int): The number of query heads. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key/value heads. This only needs to be passed if the number + of query and key/value heads are different. This argument should not be passed without + n_heads_q (we want to explicitly opt into GQA sharding). + """ + if n_heads_kv is not None and n_heads_q is None: + raise ValueError("n_heads_kv should not be passed without n_heads_q") + + if param is None: + raise ValueError("param should not be None") + if n_heads_q is None: + # Guaranteed to be in MHA + if param.shape[0] // 3 % head_size != 0: + raise ValueError("MHA param shape is not correct") + n_heads_q = param.shape[0] // head_size // 3 + mha_sharding = True + elif n_heads_kv is None: + mha_sharding = True + else: + mha_sharding = n_heads_q == n_heads_kv + + if n_heads_q < num_shards: + raise ValueError("There must be at least as many query heads as there are shards.") + + if mha_sharding: + return shard_param(param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + num_concatenated_matrices=3, + granularity=head_size) + else: + if n_heads_q % n_heads_kv != 0: + raise ValueError("Must be an even ratio between query and key/value heads.") + + if param.shape[0] != head_size * (n_heads_q + 2 * n_heads_kv): + raise ValueError("GQA param shape is not correct") + + # 32 KV heads, 16 shards for example + if n_heads_kv >= num_shards and n_heads_kv % num_shards != 0: + raise ValueError("Currently do not support uneven partitioning of KV heads for GQA.") + + # 8 KV heads, 16 shards for example + if n_heads_kv < num_shards and num_shards % n_heads_kv != 0: + raise ValueError("Currently do not support distributing KV heads across different numbers of shards.") + else: + even_kv_sharding = n_heads_kv >= num_shards + + q_param = param[:head_size * n_heads_q] + kv_param = param[head_size * n_heads_q:] + + if even_kv_sharding: + # This is equivalent to the original sharding algorithm since n_heads_q = C * n_heads_kv. + # If n_heads_kv % num_shards == 0, then n_heads_q % num_shards == 0. + q_param = shard_param(q_param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, granularity=head_size) + kv_param = shard_param(kv_param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + num_concatenated_matrices=2, + granularity=head_size) + return torch.cat([q_param, kv_param], dim=0) + else: + # We will first do a sharding on the KV and Q to map to the one KV shard per group of Q. + q_sharding_degree = num_shards // n_heads_kv + + kv_head = shard_rank // q_sharding_degree + k_param = kv_param[kv_head * head_size:(kv_head + 1) * head_size] + v_param = kv_param[(n_heads_kv + kv_head) * head_size:(n_heads_kv + kv_head + 1) * head_size] + + q_sharding_rank = shard_rank % q_sharding_degree + q_factor = n_heads_q // n_heads_kv + + q_chunk = q_param[q_factor * kv_head * head_size:q_factor * (kv_head + 1) * head_size] + + q_param = shard_param(q_chunk, + ShardingType.OUTER_DIMENSION, + q_sharding_rank, + q_sharding_degree, + granularity=head_size) + + return torch.cat([q_param, k_param, v_param], dim=0) + + +def qkv_out_features(in_features: int, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> int: + """ + Helper to calculate the expected output projection dimension of a QKV projection matrix. + + Args: + in_features (int): The model dimension. + shard_rank (int): Which rank to return the corresponding size for. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each head. + n_heads_q (int): The number of query heads. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key/value heads. This only needs to be passed if the number + of query and key/value heads are different. This argument cannot be passed without also + passing n_heads_q (we want to explicitly opt into GQA sharding). + """ + if n_heads_kv is not None and n_heads_q is None: + raise ValueError("n_heads_kv should not be passed without n_heads_q") + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if n_heads_q is not None and in_features != head_size * n_heads_q: + raise ValueError("in_features is not consistent with n_heads_q and head_size") + + if mha_sharding: + endpoints = get_shard_endpoints(in_features, shard_rank, num_shards, granularity=head_size) + return (endpoints[1] - endpoints[0]) * 3 + else: + if n_heads_kv >= num_shards: + if n_heads_kv % num_shards != 0: + raise ValueError("The KV heads must be evenly distributed across the shards.") + + n_local_groups = n_heads_kv // num_shards + group_size = n_heads_q // n_heads_kv + + return n_local_groups * head_size * (2 + group_size) + else: + if num_shards % n_heads_kv != 0: + raise ValueError("A shared KV head must always partition across the same number of shards.") + + q_split_degree = num_shards // n_heads_kv + q_split_rank = shard_rank % q_split_degree + split_granularity = (n_heads_q // n_heads_kv) * head_size + + q_endpoints = get_shard_endpoints(split_granularity, q_split_rank, q_split_degree, granularity=head_size) + + return (q_endpoints[1] - q_endpoints[0]) + 2 * head_size diff --git a/deepspeed/inference/v2/model_implementations/sharding/types.py b/deepspeed/inference/v2/model_implementations/sharding/types.py new file mode 100644 index 000000000000..01dce0db523a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/types.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum + +DEFAULT_SHARD_GRANULARITY = 32 + + +class ShardingType(Enum): + # Inner dimension sharding corresponds to splitting the Tensor along the K-dimension + # of a matrix multiplication. This would be used for attention_output or MLP2. + INNER_DIMENSION = 1 + + # Outer dimension sharding corresponds to splitting the Tensor along the N-dimension + # of a matrix multiplication. This would be used for the QKV and MLP1 projections. + OUTER_DIMENSION = 0 diff --git a/deepspeed/inference/v2/model_implementations/sharding/unembed.py b/deepspeed/inference/v2/model_implementations/sharding/unembed.py new file mode 100644 index 000000000000..6cc771969ad9 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/unembed.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_unembed_param(param: torch.Tensor, shard_rank: int, num_shards: int) -> torch.Tensor: + """ + Utility method for sharding an unembed parameter. We shard unembeddings on the vocab dimension + with the expectation of an all-gather to produce the full results. + + TODO(cmikeh2): Really ideal would be if MII could have access to the comm and we would do + an A2A and sharded sampling. + + Args: + param (torch.Tensor): The parameter to shard. Should be of shape [vocab_size, model_dim] + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + + Returns: + torch.Tensor: The sharded parameter of shape [sharded_vocab_size, model_dim] + """ + return shard_param(param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, granularity=1) + + +def sharded_unembed_dim(vocab_size: int, shard_rank: int, num_shards: int) -> int: + """ + Utility method for determining the sharded vocab size of a sharded unembed parameter. + + Args: + vocab_size (int): The size of the vocabulary. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + start_idx, end_idx = get_shard_endpoints(vocab_size, shard_rank, num_shards, granularity=1) + return end_idx - start_idx diff --git a/deepspeed/inference/v2/model_implementations/sharding/utils.py b/deepspeed/inference/v2/model_implementations/sharding/utils.py new file mode 100644 index 000000000000..fd0eb51873f8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import torch + +from .types import ShardingType, DEFAULT_SHARD_GRANULARITY + + +def get_shard_endpoints(dim_size: int, + shard_rank: int, + num_shards: int, + granularity: int = DEFAULT_SHARD_GRANULARITY) -> Tuple[int, int]: + """ + Given a dimension to shard with size dim_size, return the start and end indices of the slice + that belong to the given rank. + + The typical use of this is as an internal helper function, so see if there is a higher level + API that better suits the application. + + Args: + dim_size (int): The size of the dimension to shard. + shard_rank (int): The rank of the shard to return. + num_shards (int): Total number of shards the dimension will be distributed across. + granularity (int): The minimum alignment of the shard endpoints. This is used to support + non-even head counts as well as align dimensions to cleaner GEMM boundaries. + """ + assert dim_size % granularity == 0, "Dimension size must be divisible by granularity" + + total_chunks = dim_size // granularity + base_chunks_per_rank = total_chunks // num_shards + remainder_chunks = total_chunks % num_shards + + start_chunk_id = shard_rank * base_chunks_per_rank + min(shard_rank, remainder_chunks) + end_chunk_id = start_chunk_id + base_chunks_per_rank + (1 if shard_rank < remainder_chunks else 0) + + return start_chunk_id * granularity, end_chunk_id * granularity + + +def shard_param(param: Optional[torch.Tensor], + shard_mode: ShardingType, + shard_rank: int, + num_shards: int, + num_concatenated_matrices: int = 1, + granularity: int = 32, + bias_dims: int = 1) -> torch.Tensor: + """ + Utility for sharding a parameter. This will return the slice of the parameter that should + exist on the given shard_rank given the sharding configuration. The workflow here is + to find the minimum bounded Tensor to shard, get the slicing endpoints, and then concatenate + as needed. + + The typical use of this is as an internal helper function, so see if there is a higher level + API that better suits the application. + + Args: + param (torch.Tensor): The parameter to shard. + shard_mode (ShardingType): The type of sharding to apply. See ShardingType for more context. + shard_rank (int): The rank of the shard to return. + num_shards (int): Total number of shards the parameter will be distrbuted across. + num_concatenated_matrices (int): The number of matrices that have been concatenated together in the original + parameter. An example of this is a fused QKV projection matrix, where the `num_concatenated_matrices` + argument would be 3. + granularity (int): The minimum alignment of the shard endpoints. For attention projection matrices, this + should be set to the head size to support non-even sharding. + bias_dims (int): The number of dimensions that are considered bias dimensions. This is used to support + sharding of MoE and non-MoE biases on the same codepath. + """ + assert shard_rank < num_shards, "Shard rank must be less than num_shards" + + # Easier to hide this inside of the sharding logic than to add checks in every model + # implementation. + if param is None: + return None + + if num_shards == 1: + # Trivial case of no sharding. + return param + + if shard_mode == ShardingType.OUTER_DIMENSION: + + def get_matrices(dim_idx: int) -> torch.Tensor: + dim_size = param.size(dim_idx) // num_concatenated_matrices + start_channel_id, end_channel_id = get_shard_endpoints(dim_size, shard_rank, num_shards, granularity) + return torch.chunk(param, num_concatenated_matrices, dim=dim_idx), start_channel_id, end_channel_id + + if param.ndim == bias_dims: + # Special case for bias parameters. + matrices, start_channel_id, end_channel_id = get_matrices(dim_idx=-1) + return torch.cat([mat[..., start_channel_id:end_channel_id] for mat in matrices], dim=-1) + else: + # General case for weight parameters. This assumes MoE parameters are stored in the format of + # [num_experts, out_features, in_features] + matrices, start_channel_id, end_channel_id = get_matrices(dim_idx=-2) + return torch.cat([mat[..., start_channel_id:end_channel_id, :] for mat in matrices], dim=-2) + + elif shard_mode == ShardingType.INNER_DIMENSION: + dim_size = param.size(-1) // num_concatenated_matrices + start_channel_id, end_channel_id = get_shard_endpoints(dim_size, shard_rank, num_shards, granularity) + matrices = torch.chunk(param, num_concatenated_matrices, dim=-1) + return torch.cat([mat[..., start_channel_id:end_channel_id] for mat in matrices], dim=-1) diff --git a/deepspeed/inference/v2/modules/__init__.py b/deepspeed/inference/v2/modules/__init__.py new file mode 100644 index 000000000000..917c1599de2e --- /dev/null +++ b/deepspeed/inference/v2/modules/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from . import implementations +from . import interfaces +from .module_registry import ConfigBundle diff --git a/deepspeed/inference/v2/modules/configs/__init__.py b/deepspeed/inference/v2/modules/configs/__init__.py new file mode 100644 index 000000000000..3429e69b47de --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attention_configs import ( + DSSelfAttentionConfig, + PositionalEmbeddingType, + MaskingType, + RotateHalfConfig, +) +from .embedding_config import DSEmbeddingsConfig +from .linear_config import DSLinearConfig +from .moe_config import DSMoEConfig +from .norm_config import DSNormConfig, NormTypeEnum +from .unembed_config import DSUnembedConfig diff --git a/deepspeed/inference/v2/modules/configs/attention_configs.py b/deepspeed/inference/v2/modules/configs/attention_configs.py new file mode 100644 index 000000000000..be6a3535024c --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/attention_configs.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from typing import Dict, Optional + +from ...inference_utils import DtypeEnum +from ...modules.ds_module import DSModuleConfig +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class PositionalEmbeddingType(Enum): + + # No positional embeddings + none = "none" + + # Rotary positional embeddings - every half + rotate_half = "rotate_half" + + # Rotary positional embeddings - every other + rotate_every_other = "rotate_every_other" + + # Alibi + alibi = "alibi" + + +class RotateHalfConfig(DeepSpeedConfigModel): + + use_trained_freqs: bool = False + """ + Whether to use a passed `trained_freqs` tensor for the attention implementation + or to use default synthesized frequencies. + """ + + theta_base: float = 10_000.0 + """ + Base for theta. This will only be used if `use_trained_freqs` is False. + """ + + rotate_dim: Optional[int] = None + """ + How many neurons to rotate. If None, then all neurons will be rotated. Many external configs + will set this number to half the head dimension and then internally multiply by 2. To make it + more clear to understand what is happening (rotate_dim < head_dim -> then only partial rotation), + we do not do this multiplication internally. + """ + + +class MaskingType(Enum): + + # No masking + none = "none" + + # Causal masking + causal = "causal" + + # Local masking + local = "local" + + # Symmetric masking (this is a 1D tensor mask) + symmetric = "symmetric" + + # Arbitrary masking (this would correspond to a 2D tensor mask) + asymmetric = "asymmetric" + + +class DSSelfAttentionConfig(DSModuleConfig): + """ + Config class for attention. + """ + + # Number of query attention heads on this shard + n_heads_q: int + + # Number of KV attention heads on this shard + n_heads_kv: int + + # Size of each attention head + head_size: int + + # Max number of sequences that may compose a ragged batch + max_sequences: int + + # Scale factor for attention scores + scale_factor: float = 1.0 + + # Input data type + input_dtype: DtypeEnum = DtypeEnum.fp16 + + # Output data type + output_dtype: DtypeEnum = DtypeEnum.fp16 + + # Masking type + masking_type: MaskingType = MaskingType.causal + + # Masking args + masking_args: Dict = {} + + # Positional embedding type + positional_embedding_type: PositionalEmbeddingType = PositionalEmbeddingType.none + + # Positional embedding args + positional_embedding_config: Optional[RotateHalfConfig] = None + """ + To extend this for the other positional embedding types, we would need to add + new configs for each type (as necessary) and annotate this with the + Union[RotateHalfConfig, OtherConfig, ...] type. + """ diff --git a/deepspeed/inference/v2/modules/configs/embedding_config.py b/deepspeed/inference/v2/modules/configs/embedding_config.py new file mode 100644 index 000000000000..2486c5986e95 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/embedding_config.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig +""" +Trying to define the space we need to support here right now: + +Types of embeddings I've found so far: + 1. Token embedding + 2. Position embedding + 3. Token type embedding + 4. LN + +GPTNeo: 1, 2, 3 (shared with 1) +GPTNeoX: 1 +GPTJ: 1, 3 +LLaMA: 1 +BERT: 1, 2, 3, 4 +GPT2: 1, 2, 3 (shared with 1) + +Sidebar for OPT: +OPT: 1, 2 +1 may not actually project to the actual hidden dimension according to the raw +code, but for the model configs we care about it does. +2 has a weird offset associated with it that the others do not. +""" + + +class DSEmbeddingsConfig(DSModuleConfig): + """ + Config class for DSEmbeddings. + """ + + residual_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type the module should use for its output. + """ + + embedding_dim: int + """ + Dimensionality of the embedding projections. + """ + + positional_embedding: bool = False + """ + Whether the module should expect a positional embedding matrix. The shape of this + matrix should be of shape [max_seq_len + positional_offset, embedding_dim] + """ + + positional_offset: int = 0 + """ + Whether the linearized token IDs should be offset by a certain amount. For an example + of this, see the OPT model implementation. + """ + + use_token_type: bool = False + """ + Whether the module should expect a token type embedding matrix. + """ + + output_normalization: Optional[NormTypeEnum] = None + """ + If a the output of the embedding module should be normalized, specify here. See + ``inference.inference_utils.NormTypeEnum`` for supported values. + """ diff --git a/deepspeed/inference/v2/modules/configs/linear_config.py b/deepspeed/inference/v2/modules/configs/linear_config.py new file mode 100644 index 000000000000..40fe0773aeee --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/linear_config.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import ActivationType, DtypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSLinearConfig(DSModuleConfig): + """ + Config class for DSLinearBase. + """ + + in_channels: int + """ + Number of input channels + """ + + out_channels: int + """ + Number of output channels. NOTE: If this linear layer is using a gated activation function, + the value for ``out_channels`` passed here should refer to the number of channels after + gating (i.e., the expected weight shape before transformations will be ``[out_channels * 2, in_channels]``). + """ + + activation: ActivationType = ActivationType.IDENTITY + """ + The activation function for this layer. See :class:`deepspeed.inference.inference_utils.ActivationType` for + supported activation functions. + """ + + input_dtype: DtypeEnum = DtypeEnum.fp16 + """ + The data type of the input tensor. See :class:`deepspeed.inference.inference_utils.DtypeEnum` for supported + data types. + """ + + output_dtype: DtypeEnum = DtypeEnum.fp16 + """ + The data type of the output tensor. See :class:`deepspeed.inference.inference_utils.DtypeEnum` for supported + data types. + """ diff --git a/deepspeed/inference/v2/modules/configs/moe_config.py b/deepspeed/inference/v2/modules/configs/moe_config.py new file mode 100644 index 000000000000..7bc944f55e17 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/moe_config.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import ActivationType, DtypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSMoEConfig(DSModuleConfig): + """ + Config class for DSMoEBase + """ + + model_dim: int + """ + Size of input activation. + """ + + intermediate_features: int + """ + Size of intermediate activation. Specifically, this is the number of input features + in the second linear layer. Depending on the activation function, the output of the first + linear layer may have increased dimensionality. + """ + + n_experts: int + """ + Number of experts. + """ + + top_k: int = 1 + """ + top-k gating function (like top-1 or top-2) + """ + + input_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type for the input activations. + """ + + output_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type for the output activations. + """ + + activation: ActivationType = ActivationType.IDENTITY + """ + Activation function of the first MLP1 + """ + + normalize_scores: bool = False + """ + Whether normalization is applied to the selected scores. If true, the module + should rescale the scores such that their sum is 1.0. + """ diff --git a/deepspeed/inference/v2/modules/configs/norm_config.py b/deepspeed/inference/v2/modules/configs/norm_config.py new file mode 100644 index 000000000000..358982253756 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/norm_config.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSNormConfig(DSModuleConfig): + """ + Config class for both DSPreLN and DSPostLN. + """ + + # Type of normalization + type: NormTypeEnum + + # Number of channels in the model embedding + channels: int + + # Data type of the residual input/outputs (we assume the residual must + # be the same data type for the entire model). + residual_dtype: DtypeEnum = DtypeEnum.fp16 + + # Data type of the hidden states input + input_dtype: DtypeEnum = DtypeEnum.fp16 + + # Data type of the hidden states output + output_dtype: DtypeEnum = DtypeEnum.fp16 + + # Epsilon value for numerical stability + eps: float = 1e-5 diff --git a/deepspeed/inference/v2/modules/configs/unembed_config.py b/deepspeed/inference/v2/modules/configs/unembed_config.py new file mode 100644 index 000000000000..ea4cc3cc99c1 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/unembed_config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig +from typing import Optional + + +class DSUnembedConfig(DSModuleConfig): + """ + Config class for DSUnembed + """ + + dtype: DtypeEnum = DtypeEnum.fp16 + """ + Expected data type. + """ + + norm_type: Optional[NormTypeEnum] = None + """ + Whether the input to the unembed is normalized prior to the unembedding projection. + """ + + model_dim: int + """ + Model embedding size. + """ + + max_sequences: int + """ + Max sequences composing the ragged batch. + """ + + vocab_size: int + """ + Local vocab size (the full vocab size may have been sharded across model parallel ranks) + """ diff --git a/deepspeed/inference/v2/modules/ds_module.py b/deepspeed/inference/v2/modules/ds_module.py new file mode 100644 index 000000000000..2a6d294f3266 --- /dev/null +++ b/deepspeed/inference/v2/modules/ds_module.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractstaticmethod +from typing import Any, Dict, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class DSModuleConfig(DeepSpeedConfigModel): + + max_tokens: int + + +class DSModuleBase(torch.nn.Module, ABC): + """ + Base class for all DeepSpeed Inference modules. This class establishes + the basic attributes of a DSModule. Only abstract functionality modules should inherit + directly from this class, not specific implementations. + """ + + @abstractstaticmethod + def name() -> str: + """ + Return a memorable, human-readable name for this module. + + This will be used as a key in custom inference configurations and should only + be implemented by the children of functionality modules. + """ + ... + + @abstractstaticmethod + def config_class() -> Type[DSModuleConfig]: + """ + Return the associated config class for this module. + + This should be implemented (along with the config class) by an abstract functionality + module. + """ + ... + + @abstractstaticmethod + def supports_config(config: DSModuleConfig) -> bool: + """ + Return whether or not this module supports the given config. + + This should be implemented by the children of functionality modules and should report + whether it would be feasible to instantiate this module with the given config. + """ + ... + + def __init__(self, config: DSModuleConfig, implementation_config: Dict[str, Any] = {}) -> None: + """ + Initialize the module with the given config. + """ + super().__init__() + self._config = config + self._implementation_config = implementation_config diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py new file mode 100644 index 000000000000..f719e299a4b2 --- /dev/null +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config_v2 import RaggedInferenceEngineConfig +from ..inference_utils import NormTypeEnum + +from .module_registry import ConfigBundle +from ..modules.configs import ( + DSEmbeddingsConfig, + DSLinearConfig, + DSMoEConfig, + DSNormConfig, + DSSelfAttentionConfig, + DSUnembedConfig, +) +from ..modules.interfaces import ( + DSEmbeddingBase, + DSEmbeddingRegistry, + DSLinearBase, + DSLinearRegistry, + DSMoEBase, + DSMoERegistry, + DSPostNormBase, + DSPostNormRegistry, + DSPreNormBase, + DSPreNormRegistry, + DSSelfAttentionBase, + DSSelfAttentionRegistry, + DSUnembedBase, + DSUnembedRegistry, +) + + +def instantiate_attention(attention_config: DSSelfAttentionConfig, + engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: + """ + Choose an appropriate attention implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + attention_config (DSSelfAttentionConfig): Configuration for the attention module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An attention module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="dense_blocked_attention", config=attention_config) + return DSSelfAttentionRegistry.instantiate_config(config) + + +def instantiate_embed(embed_config: DSEmbeddingsConfig, engine_config: RaggedInferenceEngineConfig) -> DSEmbeddingBase: + """ + Choose an appropriate embedding implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + embed_config (DSEmbeddingsConfig): Configuration for the embedding module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An embedding module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="ragged_embedding", config=embed_config) + return DSEmbeddingRegistry.instantiate_config(config) + + +def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInferenceEngineConfig) -> DSLinearBase: + """ + Choose an appropriate linear implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + linear_config (DSLinearConfig): Configuration for the linear module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A linear module implementing the given configuration. + """ + + quantization_mode = engine_config.quantization.quantization_mode + if quantization_mode is None: + config = ConfigBundle(name="blas_fp_linear", config=linear_config) + else: + # Currently, we only support ``quantized_wf6af16_linear`` on NVIDIA Ampere GPUs. + if quantization_mode == "wf6af16": + import torch + if not torch.cuda.is_available(): #ignore-cuda + raise ValueError("WF6AF16 quantization is only supported on CUDA") + else: + is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None + if is_rocm_pytorch: + raise ValueError("WF6AF16 quantization is only supported on NVIDIA GPUs") + elif torch.cuda.get_device_properties(0).major != 8: #ignore-cuda + raise ValueError("WF6AF16 quantization is only supported on Ampere architectures") + config = ConfigBundle(name="quantized_wf6af16_linear", config=linear_config) + else: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") + return DSLinearRegistry.instantiate_config(config) + + +def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngineConfig) -> DSMoEBase: + """ + Choose an appropriate MoE implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + moe_config (DSMoEConfig): Configuration for the MoE module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A MoE module implementing the given configuration. + """ + + moe_type = "cutlass_multi_gemm_moe" + + if moe_type == "cutlass_multi_gemm_moe": + # TODO: Get this off an engine config + implementation_config = { + "weight_dtype": moe_config.input_dtype, + } + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="cutlass_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) + return DSMoERegistry.instantiate_config(config) + + +def instantiate_post_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPostNormBase: + """ + Choose an appropriate post-norm implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + norm_config (DSNormConfig): Configuration for the post-norm module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A post-norm module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="cuda_post_ln", config=norm_config) + return DSPostNormRegistry.instantiate_config(config) + + +def instantiate_pre_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPreNormBase: + """ + Choose an appropriate pre-norm implementation based on the given configurations. Currently, + this will select between two CUDA implementations, one for LayerNorm and one for RMSNorm. + + Arguments: + norm_config (DSNormConfig): Configuration for the pre-norm module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A pre-norm module implementing the given configuration. + """ + if NormTypeEnum(norm_config.type) == NormTypeEnum.LayerNorm: + module_name = "cuda_pre_ln" + elif NormTypeEnum(norm_config.type) == NormTypeEnum.RMSNorm: + module_name = "cuda_pre_rms" + + config = ConfigBundle(name=module_name, config=norm_config) + return DSPreNormRegistry.instantiate_config(config) + + +def instantiate_unembed(unembed_config: DSUnembedConfig, engine_config: RaggedInferenceEngineConfig) -> DSUnembedBase: + """ + Choose an appropriate unembedding implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + unembed_config (DSUnembedConfig): Configuration for the unembed module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An unembed module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="ragged_unembed", config=unembed_config) + return DSUnembedRegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/__init__.py b/deepspeed/inference/v2/modules/implementations/__init__.py new file mode 100644 index 000000000000..1b500a9a0b5a --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Imports for registering ops +from .attention import * +from .linear import * +from .post_norm import * +from .pre_norm import * +from .embedding import * +from .unembed import * +from .moe import * diff --git a/deepspeed/inference/v2/modules/implementations/attention/__init__.py b/deepspeed/inference/v2/modules/implementations/attention/__init__.py new file mode 100644 index 000000000000..1ff593ac1368 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/attention/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .dense_blocked_attention import DSDenseBlockedAttention diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py new file mode 100644 index 000000000000..3515b3c2b690 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum +from ....kernels.ragged_ops import ( + AtomBuilder, + BlockedFlashAttn, + BlockedRotaryEmbeddings, + BlockedTrainedRotaryEmbeddings, + get_q_block_size, + get_kv_block_size, + LinearBlockedKVCopy, +) +from ....ragged import RaggedBatchWrapper, split_kv +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from ...interfaces import DSSelfAttentionBase, DSSelfAttentionRegistry +from ...configs import DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType + +try: + from functools import cached_property +except ImportError: + + def cached_property(func): + return property(func) + + +@DSSelfAttentionRegistry.register_module +class DSDenseBlockedAttention(DSSelfAttentionBase): + """ + Self attention implementation for dense, blocked self attention. + """ + + @staticmethod + def name() -> str: + return 'dense_blocked_attention' + + @staticmethod + def supports_config(config: DSSelfAttentionConfig) -> bool: + + if config.input_dtype != config.output_dtype: + return False + + if DtypeEnum(config.input_dtype) not in (DtypeEnum.fp16, DtypeEnum.bf16): + return False + + if PositionalEmbeddingType(config.positional_embedding_type) not in [ + PositionalEmbeddingType.none, PositionalEmbeddingType.rotate_half + ]: + return False + + if MaskingType(config.masking_type) != MaskingType.causal: + return False + + return True + + def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[str, Any]) -> None: + """ + Create the Attention DSModule. + + Args: + config (DSSelfAttentionConfig): The self attention config for all attention DSModules. + implementation_config (Dict[str, Any]): + There are two (dependent) potential components in the implementtion config. + + 1. `trained_freqs` - If the embedding weights for RoPE are trained, the implementation + config should contain {'trained_freqs': True}. This will mean the implementation will + expect a `trained_freqs` tensor in the `forward` method and will not synthesize the + values internally. + + 2. `theta_base` - The base value for synthesized frequencies in the rotary embeddings. + This will only be used if `trained_freqs` is False or not present in the `implementation_config`. If this is not included, the default value of 10000.0 will be used. + """ + super().__init__(config, implementation_config) + + embed_type = PositionalEmbeddingType(config.positional_embedding_type) + if embed_type == PositionalEmbeddingType.none: + self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype) + elif embed_type == PositionalEmbeddingType.rotate_half: + rotary_config = config.positional_embedding_config + assert rotary_config is not None, "Rotary config must be provided if using rotate_half as Positional Embedding Type." + + if rotary_config.use_trained_freqs: + # Theta and rotary dim are effectively embedded into either the values (theta) or the shape (rotary_dim) + # of the trained_freqs tensor. + self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype) + else: + theta_base = rotary_config.theta_base + rotary_dim = rotary_config.rotate_dim if rotary_config.rotate_dim is not None else self._config.head_size + self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype, rotary_dim, + theta_base) + + self._softmax_scale = self._config.scale_factor + + # TODO(cmikeh2): Attention kernel gets created here. + self._attn_kernel = BlockedFlashAttn(self._config.head_size, self._config.input_dtype) + self._atom_builder = AtomBuilder() + + self.model_dim = self._config.head_size * self._config.n_heads_q + self._output = torch.empty((self._config.max_tokens, self._config.head_size * self._config.n_heads_q), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # TODO(cmikeh2): Pre-allocate storage buffer for the attention atoms. + self._max_atoms = self._config.max_sequences + self._atoms = torch.empty((self._max_atoms, 8), dtype=torch.int32, device=get_accelerator().current_device()) + + alloc_func = RaggedUtilsBuilder().load().allocate_fast_host_buffer + self._atoms_shadow = alloc_func(self._atoms) + self._cur_atoms = 0 + + @cached_property + def kv_block_size(self) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + return get_kv_block_size(self._config.head_size) + + @cached_property + def q_block_size(self) -> int: + """ + Property to calculate blocking granularity for the query dimension. + This has no impact on the KV-cache structure, but will affect the + number of attention atoms associated with a batch. + """ + return get_q_block_size(self._config.head_size) + + def build_atoms(self, ragged_batch: RaggedBatchWrapper) -> None: + """ + Build the atoms for the attention kernel. + + Args: + ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + """ + host_atoms, n_atoms = self._atom_builder(self._atoms_shadow, ragged_batch, self.q_block_size, + self.kv_block_size) + + self._cur_atoms = n_atoms + self._atoms[:n_atoms].copy_(host_atoms[:n_atoms], non_blocking=True) + + def forward(self, + q_k_v: torch.Tensor, + kv_cache: torch.Tensor, + batch: RaggedBatchWrapper, + inv_freqs: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward implementation. + + Args: + q_k_v (torch.Tensor): Query/Key/Value projection Tensor of shape + [n_heads, (n_heads_q + 2 * n_heads_kv) * head_size]. + kv_cache (torch.Tensor): Blocked persistent cache of shape + [2, batch, block_size, n_heads_kv, head_size]. + batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + inv_freqs (Optional[torch.Tensor]): The inverse frequencies for the rotary embeddings if they + have been modified from synthesizable values. + """ + if inv_freqs is not None: + self._kv_copy(kv_cache, q_k_v, batch, inv_freqs) + else: + self._kv_copy(kv_cache, q_k_v, batch) + + q = q_k_v[:, :self._config.head_size * self._config.n_heads_q] + output = empty_from(self._output, q.shape) + k_cache, v_cache = split_kv(kv_cache) + + self._attn_kernel(output, q, k_cache, v_cache, self._atoms[:self._cur_atoms], self._softmax_scale) + + return output diff --git a/deepspeed/inference/v2/modules/implementations/embedding/__init__.py b/deepspeed/inference/v2/modules/implementations/embedding/__init__.py new file mode 100644 index 000000000000..5458a855abf4 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/embedding/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ragged_embedding import DSRaggedEmbedding diff --git a/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py new file mode 100644 index 000000000000..90cdd39d1be7 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum +from ....kernels.ragged_ops import RaggedEmbeddingKernel +from ....ragged import RaggedBatchWrapper +from ...interfaces import DSEmbeddingBase, DSEmbeddingRegistry +from ...configs import DSEmbeddingsConfig + + +@DSEmbeddingRegistry.register_module +class DSRaggedEmbedding(DSEmbeddingBase): + + @staticmethod + def name(): + return 'ragged_embedding' + + @staticmethod + def supports_config(config: DSEmbeddingsConfig) -> bool: + + if DtypeEnum(config.residual_dtype) not in [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]: + return False + + if config.use_token_type: + return False + + if config.output_normalization is not None: + return False + + try: + _ = RaggedEmbeddingKernel(config.residual_dtype, torch.int32, config.embedding_dim) + except ValueError: + return False + + return True + + def __init__(self, config: DSEmbeddingsConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self.embed_offset = self._config.positional_offset + + # TODO(cmikeh2): How do we want to avoid the int32 vs int64 issue? + self._ragged_embed = RaggedEmbeddingKernel(self._config.residual_dtype, torch.int32, + self._config.embedding_dim) + + self._output = torch.empty((self._config.max_tokens, self._config.embedding_dim), + dtype=self._config.residual_dtype, + device=get_accelerator().current_device()) + + @property + def output(self) -> torch.Tensor: + return self._output + + def forward(self, + ragged_batch: RaggedBatchWrapper, + word_embeddings: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + word_embeddings (torch.Tensor): The word embedding table + """ + output = empty_from(self._output, (ragged_batch.tensor_toks, self._config.embedding_dim)) + self._ragged_embed(output, + ragged_batch, + word_embeddings, + position_embed_weight=position_embeddings, + position_embed_offset=self.embed_offset) + return output diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py new file mode 100644 index 000000000000..0501af54c4e6 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blas_fp_linear import BlasFPLinear +from .quantized_linear import QuantizedWf6Af16Linear, fp_quantize diff --git a/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py b/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py new file mode 100644 index 000000000000..c58dab0b826b --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import is_gated +from ....kernels.core_ops import ( + BlasLibLinear, + CUDABiasActivation, + CUDAGatedActivation, +) + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig +from ....inference_parameter import InferenceParameter + + +@DSLinearRegistry.register_module +class BlasFPLinear(DSLinearBase): + """ + Linear DSModule based on BLAS library and standalone bias + activation kernel implementation. + """ + + @staticmethod + def name(): + return 'blas_fp_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if is_gated(config.activation): + try: + _ = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + else: + try: + _ = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._linear_impl = BlasLibLinear(self._config.input_dtype) + + if is_gated(config.activation): + self._is_gated = True + self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + else: + self._is_gated = False + self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.output_dtype) + return InferenceParameter.initialize(param) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) + + if self._is_gated: + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self._config.out_channels * 2)) + self._linear_impl(staging_output, hidden_states, w) + self._act_fn(output, staging_output, b) + else: + self._linear_impl(output, hidden_states, w) + self._act_fn(output, b) + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py new file mode 100644 index 000000000000..933cf55b2391 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ....allocator import empty_from +from ....inference_utils import is_gated +from ....kernels.core_ops import ( + CUDAWf6Af16Linear, + CUDABiasActivation, + CUDAGatedActivation, +) + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig +from ....inference_parameter import InferenceParameter + + +def fp_quantize(input: torch.FloatTensor, + num_bits: int = 6, + exp_bits: int = 3, + min_value: torch.FloatTensor = None, + max_value: torch.FloatTensor = None, + group_size: int = -1): + """ + Args: + inputs (`torch.FloatTensor`) + The input which needs to be quantized + num_bits (int, >=4) + Number of bits to use for quantization + exp_bits: + fp exp_bits + min_value/max_vlue (torch.FloatTensor) + Used for static activation quantization + group_size (int) N + The quantization block size, each N numbers has its own scaling + factor and off-site. -1 means use the last dim as the group_size + Returns: + quantized_fake_fp6 + The quantized weights, in fp16 format and contains fp6 value. + scales + Quantization scales + """ + + try: + from qtorch.quant import float_quantize + except ImportError: + raise ImportError("Please install qtorch to use this function") + + assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None) + + assert input.dtype == torch.float16 + + orig_device = input.device + input = input.to(torch.float32).to(get_accelerator().current_device()) + if num_bits == 6 and exp_bits == 3: # this is default + q_range = 28 + else: + raise NotImplementedError + + man_bits = num_bits - exp_bits - 1 + input_shape = input.shape + + if group_size == -1: + group_size = input_shape[-1] + else: + # Only support per-channel quantization + raise NotImplementedError + num_groups = input.numel() // group_size + input = input.reshape(num_groups, -1) + + if min_value is None: + max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1) + else: + max_input = torch.max(min_value.abs(), max_value) # .view(-1) + scales = max_input / q_range # q_range + 1 + scales[scales == 0] = 1 # avoid zero scales + scaled_input = input / scales + + quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest") + + quantized_fake_fp6 = quantized_fake_fp6.reshape(input_shape).contiguous().to(torch.float16).to(orig_device) + scales = scales.to(torch.float16).to(orig_device) + # Now the dequantized value is quantized_fake_fp6 * scales + + return quantized_fake_fp6, scales + + +@DSLinearRegistry.register_module +class QuantizedWf6Af16Linear(DSLinearBase): + """ + Linear DSModule for FP6 weight-only quantization kernel, where weight is FP6 + and activation is FP16. + """ + + @staticmethod + def name(): + return 'quantized_wf6af16_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + # As for fp6 data items, they are packed and stored in a set of fp16 + # tensors. E.g., 8 fp6 data items are stored in 3 fp16 tensor. + if config.input_dtype != torch.float16: + return False + + if is_gated(config.activation): + try: + _ = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + else: + try: + _ = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._linear_impl = CUDAWf6Af16Linear() + + if is_gated(config.activation): + # In the FP6 kernel implementation, the MatMul is W * A, where W is + # the weight and A is activation. M is the output channel size. + self.out_channels = self._config.out_channels * 2 + self.in_channels = self._config.in_channels + self._is_gated = True + self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + else: + self.out_channels = self._config.out_channels + self.in_channels = self._config.in_channels + self._is_gated = False + self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.preprocess_weight = self.inf_module.preprocess_weight + + self.quantizer = fp_quantize + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + # It expects that the quantization scales are store in the attribute `scales`. + + if param.ndim == 1: # bias, do nothing + return InferenceParameter.initialize(param) + + quantized_fake_fp6, scales = self.quantizer(param, num_bits=6, exp_bits=3) + + # This is for debugging, will delete before release. + assert (quantized_fake_fp6.dtype == torch.float16) + assert quantized_fake_fp6.shape[0] == self.out_channels + assert scales.numel() == self.out_channels + + weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) + + return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + weights_2bit = w + weights_4bit = w.weights_4bit + scales = w.scales + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) + if self._is_gated: + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self.out_channels)) + self._linear_impl(staging_output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) + self._act_fn(output, staging_output, b) + else: + self._linear_impl(output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) + self._act_fn(output, b) + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/deepspeed/inference/v2/modules/implementations/moe/__init__.py b/deepspeed/inference/v2/modules/implementations/moe/__init__.py new file mode 100644 index 000000000000..053ad5da7746 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cutlass_multi_gemm import DSMultiGemmMoE diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py new file mode 100644 index 000000000000..a9b01d1233cd --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation +from ....kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTopKGating, +) +from ....ragged import RaggedBatchWrapper + +from ...interfaces import DSMoEBase, DSMoERegistry +from ...configs import DSMoEConfig +from ....kernels.cutlass_ops import MoEGEMM +from ....inference_parameter import InferenceParameter + + +@DSMoERegistry.register_module +class DSMultiGemmMoE(DSMoEBase): + """ + MoE implementation based on the CUTLASS multi-GEMM. + """ + + @staticmethod + def name(): + return 'cutlass_multi_gemm_moe' + + @staticmethod + def supports_config(config: DSMoEConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8: + return False + + return True + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + # Convenience variables for frequently accessed items. + self.max_tokens = self._config.max_tokens + self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k + self.intermediate_dim = self._config.intermediate_features + + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) + self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + + self._gate_proj = BlasLibLinear(self._config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) + self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) + + self._create_buffers() + + def _create_buffers(self): + + # Gating buffers + self._logits = torch.empty((self._config.max_tokens, self.n_experts), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_counts = torch.empty((self.n_experts, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.float32, + device=get_accelerator().current_device()) + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # Scatter buffers + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_cumsum = torch.empty((self._config.n_experts, ), + dtype=torch.int64, + device=get_accelerator().current_device()) + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # GEMM Buffers + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # Gather buffer + self._output = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Ensures gate param is going to match the activation data type. + """ + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + if len(param.shape) == 3: + param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + if len(param.shape) == 3: + param = param.permute(0, 2, 1).contiguous() + return InferenceParameter.initialize(param) + + @property + def output(self) -> torch.Tensor: + return self._output + + def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helper function to isolate the logit for gating. This will take the hidden states + and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has + been padded for CG, this will strip the padding for MoE. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. + batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) + """ + + # Get views on the buffers for gating + logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) + + self._gate_proj(logits, hidden_states, gate_w) + self._expert_counts.zero_() + self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) + self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, + assignments, offsets) + + return moe_input, self._expert_cumsum, scores, mapped_slots + + def forward(self, + hidden_states: torch.Tensor, + batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + MoE forward pass built on top of CUTLASS multi-GEMM. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + """ + + moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) + + # Get views on the buffers for GEMM + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) + output_unordered = empty_from(self._output_unordered, + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) + output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) + + if self._activation is not None: + gated_intermediate = empty_from( + self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + self._mlp_1( + gated_intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + self._activation(intermediate, gated_intermediate) + else: + self._mlp_1( + intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + + self._mlp_2( + output_unordered, + intermediate, + mlp_2_w, + expert_cumsum, + mlp_2_b, + ) + + self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) + return output diff --git a/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py b/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py new file mode 100644 index 000000000000..653a2fe4fb5b --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_post_ln import DSPostLNCUDAModule diff --git a/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py b/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py new file mode 100644 index 000000000000..9b2af4bb9023 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPostNormBase, DSPostNormRegistry +from ...configs import DSNormConfig +from ....kernels.core_ops.cuda_layer_norm.cuda_post_ln import CUDAFPPostLN +from ....allocator import empty_from +from ....inference_parameter import InferenceParameter + + +@DSPostNormRegistry.register_module +class DSPostLNCUDAModule(DSPostNormBase): + + @staticmethod + def name(): + return 'cuda_post_ln' + + @staticmethod + def supports_config(config: DSNormConfig): + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + _ = CUDAFPPostLN(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_post_ln = CUDAFPPostLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + self._output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def forward(self, residual: torch.Tensor, hidden_in: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + """ + self._residual_output = empty_from(self._output, residual.shape) + self._fp_post_ln(residual, residual, hidden_in, gamma, beta) + return residual, residual diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py b/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py new file mode 100644 index 000000000000..12605f13f955 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_pre_ln import DSPreLNCUDAModule +from .cuda_pre_rms import DSPreRMSCUDAModule diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py new file mode 100644 index 000000000000..90783ce8c9a6 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPreNormBase, DSPreNormRegistry +from ...configs import DSNormConfig, NormTypeEnum +from ....kernels.core_ops.cuda_layer_norm.cuda_pre_ln import CUDAFPPreLN +from ....kernels.core_ops.cuda_layer_norm.cuda_ln import CUDAFPLN +from ....allocator import empty_from +from ....inference_parameter import InferenceParameter + + +@DSPreNormRegistry.register_module +class DSPreLNCUDAModule(DSPreNormBase): + + @staticmethod + def name(): + return 'cuda_pre_ln' + + @staticmethod + def supports_config(config: DSNormConfig): + type = NormTypeEnum(config.type) + if type != NormTypeEnum.LayerNorm: + return False + + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + _ = CUDAFPPreLN(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_pre_ln = CUDAFPPreLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + self._fp_ln = CUDAFPLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + # Buffers for the hidden output (residual is updated in-place) + self._hidden_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def forward(self, residual: torch.Tensor, hidden_in: Optional[torch.Tensor], gamma: torch.Tensor, + beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + + If hidden_in is None, that means we do not need to perform the residual add and will + only return the hidden output modified. + """ + hidden_out = empty_from(self._hidden_output, residual.shape) + if hidden_in is None: + self._fp_ln(hidden_out, residual, gamma, beta) + else: + self._fp_pre_ln(residual, hidden_out, residual, hidden_in, gamma, beta) + return residual, hidden_out diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py new file mode 100644 index 000000000000..986262b31b1f --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPreNormBase, DSPreNormRegistry +from ...configs import DSNormConfig, NormTypeEnum +from ....kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm +from ....allocator import empty_from +from ....inference_parameter import InferenceParameter + + +@DSPreNormRegistry.register_module +class DSPreRMSCUDAModule(DSPreNormBase): + + @staticmethod + def name(): + return 'cuda_pre_rms' + + @staticmethod + def supports_config(config: DSNormConfig): + type = NormTypeEnum(config.type) + if type != NormTypeEnum.RMSNorm: + return False + + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + # Only need to check one since the support matrix for the two rms kernels is the same + _ = CUDARMSPreNorm(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_rms = CUDARMSNorm(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + self._fp_rms_pre = CUDARMSPreNorm(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + # Buffers for both the hidden and residual outputs + self._hidden_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + self._residual_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + param = param.to(self._config.input_dtype) + return InferenceParameter.initialize(param) + + def forward(self, + residual: torch.Tensor, + hidden_in: Optional[torch.Tensor], + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + + If hidden_in is None, that means we do not need to perform the residual add and will + only return the hidden output modified. + """ + assert beta is None, "Beta is not supported for RMSNorm" + + hidden_out = empty_from(self._hidden_output, residual.shape) + if hidden_in is None: + self._fp_rms(hidden_out, residual, gamma) + residual_out = residual + else: + residual_out = empty_from(self._residual_output, residual.shape) + self._fp_rms_pre(residual_out, hidden_out, residual, hidden_in, gamma) + return residual_out, hidden_out diff --git a/deepspeed/inference/v2/modules/implementations/unembed/__init__.py b/deepspeed/inference/v2/modules/implementations/unembed/__init__.py new file mode 100644 index 000000000000..4a5fd24d518b --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/unembed/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ragged_unembed import DSRaggedUnembed diff --git a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py new file mode 100644 index 000000000000..36130902c665 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum, ActivationType +from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm, CUDABiasActivation +from ....kernels.ragged_ops import RaggedLogitsGather +from ....ragged import RaggedBatchWrapper +from ...interfaces import DSUnembedBase, DSUnembedRegistry +from ...configs import DSUnembedConfig + + +@DSUnembedRegistry.register_module +class DSRaggedUnembed(DSUnembedBase): + """ + Ragged unembedding implementation. This implementation will gather only the last token + of each sequence in the ragged inflight batch and calculate the logits only for those rows. + """ + + @staticmethod + def name(): + return 'ragged_unembed' + + @staticmethod + def supports_config(config: DSUnembedConfig): + + if DtypeEnum(config.dtype) not in [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]: + return False + + try: + _ = RaggedLogitsGather(config.model_dim, config.dtype) + except ValueError: + return False + + if config.norm_type == 'rms_norm': + try: + _ = CUDARMSNorm(config.model_dim, config.dtype) + except ValueError: + return False + elif config.norm_type == 'layer_norm': + try: + _ = CUDAFPLN(config.model_dim, config.dtype) + except ValueError: + return False + + return True + + def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._logits_gather = RaggedLogitsGather(config.model_dim, self._config.dtype) + + if self._config.norm_type == 'layer_norm': + self._norm = CUDAFPLN(self._config.model_dim, self._config.dtype) + elif self._config.norm_type == 'rms_norm': + self._norm = CUDARMSNorm(self._config.model_dim, self._config.dtype) + else: + self._norm = None + + self._linear = BlasLibLinear(self._config.dtype) + # Here the activation kernel is being used to apply bias, hence the identity activation type! + self._act_fn = CUDABiasActivation(self._config.vocab_size, self._config.dtype, ActivationType.IDENTITY) + + self._intermediate = torch.empty((self._config.max_sequences, self._config.model_dim), + dtype=self._config.dtype, + device=get_accelerator().current_device()) + + self._output = torch.empty((self._config.max_sequences, self._config.vocab_size), + dtype=self._config.dtype, + device=get_accelerator().current_device()) + + @property + def output(self) -> torch.Tensor: + return self._output + + def forward(self, + hidden_states: torch.Tensor, + vocab_embedding: torch.Tensor, + ragged_metadata: RaggedBatchWrapper, + bias: Optional[torch.Tensor] = None, + gamma: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Return final model logits. + + Args: + hidden_states (torch.Tensor): The hidden states from the model. This is the output of the + final layer of the model. + vocab_embedding (torch.Tensor): The vocab embedding table. + raged_metadata (RaggedBatchWrapper): The ragged batch metadata. + gamma (Optional[torch.Tensor]): The gamma tensor for normalization. + beta (Optional[torch.Tensor]): The beta tensor for normalization. + """ + + cut_down_hidden_states = empty_from(self._intermediate, + (ragged_metadata.current_sequences, self._config.model_dim)) + self._logits_gather(cut_down_hidden_states, hidden_states, ragged_metadata) + + if self._config.norm_type == 'rms_norm': + if gamma is None: + raise ValueError('RMS Normalization enabled but gamma not provided.') + self._norm(cut_down_hidden_states, cut_down_hidden_states, gamma) + elif self._config.norm_type == 'layer_norm': + if gamma is None or beta is None: + raise ValueError('Normalization enabled but gamma and/or beta not provided.') + self._norm(cut_down_hidden_states, cut_down_hidden_states, gamma, beta) + + output = empty_from(self._output, (ragged_metadata.current_sequences, self._config.vocab_size)) + self._linear(output, cut_down_hidden_states, vocab_embedding) + if bias is not None: + self._act_fn(output, bias) + + return output diff --git a/deepspeed/inference/v2/modules/interfaces/__init__.py b/deepspeed/inference/v2/modules/interfaces/__init__.py new file mode 100644 index 000000000000..13b556789e4e --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attention_base import DSSelfAttentionRegistry, DSSelfAttentionBase +from .embedding_base import DSEmbeddingRegistry, DSEmbeddingBase +from .linear_base import DSLinearRegistry, DSLinearBase +from .moe_base import DSMoERegistry, DSMoEBase +from .post_norm_base import DSPostNormRegistry, DSPostNormBase +from .pre_norm_base import DSPreNormRegistry, DSPreNormBase +from .unembed_base import DSUnembedRegistry, DSUnembedBase diff --git a/deepspeed/inference/v2/modules/interfaces/attention_base.py b/deepspeed/inference/v2/modules/interfaces/attention_base.py new file mode 100644 index 000000000000..c67dc033f92a --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/attention_base.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from ...ragged import RaggedBatchWrapper +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSSelfAttentionConfig + + +class DSSelfAttentionBase(DSModuleBase): + """ + Base mixin for all attention modules. The interface represented by this module + is broadly: + + output = attention(query_key_value, + Optional[kv_cache], + Optional[attention_mask], + Optional[attention_bias]) + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSSelfAttentionConfig + + def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @property + def kv_block_size(self) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + raise NotImplementedError() + + @property + def q_block_size(self) -> int: + """ + Property to calculate blocking granularity for the query dimension. + This has no impact on the KV-cache structure, but will affect the + number of attention atoms associated with a batch. + """ + raise NotImplementedError() + + def build_atoms(self, ragged_batch: RaggedBatchWrapper) -> None: + """ + Build the atoms for this module. This is not a strict requirement for the class, + so this method is a no-op by default rather than abstract. + """ + pass + + def forward(self, + q_k_v: torch.Tensor, + kv_cache: torch.Tensor, + batch: RaggedBatchWrapper, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inv_freqs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + q_k_v (torch.Tensor): Query, key, and value tensors. Expected shape is: + [ + batch, + seq_len, + 2 * self._config.n_heads_kv + self._config.n_heads_q, + self._config.head_size + ]. + kv_cache (Optional[torch.Tensor]): Key and value cache tensor. Expected shape is + [ + 2, + batch, + kv_cache_len, + self._config.n_heads_kv, + self._config.head_size + ]. If None, cache is disabled. The `kv_cache_len` dimension does not need to + be contiguous (it should expand stride by `max_out_tokens`). + batch (RaggedBatchWrapper): Ragged batch metadata. + attention_mask (Optional[torch.Tensor]): Attention mask tensor. If None, masking is + disabled. This will defer to the config in the case of conflicting information. + This means if the config class is implying causal attention, the mask will be ignored. + attention_bias (Optional[torch.Tensor]): Attention bias tensor. If None, bias is disabled. + """ + raise NotImplementedError() + + +class DSSelfAttentionRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSSelfAttentionBase diff --git a/deepspeed/inference/v2/modules/interfaces/embedding_base.py b/deepspeed/inference/v2/modules/interfaces/embedding_base.py new file mode 100644 index 000000000000..1ab7e5f0b7a2 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/embedding_base.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ...ragged import RaggedBatchWrapper +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSEmbeddingsConfig +from ...inference_parameter import InferenceParameter + + +class DSEmbeddingBase(DSModuleBase): + """ + Base mixin for embedding modules. The interface represented by this module is: + + hidden_out = embedding(input_ids) + + position_embedding(position_ids) + + token_type_embedding(token_type_ids) + with optional normalization. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSEmbeddingsConfig + + def __init__(self, config: DSEmbeddingsConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + def transform_param(self, embed_param: torch.Tensor) -> InferenceParameter: + """ + Perform any necessary transformations on an embedding parameter. This module assumes + that all embedding parameters would require the same set of transformations. + + Parameters: + embed_param (torch.Tensor): Embedding parameter. Shape is of [vocab_size, hidden_size] + """ + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Pre-allocated output Tensor. This currently needs to be exposed for gather operations + on the output. + + TODO(cmikeh2): This is not ideal. We need a better abstraction for this, such as giving + access to the inference comm object to the DSModule. + """ + raise NotImplementedError() + + def forward(self, + ragged_batch: RaggedBatchWrapper, + word_embeddings: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + token_type_embeddings: Optional[torch.Tensor] = None) -> InferenceParameter: + """ + Parameters: + ragged_batch (torch.Tensor): Ragged batch of token ids + associated metadata. + word_embeddings (torch.Tensor): Word embeddings. + position_embeddings (torch.Tensor): Position embeddings. If passed, IDs will be + inferred from the ragged batch itself. + token_type_ids (torch.Tensor): Token type ids. + token_type_embeddings (torch.Tensor): Token type embeddings. + + Returns: + torch.Tensor: Hidden states. This should be the sum of the relevant + encodings for the model. + """ + raise NotImplementedError() + + +class DSEmbeddingRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSEmbeddingBase diff --git a/deepspeed/inference/v2/modules/interfaces/linear_base.py b/deepspeed/inference/v2/modules/interfaces/linear_base.py new file mode 100644 index 000000000000..fe6ccbcd9344 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/linear_base.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSLinearConfig +from ...inference_parameter import InferenceParameter + + +class DSLinearBase(DSModuleBase): + """ + Base mixin for all Linear modules. The interface represented by this module + is: + + hidden_out = activation(hidden_in * weight + bias) + + The format and dtype of the weight and bias tensors are not defined and implementations + may compress as necessary. Must support a bias. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSLinearConfig + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Perform any necessary transformations of the parameters of this module. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + ... + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is either + [batch, seq_len, in_channels] or [batch, in_channels]. + + Returns: + torch.Tensor: Output tensor. Tensor should have same number of dimensions as + input tensor. + """ + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + ... + + +class DSLinearRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSLinearBase diff --git a/deepspeed/inference/v2/modules/interfaces/moe_base.py b/deepspeed/inference/v2/modules/interfaces/moe_base.py new file mode 100644 index 000000000000..78bdc0700f63 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/moe_base.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSMoEConfig +from ...inference_parameter import InferenceParameter + + +class DSMoEBase(DSModuleBase): + """ + Base mixing for MoE modules. The interface represented by this module is: + + expert_assignments = gate(hidden_states) + intermediate = ragged_linear(hidden_states, expert_assignments) + output = ragged_linear(intermediate, expert_assignments) + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSMoEConfig + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Perform any necessary transformations of the gate parameter. + + Args: + param (torch.Tensor): gate_w (shape: [num_experts, model_dim]) + """ + ... + + @abstractmethod + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Perform any necessary transformations of the parameter. The specific component + being transformed should be inferred from the shape of the parameter. + + Args: + param (torch.Tensor): One of either mlp_1_w, mlp_1_b + """ + ... + + @abstractmethod + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Perform any necessary transformations of the parameter. The specified component being + transformed should be inferred from the shape of the parameter. This interface is + separate from transform_moe_1_param because the two components may have identical + shapes. + + Args: + param (torch.Tensor): One of either mlp_2_w or mlp_2_b + """ + ... + + def forward(self, + hidden_states: torch.Tensor, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Returns the pre-allocated, padded output Tensor. + """ + ... + + +class DSMoERegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSMoEBase diff --git a/deepspeed/inference/v2/modules/interfaces/post_norm_base.py b/deepspeed/inference/v2/modules/interfaces/post_norm_base.py new file mode 100644 index 000000000000..cc80e5c94bf7 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/post_norm_base.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..configs.norm_config import DSNormConfig +from ..module_registry import DSModuleRegistryBase +from ...inference_parameter import InferenceParameter + + +class DSPostNormBase(DSModuleBase): + """ + Base MixIn for all Post-Normalization modules. The interface represented by this + module is: + + residual, hidden_out = norm(residual + hidden_in) + + If residual and hidden_out are the same data type, then they may alias each other. + Furthermore, residual should be updated in-place. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSNormConfig + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Transform a gamma/beta parameter. It is assumed that both transformations are + the same. + + Parameters: + param (torch.Tensor): Gamma or beta parameter. + """ + ... + + def forward(self, + residual: torch.Tensor, + hidden_states: torch.Tensor, + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + residual (torch.Tensor): Residual tensor. + hidden_states (torch.Tensor): Hidden states tensor. + + Returns: + (torch.Tensor, torch.Tensor): Tuple of residual and hidden states. + Hidden states may alias with residual. + """ + raise NotImplementedError() + + +class DSPostNormRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSPostNormBase diff --git a/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py b/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py new file mode 100644 index 000000000000..84f51cff6947 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..configs.norm_config import DSNormConfig +from ..module_registry import DSModuleRegistryBase +from ...inference_parameter import InferenceParameter + + +class DSPreNormBase(DSModuleBase): + """ + Base mixin for all Pre-Normalization modules. The interface represented by this module + is: + + if hidden_in is not None: + residual_out = residual + hidden_in + else: + residual_out = residual + + hidden_out = normalize(residual_out) + return residual_out, hidden_out + + Residual should be updated in-place. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSNormConfig + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Transform a gamma/beta parameter. It is assumed that both transformations are + the same. + + Parameters: + param (torch.Tensor): Gamma or beta parameter. + """ + ... + + def forward(self, + residual: torch.Tensor, + hidden_states: Optional[torch.Tensor], + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + residual (torch.Tensor): Residual tensor. + hidden_states (torch.Tensor): Hidden states tensor. + + Returns: + (torch.Tensor, torch.Tensor): Tuple of residual and hidden states. + """ + raise NotImplementedError() + + +class DSPreNormRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSPreNormBase diff --git a/deepspeed/inference/v2/modules/interfaces/unembed_base.py b/deepspeed/inference/v2/modules/interfaces/unembed_base.py new file mode 100644 index 000000000000..9eca6fcde768 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/unembed_base.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ...ragged import RaggedBatchWrapper +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSUnembedConfig + + +class DSUnembedBase(DSModuleBase): + """ + Base mixin for unmebedding modules. The interface represented by this module is: + + if config.do_normalization + hidden = layer_norm(hidden) + logits = hidden @ projection + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSUnembedConfig + + def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + def forward(self, + hidden_states: torch.Tensor, + vocab_embedding: torch.Tensor, + ragged_metadata: RaggedBatchWrapper, + gamma: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward interface. Gamma and beta are optional parameters passed depending on + `self.config.do_normalization`. + + Args: + hidden_states (torch.Tensor): Hidden states of shape [tokens, model_dim] + vocab_embedding (torch.Tensor): Embedding matrix of shape [vocab_size, model_dim] + ragged_metadata (RaggedBatchWrapper): Metadata for the ragged batch. + gamma (Optional[torch.Tensor]): Gamma parameter for layer norm. + beta (Optional[torch.Tensor]): Beta parameter for layer norm. + + Returns: + torch.Tensor: Unembedded hidden states of shape [n_seqs, model_dim] + """ + raise NotImplementedError() + + +class DSUnembedRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSUnembedBase diff --git a/deepspeed/inference/v2/modules/module_registry.py b/deepspeed/inference/v2/modules/module_registry.py new file mode 100644 index 000000000000..e04b8d734518 --- /dev/null +++ b/deepspeed/inference/v2/modules/module_registry.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractstaticmethod +from typing import Any, Dict, Type + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .ds_module import DSModuleBase + + +class ConfigBundle(DeepSpeedConfigModel): + """ + A config bundle is a collection of configs that are used to instantiate a model implementation. + """ + name: str + config: DeepSpeedConfigModel + implementation_config: Dict[str, Any] = {} + + +class DSModuleRegistryBase(ABC): + """ + Class holding logic for tracking the DSModule implementations of a given interface. + """ + + @classmethod + def instantiate_config(cls, config_bundle: ConfigBundle) -> DSModuleBase: + """ + Given a DSModule key, attempt to instantiate + """ + if config_bundle.name not in cls.registry: + raise KeyError(f"Unknown DSModule: {config_bundle.name}, cls.registry={cls.registry}") + + target_implementation = cls.registry[config_bundle.name] + if not target_implementation.supports_config(config_bundle.config): + raise ValueError(f"Config {config_bundle.config} is not supported by {target_implementation}") + + return cls.registry[config_bundle.name](config_bundle.config, config_bundle.implementation_config) + + @abstractstaticmethod + def associated_class() -> Type[DSModuleBase]: + """ + Return the class associated with this registry. + """ + raise NotImplementedError("Must associated a DSModule class with its registry.") + + @classmethod + def register_module(cls, child_class: DSModuleBase) -> None: + """ + Register a module with this registry. + """ + if not issubclass(child_class, cls.associated_class()): + raise TypeError( + f"Can only register subclasses of {cls.associated_class()}, {child_class} does not inherit from {cls.associated_class()}" + ) + cls.registry[child_class.name()] = child_class + return child_class diff --git a/deepspeed/inference/v2/ragged/__init__.py b/deepspeed/inference/v2/ragged/__init__.py new file mode 100644 index 000000000000..3af09cff4be5 --- /dev/null +++ b/deepspeed/inference/v2/ragged/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .kv_cache import split_kv +from .manager_configs import ( + AllocationMode, + DSStateManagerConfig, + KVCacheConfig, + MemoryConfig, +) +from .ragged_manager import DSStateManager +from .ragged_wrapper import RaggedBatchWrapper +from .sequence_descriptor import DSSequenceDescriptor, PlaceholderSequenceDescriptor diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py new file mode 100644 index 000000000000..7884d8cccb47 --- /dev/null +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Union + +import torch + + +class BlockedAllocator: + """ + Allocator class for managing which blocks are free/used in the + blocked KV-cache. This is a simple allocator that uses a linked list + to keep track of which blocks are free/used. The cost of allocation/deallocation + is O(blocks), where blocks is the number of blocks to allocate/deallocate. + + TODO(cmikeh2): Evaluate performance of this allocator and migrate + to C++ if necessary. + """ + # Number of blocks in the KV-cache(s). + _num_blocks: int + + # Array of blocks, where each element is the next block in the linked list. + _blocks: torch.Tensor + + # Index of the head of the linked list. + _head: int + + # Number of free blocks in the KV-cache. + _free_blocks: int + + def __init__(self, num_blocks: int) -> None: + """ + Initialize an allocator with `num_blocks` blocks. This requires at least + `num_blocks` * 4 bytes of host memory. + + Parameters: + num_blocks (int): The number of blocks to allocate. + """ + + if num_blocks < 1: + raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') + + self._num_blocks = num_blocks + self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + self._head = 0 + self._free_blocks = num_blocks + + def allocate(self, num_blocks: int) -> torch.Tensor: + """ + Allocate a list of blocks from the associated KV-caches. This will + return `num_blocks` blocks from the KV-cache if they are available, + or raise an exception if there are not enough free blocks. + + Parameters: + num_blocks (int): The number of blocks to allocate. + + Returns: + List[int]: The list of blocks allocated. + """ + if num_blocks > self._free_blocks: + raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') + + allocated_blocks = torch.zeros(num_blocks, dtype=torch.int32) + for i in range(num_blocks): + allocated_blocks[i] = self._head + self._head = self._blocks[self._head].item() + self._blocks[allocated_blocks[i]] = -1 # Mark as used + self._free_blocks -= 1 + + return allocated_blocks + + def free(self, blocks: Union[Iterable[int], int]) -> None: + """ + Return a list of blocks to the free pool. If a single invalid block is provided (i.e., + one that is out of range of the allocator or is already free), then an exception is raised + and no blocks are freed. + + Parameters: + blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block + is to be freed, this can be alone as an integer. + """ + if isinstance(blocks, int): + blocks = [blocks] + + for block in blocks: + # Parse all blocks for validity before mutating the list. + if block < 0 or block >= self._num_blocks: + raise ValueError(f'Invalid block {block} provided to free') + + if self._blocks[block] != -1: + raise ValueError(f'Block {block} is already free') + + for block in blocks: + self._blocks[block] = self._head + self._head = block + self._free_blocks += 1 + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV-cache. + """ + return self._free_blocks diff --git a/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu b/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu new file mode 100644 index 000000000000..31347636b50c --- /dev/null +++ b/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "fast_host_buffer.h" + +void* get_cuda_fast_buffer(int64_t size) +{ + void* buffer_ptr; + // Host allocation flags that should minimize the host -> accelerator copy latency + unsigned int alloc_flags = + cudaHostAllocPortable | cudaHostAllocMapped | cudaHostAllocWriteCombined; + + cudaHostAlloc(&buffer_ptr, size, alloc_flags); + return buffer_ptr; +} diff --git a/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp b/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp new file mode 100644 index 000000000000..ce115f993c3c --- /dev/null +++ b/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include + +#include "fast_host_buffer.h" + +/* +Similar to doing an empty_like to replicate a Tensor on the host, but will +attempt to optimize for faster host -> accelerator copies. Since this is on the critical +path for the forward pass, this should directly improve performance. +Allocates the shadow buffers for the input_ids, batch, seq and kv_ids tensors. + +Arguments: + device_mirror: A tensor on the accelerator that should be mirrored by the host. + +Returns: + A tensor on the host of the same size and datatype optimized for fast host -> accelerator +copies. +*/ +torch::Tensor allocate_fast_host_buffer(torch::Tensor device_mirror) +{ +#ifdef __HIP_PLATFORM_AMD__ + auto options = + torch::TensorOptions().device(torch::kCPU).pinned_memory(true).dtype(device_mirror.dtype()); + auto buffer = torch::empty(device_mirror.sizes(), options); +#else + + void* buffer_ptr = get_cuda_fast_buffer(device_mirror.numel() * device_mirror.element_size()); + + auto options = torch::TensorOptions().device(torch::kCPU).dtype(device_mirror.dtype()); + auto buffer = torch::from_blob(buffer_ptr, device_mirror.sizes(), options); +#endif + return buffer; +} + +torch::Tensor allocate_view_on(torch::Tensor& tensor, torch::Tensor& buffer, int64_t offset) +{ + int8_t* data = reinterpret_cast(buffer.data_ptr()); + + auto options = tensor.options().device(buffer.device()); + + return at::from_blob(data + offset, tensor.sizes(), tensor.strides(), options); +} + +torch::Tensor allocate_view_like(py::tuple shape, + py::tuple strides, + torch::Tensor& dummy_tensor, + torch::Tensor& buffer, + int64_t offset) +{ + int8_t* data = reinterpret_cast(buffer.data_ptr()); + + auto options = torch::TensorOptions().device(buffer.device()).dtype(dummy_tensor.dtype()); + + return at::from_blob(data + offset, + shape.cast>(), + strides.cast>(), + options); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("allocate_fast_host_buffer", + &allocate_fast_host_buffer, + "Allocate a host mirror of an accelerator Tensor."); + m.def("allocate_view_on", + &allocate_view_on, + "Allocate a view on a Tensor on the same device as the input Tensor."); + m.def("allocate_view_like", + &allocate_view_like, + "Allocate a view on a Tensor on the same device as the input Tensor."); +} diff --git a/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h b/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h new file mode 100644 index 000000000000..81f24ed8fdaa --- /dev/null +++ b/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +/* +Wrapper around cudaHostAlloc with some specific flags. Returns a pointer to the +memory region of `size` bytes. +*/ +void* get_cuda_fast_buffer(int64_t size); diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py new file mode 100644 index 000000000000..ceba3190b93c --- /dev/null +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -0,0 +1,208 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from functools import reduce +from typing import Any, Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist +from deepspeed.comm.reduce_op import ReduceOp + +from deepspeed.accelerator import get_accelerator +from ..inference_utils import elem_size +from ..logging import inference_logger +from .blocked_allocator import BlockedAllocator +from .manager_configs import AllocationMode, KVCacheConfig, MemoryConfig + + +def split_kv(kv_cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split a KV cache instance into its key and value components. + + Parameters: + kv_cache (torch.Tensor): The KV-cache to split. This should be a 5D tensor with the + following shape: [num_blocks, block_size, 2, num_heads, head_size] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The key and value components of the KV-cache. Both + tensors will have the shape [num_blocks, block_size, num_heads, head_size]. + """ + if kv_cache.ndim != 5: + raise ValueError(f"KV-cache must have 5 dimensions, got {kv_cache.ndim}.") + + return kv_cache[:, :, 0, :, :], kv_cache[:, :, 1, :, :] + + +class BlockedKVCache: + + _caches: Tuple[torch.Tensor, ...] + """ + Backing storage for all KV caches. This is a 6D tensor with the following shape: + (num_caches, num_blocks, block_size, 2, num_heads, head_size) + """ + + _allocators: Tuple[BlockedAllocator, ...] + """ + Block allocator for tracking cache usage. This manages the GPU cache. + """ + + _configs: Tuple[KVCacheConfig, ...] + """ + Configuration of the KV cache(s). See ``KVCacheConfig`` for more details. This enables the support + for different types/shapes of KV-caches (i.e. the alternating local and global attention in + GPT-Neo). + """ + + def __init__(self, + configs: Tuple[KVCacheConfig, ...], + memory_config: MemoryConfig, + mp_group: Optional[Any] = None, + offload: bool = False) -> None: + """ + Create a container that will maintain the storage and allocations for a set of + blocked KV-caches. + + Parameters: + config (KVCacheConfig): The configuration of the KV-cache. + slack (int): The amount of slack space to reserve in GPU memory for the cache. + enable_offload (bool): Whether to enable offloading of the cache to the host. + blocks (int): The number of blocks to pre-allocate for the cache. If this is set, + slack will be ignored. + """ + self._configs = configs + self._memory_config = memory_config + self._enable_offload = offload + + if self._enable_offload: + raise NotImplementedError("Offloading of KV-caches is not yet supported.") + + if AllocationMode(self._memory_config.mode) is AllocationMode.RESERVE: + # TODO(cmikeh2): Change the weighting based on the type of the KV-cache + + total_per_block_footprint = 0 + for config in self._configs: + per_block_footprint = reduce(operator.mul, config.cache_shape, config.block_size) + per_block_footprint *= 2 # for key and value + total_per_block_footprint += per_block_footprint * elem_size(config.cache_dtype) + + # Perform a dummy nccl call before calculating available memory, on some systems (H100) we've observed higher memory allocations from NCCL + if dist.get_world_size(group=mp_group) > 1: + dummy_tensor = torch.tensor(0, dtype=torch.int32, device=get_accelerator().current_device()) + dist.all_reduce(dummy_tensor, op=ReduceOp.MIN, group=mp_group) + + get_accelerator().empty_cache() + available_kv_memory = get_accelerator().available_memory() - self._memory_config.size + total_memory = get_accelerator().total_memory() + + inference_logger().debug( + f"Memory usage before KV-cache allocation: total_memory={total_memory}, available_kv_memory={available_kv_memory}, total_per_block_footprint={total_per_block_footprint}" + ) + + if available_kv_memory < total_per_block_footprint: + raise ValueError( + f"Insufficient memory to allocate KV-caches. Required: {total_per_block_footprint}, Available: {available_kv_memory}" + ) + + num_blocks = available_kv_memory // total_per_block_footprint + + # In a multi-process setting, we need to ensure that all processes have the same + # KV cache capacity to ensure scheduling guarantees are equivalent on all ranks. + if dist.get_world_size(group=mp_group) > 1: + reduce_tensor = torch.tensor(num_blocks, dtype=torch.int32, device=get_accelerator().current_device()) + dist.all_reduce(reduce_tensor, op=ReduceOp.MIN, group=mp_group) + num_blocks = reduce_tensor.item() + + # This is ugly but don't want the fragmentation of the 8 byte Tensor maybe + # hanging around. + del reduce_tensor + get_accelerator().empty_cache() + else: # AllocationMode.ALLOCATE + num_blocks = self._memory_config.size + + caches = [] + allocators = [] + + for cache_group_id, config in enumerate(self._configs): + num_caches = config.cache_shape[0] + num_heads = config.cache_shape[1] + head_size = config.cache_shape[2] + + alloc_shape = (num_caches, num_blocks, config.block_size, 2, num_heads, head_size) + inference_logger().info( + f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") + caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, + device=get_accelerator().current_device())) + allocators.append(BlockedAllocator(num_blocks)) + + self._caches = tuple(caches) + self._allocators = tuple(allocators) + + def reserve(self, num_blocks: int, cache_group: int = 0) -> torch.Tensor: + """ + Reserve a number of blocks from the cache. This will return a 1D tensor of + block_ids that have been marked as reserved. + + Parameters: + num_blocks (int): The number of blocks to reserve. + cache_group (int): The cache group to reserve from. Default is 0. + """ + return self._allocators[cache_group].allocate(num_blocks) + + def free(self, blocks: Iterable[int], cache_group: int = 0) -> None: + """ + Free a set of blocks from the cache. This will mark the blocks as free in the + allocator. + + Parameters: + blocks (Iterable[int]): The blocks to free. + cache_group (int): The cache group to free from. Default is 0. + """ + self._allocators[cache_group].free(blocks) + + def offload(self, blocks: Iterable[int], cache_group: int = 0) -> torch.Tensor: + """ + Offload KV-cache blocks from accelerator memory to the host. + + Parameters: + blocks (Iterable[int]): The blocks to offload. + cache_group (int): The cache group to offload from. Default is 0. + """ + raise NotImplementedError("Offloading is not yet supported.") + + def restore(self, blocks: Iterable[int], cache_group: int = 0) -> torch.Tensor: + """ + Restore KV-cache blocks from the host to accelerator memory. + + Parameters: + blocks (Iterable[int]): The blocks to restore. + cache_group (int): The cache group to restore to. Default is 0. + """ + raise NotImplementedError("Offloading is not yet supported.") + + def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor: + """ + Get the tensor associated with the given cache ID. + + Parameters: + cache_id (int): The ID of the cache tensor to get. + cache_group (int): The cache group to get from. Default is 0. + """ + return self._caches[cache_group][cache_id] + + @property + def free_blocks(self) -> torch.Tensor: + """ + Return the number of free blocks in each cache + """ + return [allocator.free_blocks for allocator in self._allocators] + + @property + def num_caches(self) -> int: + """ + Return the number of caches + """ + return len(self._caches) diff --git a/deepspeed/inference/v2/ragged/manager_configs.py b/deepspeed/inference/v2/ragged/manager_configs.py new file mode 100644 index 000000000000..17283b8bc0c4 --- /dev/null +++ b/deepspeed/inference/v2/ragged/manager_configs.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from typing import Tuple + +from pydantic import PositiveInt, model_validator + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..inference_utils import DtypeEnum + + +class KVCacheType(Enum): + + DENSE = "dense" + """ + Dense KV-cache. This is the default type. + """ + + LOCAL = "local" + """ + KV-cache that attends to only a local (trailing) window of tokens. + """ + + +class KVCacheConfig(DeepSpeedConfigModel): + + type: KVCacheType = KVCacheType.DENSE + """ + Type of KV-cache to use. This may inform the allocator of the expected access/retention pattern + to enable more efficient memory management. + """ + + block_size: int = 128 + """ + Number of tokens that may be contained in each cache block. + """ + + num_allocation_groups: PositiveInt = 1 + """ + Allocation groups are assumed to be able to use the same allocation block size because + the allocation granularity is the same but the number of blocks required in each group + may differ. + + As a concrete example, consider a model with alternating layers of local and global + attention (such as GPTNeo). The local attention layers do not require the same number + of cache blocks as the global layer. However, a static partitioning scheme is sub-optimal since the ratio of local to global KV-cache blocks is not constant across + the range of sequence lengths that may be encountered. + + NOTE: In theory, this functionality could be used to do per-head and per-layer + KV-cache allocation, but it is likely the allocator will struggle with managing that + many blocks. + + NOTE: This will need to be primarily understood and handled by the model implementation + itself, rather than the KV cache manager. However, I'd like to make this explicit. + """ + + cache_shape: Tuple[PositiveInt, PositiveInt, PositiveInt] + """ + The shape of the cache per token. The first dimension is the number of individual + caches, the second is the number of heads, and the third is the head size. The number + of caches argument here is per allocation group. + """ + + cache_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type of the KV-cache. + """ + + max_blocks_per_allocation_group: PositiveInt = 64 + """ + Maximum number of blocks that can be associated with an allocation group. + """ + + +""" +The config above is a little confusing so let's use a couple of concrete examples of +usage: + +Model 1: Llama-13B with a block size of 256 + +Llama is uniform attention so we have a single allocation group. The cache shape is +(40 layers, 40 heads, 128 head size) + +```python +llama_kv_config = KVCacheConfig(block_size=256, + num_allocation_groups=1, + cache_shape=(40, 40, 128)) +``` + +Model 2: GPTNeo-2.7B with a block size of 128 + +GPTNeo has alternating local and global attention layers. We have two allocation groups. +There are 16 layers of each type with 20 heads apiece at 128 head size. + +```python +gptneo_kv_config = KVCacheConfig(num_allocation_groups=2, cache_shape=(16, 20, 128)) +``` +""" + + +class AllocationMode(Enum): + """ + Helper class to describe memory allocation strategies for the KV-cache. + """ + + RESERVE = "reserve" + """ + Reserve a small amount of memory for non-KV cache allocations. + """ + + ALLOCATE = "allocate" + """ + Allocate an explicit number of KV blocks. + """ + + +class MemoryConfig(DeepSpeedConfigModel): + + mode: AllocationMode = AllocationMode.RESERVE + + size: PositiveInt = 1_000_000_000 + """ + Parameter for each of the modes. + + If mode is RESERVE, this is the amount of memory in bytes to reserve after allocating the + KV-cache. If in a tensor-parallel regime, this amount is guaranteed to be reserved on + all devices. + + If mode is ALLOCATE, this is the number of blocks to allocate for the KV-cache. This may + require tuning for model/GPU setups. + """ + + +class DSStateManagerConfig(DeepSpeedConfigModel): + + max_tracked_sequences: PositiveInt = 2048 + """ + How many sequences this engine will track simultaneously. This limit should be greater + than the ``max_ragged_sequence_count``. + """ + + max_ragged_batch_size: PositiveInt = 768 + """ + The maximum number of tokens that can be contained in a single ragged batch. Passing + a larger value than this will raise an exception that must be handled by the runtime. + """ + + max_ragged_sequence_count: PositiveInt = 512 + """ + The maximum number of sequences that can compose a batch. This limitation is only + relevant under CUDA graphing scenarios currently, where the maximum number of blocks + is largely bound by the total number of sequences in the ragged batch. This number cannot + be larger than ``max_tracked_sequences`` or ``max_ragged_batch_size``. + """ + + max_context: PositiveInt = 8192 + """ + The maximum number of tokens (inclusive of generation) that can be contained in a single + sequence. Currently used to bound the size of the KV cache metadata. + """ + + memory_config: MemoryConfig = MemoryConfig() + """ + Directive for how to manage the creation of the KV-cache. See MemoryConfig for more + details. + """ + + offload: bool = False + """ + Enable tracking for offloading KV-cache to host memory. Currently unsupported. + """ + + @model_validator(mode="after") + def max_ragged_sequence_count_validator(self): + # If the attributes below failed their validation they won't appear in the values dict. + assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences" + assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size" + return self diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py new file mode 100644 index 000000000000..ecc3c52a5834 --- /dev/null +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import Any, Dict, Optional, Tuple + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import RaggedUtilsBuilder +from deepspeed.utils.logging import logger + +from .blocked_allocator import BlockedAllocator +from .kv_cache import BlockedKVCache +from .manager_configs import DSStateManagerConfig, KVCacheConfig +from .sequence_descriptor import DSSequenceDescriptor + + +class DSStateManager: + """ + Base abstract class for managing blocked KV caches. Will probably have a single + implementation for now. + """ + + _config: DSStateManagerConfig + """ + Config for state management. See DSStateManagerConfig for more details. The arguments here + should come from the engine config. + """ + + _kv_configs: Tuple[KVCacheConfig] + """ + Config for the KV cache. See KVCacheConfig for more details. These arguments should derive + from the model implementation. + """ + + _kv_cache: BlockedKVCache + """ + Persistent KV cache store. + """ + + # Container for tracking all sequences in the system. + _seqs: Dict[int, DSSequenceDescriptor] + """ + Container for tracking all sequences in the system. + + TODO(cmikeh2): Evaluate if this has any performance implications. + """ + + # Allocator for tracking sequences. + _tracking_allocator: BlockedAllocator + _all_block_ids: Tuple[torch.Tensor, ...] + _all_block_ids_shadow: Tuple[torch.Tensor, ...] + + def __init__(self, + config: DSStateManagerConfig, + kv_configs: Tuple[KVCacheConfig, ...], + base_mp_group: Optional[Any] = None) -> None: + """ + The key + + Parameters: + block_size (int): The number of tokens to allocate in each block. + """ + self._config = config + self._kv_configs = kv_configs + + # Load our helpers for host allocation. + self._ragged_utils = RaggedUtilsBuilder().load() + + # Initialize the allocator for tracking sequences (so this doesn't need to be ad-hoc). + self._tracking_allocator = BlockedAllocator(self._config.max_tracked_sequences) + + all_block_ids = [] + all_block_ids_shadow = [] + + for cache_config in self._kv_configs: + # Storage to back tracking the KV cache allocation. + ids_shape = ( + self._config.max_tracked_sequences, + cache_config.num_allocation_groups, + cache_config.max_blocks_per_allocation_group, + ) + + all_block_ids.append(torch.zeros(ids_shape, dtype=torch.int32, device=get_accelerator().current_device())) + all_block_ids_shadow.append(self._ragged_utils.allocate_fast_host_buffer(all_block_ids[-1])) + + self._all_block_ids = tuple(all_block_ids) + self._all_block_ids_shadow = tuple(all_block_ids_shadow) + + # Initialize the sequence container. + self._seqs = {} + + # Finally initialize the KV cache. + self._kv_cache = BlockedKVCache(self._kv_configs, + self._config.memory_config, + mp_group=base_mp_group, + offload=self._config.offload) + + def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor: + """ + Return the Tensor associated with the given cache id in the specified cache group. + + Arguments: + cache_group (str): The KV cache group. + cache_id (int): The cache id within that group. + """ + return self._kv_cache.get_cache(cache_id, cache_group=cache_group) + + def flush_sequence(self, uid: int) -> None: + """ + Free all resources associated with the given sequence id. + """ + if uid not in self._seqs: + logger.warning(f"Attempting to flush sequence {uid} which does not exist.") + return + + seq = self._seqs[uid] + for i in range(self.n_kv_cache_groups): + self._kv_cache.free(seq.all_block_ids(cache_group=i), cache_group=i) + + self._tracking_allocator.free(seq.tracking_id) + del self._seqs[uid] + + def get_sequence(self, uid: int) -> Optional[DSSequenceDescriptor]: + """ + Get the sequence descriptor for the given sequence id. If the sequence does not exist, + then None is returned. + """ + return self._seqs.get(uid, None) + + def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor: + """ + Get the existing sequence descriptor for a given uid or initialize one if + it does not exist. NOTE: This will always return a valid sequence descriptor + if one may be allocated and should not be used from APIs that are attempting + to test the schedulability of a hypothetical batch. + """ + seq = self.get_sequence(uid) + if seq is not None: + return seq + else: + return self._create_sequence(uid) + + def _create_sequence(self, uid: int) -> DSSequenceDescriptor: + """ + Create a new sequence descriptor for the given sequence id. + """ + if uid in self._seqs: + raise ValueError(f"Sequence {uid} already exists.") + + try: + tracking_slot = self._tracking_allocator.allocate(1).item() + except ValueError: + raise RuntimeError( + f"Unable to create tracking slot for sequence {uid} since the metadata buffers are full.") + + seq_block_ids = tuple(all_block_ids[tracking_slot] for all_block_ids in self._all_block_ids) + seq_block_ids_shadow = tuple(all_block_ids_shadow[tracking_slot] + for all_block_ids_shadow in self._all_block_ids_shadow) + + self._seqs[uid] = DSSequenceDescriptor(tracking_slot, + seq_block_ids, + seq_block_ids_shadow, + max_context=self._config.max_context) + # TODO(cmikeh2): Debug call here might be unnecessary and is potentially on critical path. + logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") + return self._seqs[uid] + + @property + def tracked_sequences(self) -> Dict[int, DSSequenceDescriptor]: + """ + Return the tracked sequences. + """ + return self._seqs + + @property + def n_tracked_sequences(self) -> int: + """ + Return the number of sequences currently tracked. + """ + return len(self._seqs) + + @property + def kv_block_size(self) -> int: + """ + Return the block size of the KV cache. + """ + return self._kv_config.block_size + + @property + def n_kv_cache_groups(self) -> int: + """ + Return the number of KV caches. + """ + return self._kv_cache.num_caches + + @property + def free_blocks(self) -> torch.Tensor: + """ + Return the number of free blocks in the KV cache. + """ + return self._kv_cache.free_blocks + + def allocate_blocks(self, n_blocks: int, cache_group: int = 0) -> torch.Tensor: + return self._kv_cache.reserve(n_blocks, cache_group=cache_group) diff --git a/deepspeed/inference/v2/ragged/ragged_wrapper.py b/deepspeed/inference/v2/ragged/ragged_wrapper.py new file mode 100644 index 000000000000..056ecfa2ac40 --- /dev/null +++ b/deepspeed/inference/v2/ragged/ragged_wrapper.py @@ -0,0 +1,292 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from .sequence_descriptor import DSSequenceDescriptor +from .manager_configs import DSStateManagerConfig + + +def to_padded(original_size: int) -> int: + """ + Pad to a backend friendly granularity. + """ + + def _pad_to_mul_of_pow2(val: int, pow_2_val: int) -> int: + return val + (pow_2_val - 1) & ~(pow_2_val - 1) + + # TODO(cmikeh2): Tune this approach. This is mainly a placeholder right now. + granularity = 64 if original_size <= 512 else 128 + + return _pad_to_mul_of_pow2(original_size, granularity) + + +class RaggedBatchWrapper: + """ + Container for all the auxiliary Tensors used in the management of a ragged batch. + + For each Tensor, we maintain a shadow Tensor on the host. This Tensor is what is + directly populated when constructing the ragged batch. The shadow Tensors, when possible, + should be allocated so as to support fast host-to-accelerator copies. + """ + + # Tensors to populate the ragged batch into. + _input_ids_shadow: torch.Tensor + _input_ids: torch.Tensor + """ + Forward pass input buffer. + """ + + _batch_metadata_storage: torch.Tensor + _batch_metadata_storage_shadow: torch.Tensor + """ + Holds the number of inflight sequences and tokens for the ragged batch. + """ + + _token_to_seq_storage: torch.Tensor + _token_to_seq_storage_shadow: torch.Tensor + """ + Linear mapping for each of the tokens. Let's say we have 8 tokens in the batch, + with the sequence breakdown being [4, 1, 3]. Then, the mapping would be: + [0, 0, 0, 0, 1, 2, 2, 2] + """ + + _inflight_seq_descriptors: torch.Tensor + _inflight_seq_descriptors_shadow: torch.Tensor + """ + For each sequence in the batch, we store the start token in the batch, the number of tokens + the number of tokens in the history of this sequence, and an unused 4th reserved for alignment. + For the above example this would give: + [[0, 4, H0, X], [4, 1, H1, X], [5, 3, H2, X]] + """ + + # Holds the block ids for each sequence in the ragged batch. + _kv_ptrs: torch.Tensor + _kv_ptrs_shadow: torch.Tensor + """ + List of ptrs pointing to the GPU buffer that holds the KV-block ids for each sequence. + If there are multiple allocation groups associated with each of the sequences, then + then accessing the Nth cache will require accessing the Nth block id + """ + + def __init__(self, config: DSStateManagerConfig) -> None: + """ + Convenience wrapper around the data structures used to represent a ragged + batch for inference. Only a single `RaggedBatchWrapper` should be used per + ragged inference engine. + + The underlying data structures are implemented in `ragged_batch_descriptor.h`. + """ + self._config = config + self._input_ids = torch.zeros((self._config.max_ragged_batch_size), + dtype=torch.int64, + device=get_accelerator().current_device()) + + self._batch_metadata_storage = torch.zeros(2, dtype=torch.int32, device=get_accelerator().current_device()) + + self._token_to_seq_storage = torch.zeros((self._config.max_ragged_batch_size), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._inflight_seq_descriptors = torch.zeros((self._config.max_ragged_sequence_count, 4), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._kv_ptrs = torch.zeros((self._config.max_ragged_sequence_count), + dtype=torch.int64, + device=get_accelerator().current_device()) + + self._utils_module = RaggedUtilsBuilder().load() + host_alloc = self._utils_module.allocate_fast_host_buffer + + self._input_ids_shadow = host_alloc(self._input_ids) + self._batch_metadata_storage_shadow = host_alloc(self._batch_metadata_storage) + self._token_to_seq_storage_shadow = host_alloc(self._token_to_seq_storage) + self._inflight_seq_descriptors_shadow = host_alloc(self._inflight_seq_descriptors) + self._kv_ptrs_shadow = host_alloc(self._kv_ptrs) + + # Default behavior should be no padding + self._is_padded = False + + self._current_tokens = 0 + self._current_sequences = 0 + self._batch_tokens = [] + self._inflight_seq_descriptors_shadow_buf = [] + self._kv_blocks_ptr_buf = [] + self._token_to_seq_storage_shadow_buf = [] + + def clear(self) -> None: + """ + Clear the ragged batch. This will reset the number of tokens and sequences to 0. + """ + self._current_tokens = 0 + self._current_sequences = 0 + self._batch_tokens = [] + self._inflight_seq_descriptors_shadow_buf = [] + self._kv_blocks_ptr_buf = [] + self._token_to_seq_storage_shadow_buf = [] + + def insert_sequence(self, seq_descriptor: DSSequenceDescriptor, tokens: torch.Tensor, do_checks=True) -> None: + """ + Incrementally insert a sequence into the ragged batch. This will update the + metadata for the ragged batch and the sequence. + + Arguments: + seq_descriptor () + """ + if tokens.device != torch.device("cpu"): + # This doesn't really fall under schedulability, so we'll unconditionally check for it. + raise RuntimeError(f"Expected tokens to be on host but found device '{tokens.device}'") + + if do_checks and self.current_sequences == self._config.max_ragged_sequence_count: + raise RuntimeError(f"Ragged batch is full due to sequence limit: {self._config.max_ragged_sequence_count}") + + seq_tokens = tokens.numel() + + if do_checks and self.current_tokens + seq_tokens > self._config.max_ragged_batch_size: + raise RuntimeError(f"Ragged batch is full due to capacity limit: {self._config.max_ragged_batch_size})") + + # The values in _inflight_seq_descriptors_shadow_buf, _token_to_seq_storage_shadow_buf, _kv_blocks_ptr_buf, etc., + # are ultimately stored in PyTorch tensors: _inflight_seq_descriptors_shadow, _token_to_seq_storage_shadow, _kv_ptrs_shadow, etc. + # However, we found it inefficient to iterate over and substitute values into tensor slices or to use copy/fill calls for this purpose. + # Therefore, we initially store the values in Python lists or primitive data types and then copy them collectively in the finalize() method, + # instead of updating the tensors directly in each iteration. + self._batch_tokens.append(tokens) + self._inflight_seq_descriptors_shadow_buf.append(self.current_tokens) + self._inflight_seq_descriptors_shadow_buf.append(seq_tokens) + self._inflight_seq_descriptors_shadow_buf.append(seq_descriptor.seen_tokens) + self._inflight_seq_descriptors_shadow_buf.append(0) # alignment + + self._token_to_seq_storage_shadow_buf.extend([self.current_sequences] * seq_tokens) + + self._kv_blocks_ptr_buf.append(seq_descriptor.kv_blocks_ptr) + + self._current_tokens += seq_tokens + self._current_sequences += 1 + + @property + def tensor_toks(self) -> torch.Tensor: + """ + The number of tokens in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + cur_toks = self.current_tokens + if self._is_padded: + return to_padded(cur_toks) + else: + return cur_toks + + def finalize(self, padding: Optional[bool] = False) -> None: + """ + Completes construction of the ragged batch by flushing the host buffers to the device. + """ + cur_toks = self.current_tokens + + # Batch-copy the values recorded in insert_sequence() into PyTorch tensors to enhance efficiency. + self._inflight_seq_descriptors_shadow.flatten()[:len(self._inflight_seq_descriptors_shadow_buf)].copy_( + torch.tensor(self._inflight_seq_descriptors_shadow_buf)) + self._input_ids_shadow[:self.current_tokens].copy_(torch.cat(self._batch_tokens, dim=0)) + self._token_to_seq_storage_shadow[:len(self._token_to_seq_storage_shadow_buf)].copy_( + torch.tensor(self._token_to_seq_storage_shadow_buf)) + self._kv_ptrs_shadow[:len(self._kv_blocks_ptr_buf)].copy_(torch.tensor(self._kv_blocks_ptr_buf)) + self._batch_metadata_storage_shadow.copy_(torch.tensor([cur_toks, self.current_sequences])) + + if padding: + padded_toks = to_padded(cur_toks) + self._input_ids_shadow[cur_toks:padded_toks].fill_(-1) + self._token_to_seq_storage_shadow[cur_toks:padded_toks].fill_(-1) + self._is_padded = True + else: + padded_toks = cur_toks + self._is_padded = False + + current_sequences = self.current_sequences + + def _noblock_copy(dst: torch.Tensor, src: torch.Tensor) -> None: + dst.copy_(src, non_blocking=True) + + _noblock_copy(self._input_ids[:padded_toks], self._input_ids_shadow[:padded_toks]) + _noblock_copy(self._batch_metadata_storage, self._batch_metadata_storage_shadow) + _noblock_copy(self._token_to_seq_storage[:padded_toks], self._token_to_seq_storage_shadow[:padded_toks]) + _noblock_copy(self._inflight_seq_descriptors[:current_sequences], + self._inflight_seq_descriptors_shadow[:current_sequences]) + _noblock_copy(self._kv_ptrs[:current_sequences], self._kv_ptrs_shadow[:current_sequences]) + + def input_ids(self, on_device: bool = True) -> torch.Tensor: + """ + The input ids tensor for the ragged batch. If the device Tensor is requested, the Tensor + is truncated to the number of tokens in the batch. + """ + if on_device: + return self._input_ids[:self.tensor_toks] + else: + return self._input_ids_shadow + + def batch_metadata_buffer(self, on_device: bool = True) -> torch.Tensor: + """ + Buffer associated with the batch metadata tensor that can + be populated in preparation for passing a new input to the device. + """ + if on_device: + return self._batch_metadata_storage + else: + return self._batch_metadata_storage_shadow + + def tokens_to_seq(self, on_device: bool = True) -> torch.Tensor: + """ + Mapping of token to which sequence it belongs to in the ragged batch. If the device Tensor + is requested, the Tensor is truncated to the number of tokens in the batch. + """ + if on_device: + return self._token_to_seq_storage[:self.tensor_toks] + else: + return self._token_to_seq_storage_shadow + + def inflight_seq_descriptors(self, on_device: bool = True) -> torch.Tensor: + """ + Buffer associated with the metadata of each sequence in the ragged batch. If the device Tensor + is requested, the Tensor is truncated to the number of sequences in the batch. + """ + if on_device: + return self._inflight_seq_descriptors[:self.current_sequences] + else: + return self._inflight_seq_descriptors_shadow + + def kv_ptrs(self, on_device: bool = True) -> torch.Tensor: + """ + Pointer to where the list of KV ids associated with a sequence are. If the device Tensor + is requested, the Tensor is truncated to the number of sequences in the batch. + """ + if on_device: + return self._kv_ptrs[:self.current_sequences] + else: + return self._kv_ptrs_shadow + + def masks(self, on_device: bool = True) -> Optional[torch.Tensor]: + """ + Placeholder for supporting complex masks. Currently not supported. + + Models that will need this will be BERT-like, not generative. + """ + return None + + @property + def current_tokens(self) -> int: + """ + The number of tokens in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + return self._current_tokens + + @property + def current_sequences(self) -> int: + """ + The number of sequences in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + return self._current_sequences diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py new file mode 100644 index 000000000000..a66bea505b6f --- /dev/null +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple, Union + +import torch + + +class BaseSequenceDescriptor: + + @property + def seen_tokens(self) -> int: + """ + The number of tokens for this sequence that have completed a forward pass. + """ + raise NotImplementedError() + + @property + def cur_allocated_blocks(self, cache_group: int = 0) -> int: + """ + The number of KV blocks currently allocated for this sequence. + """ + raise NotImplementedError() + + @property + def kv_blocks_ptr(self, cache_group: int = 0) -> int: + """ + The pointer to the KV blocks for this sequence. + """ + raise NotImplementedError() + + +class PlaceholderSequenceDescriptor(BaseSequenceDescriptor): + """ + The DummySequenceDescriptor is an empty object that allows us to perform schedulability + checks before formally tracking a sequence. + """ + + def __init__(self, seen_tokens=0, cur_allocated_blocks=0, kv_blocks_ptr=0) -> None: + self._seen_tokens = seen_tokens + self._cur_allocated_blocks = cur_allocated_blocks + self._kv_blocks_ptr = kv_blocks_ptr + + @property + def seen_tokens(self) -> int: + return self._seen_tokens + + @property + def cur_allocated_blocks(self, cache_group: int = 0) -> int: + return self._cur_allocated_blocks + + @property + def kv_blocks_ptr(self, cache_group: int = 0) -> int: + return self._kv_blocks_ptr + + +class DSSequenceDescriptor(BaseSequenceDescriptor): + + _seen_tokens: int + """ + Number of tokens in the sequence that have completed a forward pass. + """ + + _in_flight_tokens: int + """ + Number of tokens that have begun a forward pass but not yet completed it. + """ + + _max_context: int + """ + Maximum number of tokens this sequence may eventually include. Currently unused but + may be used in future implementations for speculative caching. + """ + + _num_allocation_groups: Tuple[int, ...] + """ + Number of unique allocation groups associated with the sequence for each cache group. + """ + + _blocks_per_allocation_group: Tuple[torch.IntTensor, ...] + """ + Number of blocks allocated for each allocation group in each cache group. + """ + + # Padded list of KV-cache IDs for the sequence. + _kv_cache_ids: Tuple[torch.Tensor, ...] + _kv_cache_ids_shadow: Tuple[torch.Tensor, ...] + """ + Padded list of KV-cache IDs for the sequence. The padded shape is [num_allocation_groups, max_blocks_per_allocation_group]. + """ + + # The location in the broader ID tensor where the KV-cache IDs for the sequence + # are stored. Used on flush. + _tracking_id: int + + def __init__(self, + tracking_id: int, + kv_cache_ids: Tuple[torch.Tensor, ...], + kv_cache_ids_shadow: Tuple[torch.Tensor, ...], + max_context: int = -1) -> None: + """ + Create the metadata to track a single sequence in the system. + + Arguments: + tracking_id (int): The slot in the tracking buffers used to track this sequence. + kv_cache_ids (Tuple[torch.Tensor, ...]): The KV-cache IDs for the sequence. The shape + of the tensor should be [num_allocation_groups, max_blocks_per_allocation_group]. + There should be one tensor per cache group. + kv_cache_ids_shadow (Tuple[torch.Tensor, ...]): The shadow tensor for the KV-cache IDs. + This tensor should be allocated on the host and should have the same shape as the + tensor provided in ``kv_cache_ids``. There should be one tensor per cache group. + max_context (int): The maximum number of tokens this sequence may eventually include. + Currently unused but may be used in future implementations for speculative caching. + """ + self._tracking_id = tracking_id + self._kv_cache_ids = kv_cache_ids + self._kv_cache_ids_shadow = kv_cache_ids_shadow + self._max_context = max_context + self._n_cache_groups = len(kv_cache_ids) + + self._seen_tokens = 0 + self._in_flight_tokens = 0 + assert kv_cache_ids_shadow is not None # add check before use + + self._num_allocation_groups = tuple(kv_cache_id.shape[0] for kv_cache_id in kv_cache_ids_shadow) + self._blocks_per_allocation_group = tuple( + torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups) + + for cache_group, kv_cache_ids in enumerate(kv_cache_ids): + assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0] + assert len(kv_cache_ids.shape) == 2 + + @property + def seen_tokens(self) -> int: + """ + Number of tokens in the sequence that have completed a forward pass. + """ + return self._seen_tokens + + @property + def in_flight_tokens(self) -> int: + """ + Number of tokens that have begun a forward pass but not yet completed it. + """ + return self._in_flight_tokens + + @property + def max_context(self) -> int: + """ + Maximum number of tokens for this sequence. Currently unused. + """ + return self._max_context + + @property + def tracking_id(self) -> int: + """ + Return the slot in the tracking buffers used to track this sequence. + """ + return self._tracking_id + + @property + def cur_allocated_blocks(self, cache_group: int = 0) -> int: + """ + Returns the number of blocks currently allocated for this sequence in the specified cache group. + + Arguments: + cache_group (int): The cache group to query. + """ + # Currently, there is only one allocation group. + # A shortcut is used here to bypass the overhead of sum(). + if len(self._blocks_per_allocation_group) == 1: + return self._blocks_per_allocation_group[0].item() + return self._blocks_per_allocation_group[cache_group].sum().item() + + def kv_cache_ids(self, cache_group: int = 0, on_device: bool = False) -> torch.Tensor: + """ + Returns the Tensor containing the block IDs for this sequence on the appropriate device + for the specified cache group. + + Arguments: + cache_group (int): The cache group to query. + on_device (bool): Whether or not to return the Tensor on the device or on the host. + """ + if on_device: + return self._kv_cache_ids[cache_group] + else: + return self._kv_cache_ids_shadow[cache_group] + + @property + def kv_blocks_ptr(self, cache_group: int = 0) -> int: + """ + Get the device pointer to the base of the KV-cache ids for the specified cache group and + sequence. + + Arguments: + cache_group (int): The cache group to query. + """ + return self._kv_cache_ids[cache_group].data_ptr() + + #TODO: this was previously a property but causing issues with PR-4668 need to consult w. Connor + def all_block_ids(self, cache_group: int = 0) -> torch.Tensor: + """ + Return the Tensor containing all block IDs for this sequence in the specified cache group. + + Arguments: + cache_group (int): The cache group to query. + """ + block_ids = [] + for allocation_group, num_blocks in zip(self._kv_cache_ids[cache_group], + self._blocks_per_allocation_group[cache_group]): + block_ids.append(allocation_group[:num_blocks]) + return torch.cat(block_ids) + + def pre_forward(self, num_tokens: int) -> None: + """ + Update the state of the sequence before a forward pass. + + Arguments: + num_tokens (int): The number of tokens in the sequence that will be executed during the + next forward pass of the model. + """ + self._in_flight_tokens = num_tokens + + def post_forward(self) -> None: + """ + Update the state of the sequence after a forward pass. This should be called after the forward + pass completes. NOTE: due to the asynchronous nature of the accelerator, this may be called + before the forward pass completes on the device itself. + """ + self._seen_tokens += self._in_flight_tokens + self._in_flight_tokens = 0 + + def extend_kv_cache(self, new_ids: Union[List[torch.IntTensor], torch.IntTensor], cache_group: int = 0) -> None: + """ + Extend the KV-cache for the sequence. + + Arguments: + new_ids (Union[List[torch.IntTensor], torch.IntTensor]): For each allocation group, the IDs + to add to the KV-cache. If there is only one allocation group, a single tensor can be + provided. Otherwise, a list of tensors should be provided. The tensors do not need + to have the same shape. + """ + if isinstance(new_ids, torch.Tensor): + new_ids = [new_ids] + + if len(new_ids) != self._num_allocation_groups[cache_group]: + raise ValueError( + f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups[cache_group]}") + + for group_id, new_group_ids in enumerate(new_ids): + new_blocks = new_group_ids.numel() + + if new_blocks == 0: + # If we have multiple groups, it's possible to have an empty group. + continue + + shadow_alloc_group = self._kv_cache_ids_shadow[cache_group][group_id] + alloc_group = self._kv_cache_ids[cache_group][group_id] + cur_blocks = self._blocks_per_allocation_group[cache_group][group_id] + + shadow_alloc_group[cur_blocks:cur_blocks + new_blocks].copy_(new_group_ids) + alloc_group[cur_blocks:cur_blocks + new_blocks].copy_(shadow_alloc_group[cur_blocks:cur_blocks + + new_blocks], + non_blocking=True) + + self._blocks_per_allocation_group[cache_group][group_id] += new_blocks + + def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor], cache_group: int = 0) -> None: + """ + Free blocks from the KV-cache for the sequence. + + Arguments: + free_ids (Union[List[torch.IntTensor], torch.IntTensor]): The ids of blocks to free + from the KV-cache. If there is only one allocation group, a single tensor can be + provided. Otherwise, a list of tensors should be provided. The tensors do not need + to have the same shape. + """ + raise NotImplementedError("Partial KV-cache freeing is not yet supported.") diff --git a/deepspeed/inference/v2/scheduling_utils.py b/deepspeed/inference/v2/scheduling_utils.py new file mode 100644 index 000000000000..6d3818d46675 --- /dev/null +++ b/deepspeed/inference/v2/scheduling_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum + + +class SchedulingResult(Enum): + + Success = 0 + """ + The proposed batch is valid and can be scheduled. + """ + + EngineSequenceLimitExceeded = 1 + """ + The proposed batch would would overflow the number of concurrent sequences the engine may support. + """ + + BatchSequenceLimitExceeded = 2 + """ + The proposed batch contains more sequences than the engine was configured + to support in a single forwardp + """ + + BatchTokenLimitExceeded = 3 + """ + The proposed batch contains more tokens than the engine was configured + to support in a single forward. + """ + + KVCacheLimitExceeded = 4 + """ + The proposed batch would require more KV cache to be allocated than the engine + currently has available. + """ + + SequenceTokenLimitExceeded = 5 + """ + The proposed batch contains a sequence that is longer than the engine/model can support. + """ + + +class SchedulingError(RuntimeError): + + result: SchedulingResult + """ + The failed result of the scheduling check. Guaranteed to not be SchedulingResult.Success. + """ + + def __init__(self, result: SchedulingResult) -> None: + self.result = result + super().__init__(f"Batch scheduling failed with result {result}") diff --git a/deepspeed/io/__init__.py b/deepspeed/io/__init__.py new file mode 100644 index 000000000000..aab8353b8e17 --- /dev/null +++ b/deepspeed/io/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .py_file_writer import PyFileWriter +from .fast_file_writer import FastFileWriter, FastFileWriterConfig +from .mock_file_writer import MockFileWriter diff --git a/deepspeed/io/base_file_writer.py b/deepspeed/io/base_file_writer.py new file mode 100644 index 000000000000..ef7d9148450c --- /dev/null +++ b/deepspeed/io/base_file_writer.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .constants import * + +BASE_STAT_KEYS = [ + CLOSE_COUNT_KEY, FILENO_COUNT_KEY, FLUSH_COUNT_KEY, WRITE_COUNT_KEY, WRITE_BYTES_KEY, WRITE_SEC_KEY, + WRITE_SPEED_KEY +] + + +class BaseFileWriter(object): + + def __init__(self, file_path): + self._file_path = file_path + self._stats = {k: 0 for k in BASE_STAT_KEYS} + + def close(self): + pass + + def fileno(self): + pass + + def flush(self): + pass + + def write(self, buffer): + pass + + def file_path(self): + return self._file_path + + def _incr_stats(self, key, incr=1): + self._stats[key] += incr + + def _dump_state(self): + if self._stats[WRITE_SEC_KEY] > 0: + self._stats[WRITE_SPEED_KEY] = (self._stats[WRITE_BYTES_KEY] / self._stats[WRITE_SEC_KEY] / (1024**3)) + state = self._stats + state[FILE_PATH_KEY] = self.file_path() + print(f'stats = {self._stats}') diff --git a/deepspeed/io/base_io_buffer.py b/deepspeed/io/base_io_buffer.py new file mode 100644 index 000000000000..d06a20e01656 --- /dev/null +++ b/deepspeed/io/base_io_buffer.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + + +class Base_IO_Buffer(object): + + def __init__(self, pinned_tensor, dnvme_handle): + assert pinned_tensor.numel() % dnvme_handle.get_alignment() == 0 + self._dnvme_handle = dnvme_handle + self._pinned_tensor = pinned_tensor + + def fill(self, src_tensor, src_offset): + pass + + def drain(self, num_bytes, fd, file_offset): + pass + + def is_empty(self): + pass + + def is_full(self): + pass + + def get_buffer(self): + pass + + def get_offset(self): + pass + + def get_aligned_num_bytes(self): + pass + + def get_unaligned_num_bytes(self): + pass + + def reset(self): + pass + + def complete_ongoing_drain(self): + pass + + def _drain(self, num_bytes, fd, file_offset, blocking=False): + assert num_bytes <= self.get_offset() + assert num_bytes % self._dnvme_handle.get_alignment() == 0 + buffer = self.get_buffer() + r = self._dnvme_handle.async_pwrite(torch.narrow(buffer, 0, 0, num_bytes), fd, file_offset) + assert 0 == r + if blocking: + assert 1 == self._dnvme_handle.wait() + + @staticmethod + def fill_buffer(src_tensor, src_offset, buffer_tensor, buffer_offset): + src_bytes = src_tensor.numel() - src_offset + assert src_bytes > 0 + + dst_bytes = buffer_tensor.numel() - buffer_offset + copy_bytes = min(src_bytes, dst_bytes) + assert (buffer_offset + copy_bytes) <= buffer_tensor.numel() + + if copy_bytes > 0: + src_slice = torch.narrow(src_tensor, 0, src_offset, copy_bytes) + dst_slice = torch.narrow(buffer_tensor, 0, buffer_offset, copy_bytes) + dst_slice.data.copy_(src_slice.data) + + return copy_bytes diff --git a/deepspeed/io/constants.py b/deepspeed/io/constants.py new file mode 100644 index 000000000000..e402a365a0d3 --- /dev/null +++ b/deepspeed/io/constants.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +INVALID_FD = -1 + +FILE_PATH_KEY = 'path' +FLUSH_COUNT_KEY = 'flush' +WRITE_COUNT_KEY = 'write' +CLOSE_COUNT_KEY = 'close' +FILENO_COUNT_KEY = 'fileno' +WRITE_BYTES_KEY = 'bytes' +WRITE_SEC_KEY = 'write_secs' +WRITE_SPEED_KEY = 'write_GB/s' + +AIO_WRITE_SEC_KEY = 'aio_write_secs' +AIO_WRITE_BYTES_KEY = 'aio_bytes' +AIO_SPEED_KEY = 'aio_GB/s' +SLOW_WRITE_BYTES_KEY = 'slow_bytes' +SLOW_WRITE_SEC_KEY = 'slow_write_secs' +AIO_FILL_BUFFER_SEC_KEY = 'fill_buffer_secs' +AIO_FILL_BUFFER_COUNT_KEY = 'fill_buffer_count' +AIO_FILL_BUFFER_SPEED_KEY = 'fill_buffer_GB/s' + +SAVE_STORAGE_KEY = 'save_storage' +SAVE_STORAGE_BYTES_KEY = 'save_storage_bytes' +SAVE_STORAGE_SEC_KEY = 'save_storage_secs' +STORAGE_OBJ_SIZE = 8 + +RANK_KEY = 'rank' diff --git a/deepspeed/io/double_io_buffer.py b/deepspeed/io/double_io_buffer.py new file mode 100644 index 000000000000..0491ff1ac93e --- /dev/null +++ b/deepspeed/io/double_io_buffer.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from .base_io_buffer import Base_IO_Buffer + +NUM_BUFFERS = 2 +INVALID_BUFFER_INDEX = -1 + + +class Double_IO_Buffer(Base_IO_Buffer): + + def __init__(self, pinned_tensor, dnvme_handle): + super(Double_IO_Buffer, self).__init__(pinned_tensor, dnvme_handle) + assert self._pinned_tensor.numel() % (NUM_BUFFERS * self._dnvme_handle.get_alignment()) == 0 + self._buffers = self._split_buffer() + self._fill_index = 0 + self._drain_index = INVALID_BUFFER_INDEX + self._buffer_offset = 0 + + def fill(self, src_tensor, src_offset): + self._validate_buffer_index(self._fill_index) + copy_bytes = Base_IO_Buffer.fill_buffer(src_tensor, src_offset, self._buffers[self._fill_index], + self._buffer_offset) + self._buffer_offset += copy_bytes + return copy_bytes + + def drain(self, num_bytes, fd, file_offset): + self._validate_buffer_index(self._fill_index) + self.complete_ongoing_drain() + assert self._drain_index == INVALID_BUFFER_INDEX + self._drain(num_bytes, fd, file_offset, blocking=False) + self._drain_index = self._fill_index + self._fill_index = (self._fill_index + 1) % NUM_BUFFERS + self._buffer_offset = 0 + + def get_buffer(self): + self._validate_buffer_index(self._fill_index) + return self._buffers[self._fill_index] + + def get_offset(self): + self._validate_buffer_index(self._fill_index) + return self._buffer_offset + + def get_aligned_num_bytes(self): + self._validate_buffer_index(self._fill_index) + aligned_size = self._dnvme_handle.get_alignment() + return (self._buffer_offset // aligned_size) * aligned_size + + def get_unaligned_num_bytes(self): + self._validate_buffer_index(self._fill_index) + return self._buffer_offset % self._dnvme_handle.get_alignment() + + def is_full(self): + self._validate_buffer_index(self._fill_index) + return self._buffer_offset == self._buffers[self._fill_index].numel() + + def is_empty(self): + self._validate_buffer_index(self._fill_index) + return self._buffer_offset == 0 and not self._is_ongoing_drain() + + def reset(self): + self._buffer_offset = 0 + + def complete_ongoing_drain(self): + if self._is_ongoing_drain(): + self._wait_for_drain() + + def _split_buffer(self): + buffer_size = self._pinned_tensor.numel() // NUM_BUFFERS + return [torch.narrow(self._pinned_tensor, 0, (i * buffer_size), buffer_size) for i in range(NUM_BUFFERS)] + + def _validate_buffer_index(self, index): + assert index in [0, 1] + + def _wait_for_drain(self): + self._validate_buffer_index(self._drain_index) + assert 1 == self._dnvme_handle.wait() + self._drain_index = INVALID_BUFFER_INDEX + + def _is_ongoing_drain(self): + return self._drain_index != INVALID_BUFFER_INDEX diff --git a/deepspeed/io/fast_file_writer.py b/deepspeed/io/fast_file_writer.py new file mode 100644 index 000000000000..fd4470de571b --- /dev/null +++ b/deepspeed/io/fast_file_writer.py @@ -0,0 +1,273 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +import time +from dataclasses import dataclass + +from .constants import * +from .base_file_writer import BaseFileWriter +from .single_io_buffer import Single_IO_Buffer +from .double_io_buffer import Double_IO_Buffer +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.accelerator import get_accelerator + +from .utils import (tensor_to_bytes, bytes_to_tensor, obj_serialization_details) + +FASTIO_STAT_KEYS = [ + AIO_WRITE_SEC_KEY, + AIO_WRITE_BYTES_KEY, + AIO_SPEED_KEY, + SLOW_WRITE_BYTES_KEY, + SLOW_WRITE_SEC_KEY, + AIO_FILL_BUFFER_COUNT_KEY, + AIO_FILL_BUFFER_SEC_KEY, + AIO_FILL_BUFFER_SPEED_KEY, + SAVE_STORAGE_KEY, + SAVE_STORAGE_BYTES_KEY, +] + + +@dataclass +class FastFileWriterConfig: + dnvme_handle: object + pinned_tensor: torch.Tensor + double_buffer: bool = True + num_parallel_writers: int = 1 + writer_rank: int = 0 + global_rank: int = 0 + + +class FastFileWriter(BaseFileWriter): + + def __init__(self, file_path, config): + super(FastFileWriter, self).__init__(file_path) + self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY) + self._dnvme_handle = config.dnvme_handle + self._file_offset = 0 + io_buffer_type = Double_IO_Buffer if config.double_buffer else Single_IO_Buffer + self._io_buffer = io_buffer_type(config.pinned_tensor, self._dnvme_handle) + self._cast_to_byte_tensor = UtilsBuilder().load().cast_to_byte_tensor + self._get_serialization_details = obj_serialization_details() + self._num_parallel_writers = config.num_parallel_writers + self._writer_rank = config.writer_rank + self._global_rank = config.global_rank + + for k in FASTIO_STAT_KEYS: + self._stats[k] = 0 + + def write(self, buffer): + assert self._file_offset % self._dnvme_handle.get_alignment() == 0 + buffer_num_bytes = len(buffer) + num_written_bytes = self._write_from_tensor(bytes_to_tensor(buffer)) + assert buffer_num_bytes == num_written_bytes + return buffer_num_bytes + + def split_index_list(self, storage_obj_list, num_splits): + assert num_splits > 0 + split_list = [-1] * num_splits + # t[0] is data, t[1] is data_type + tensor_bytes_list = [len(t[0]) for t in storage_obj_list] + print(tensor_bytes_list) + total_bytes = sum(tensor_bytes_list) + bytes_per_group = total_bytes / num_splits + split_counter = 0 + tmp_size = 0 + for i in range(len(tensor_bytes_list)): + tmp_size += tensor_bytes_list[i] + if tmp_size > bytes_per_group: + split_list[split_counter] = i + tmp_size = 0 + split_counter += 1 + if split_list[num_splits - 1] == -1: + split_list[num_splits - 1] = len(tensor_bytes_list) + return split_list + + def save_torch_storage_object_list(self, storage_obj_list, save_size): + assert self._file_offset % self._dnvme_handle.get_alignment() == 0 + num_bytes_written = self._save_storage_list(storage_obj_list, save_size) + return num_bytes_written + + def close(self): + self._fini() + self._incr_stats(CLOSE_COUNT_KEY) + + def fileno(self): + self._incr_stats(FILENO_COUNT_KEY) + return INVALID_FD # self._aio_fd + + def flush(self): + self._incr_stats(FLUSH_COUNT_KEY) + + def __del__(self): + self._fini() + assert self._aio_fd == INVALID_FD + assert self._io_buffer.get_offset() == 0, \ + f'__del__ assert: pinned_offset {self._io_buffer.get_offset()} != 0' + assert self._file_offset == self._stats[WRITE_BYTES_KEY], \ + f'__del__ assert: file_offset != write_bytes - {self._file_offset} != {self._stats[WRITE_BYTES_KEY]}' + + def _fini(self): + if not self._io_buffer_is_empty(): + self._force_drain() + self._io_buffer.reset() + fd = self._aio_fd + self._aio_fd = INVALID_FD + if fd != INVALID_FD: + try: + os.fsync(fd) + finally: + os.close(fd) + + def _fill_io_buffer(self, src_tensor, src_offset): + st = time.time() + copy_bytes = self._io_buffer.fill(src_tensor, src_offset) + self._incr_stats(AIO_FILL_BUFFER_SEC_KEY, time.time() - st) + self._incr_stats(AIO_FILL_BUFFER_COUNT_KEY) + return copy_bytes + + def _drain_io_buffer(self, num_bytes): + st = time.time() + self._io_buffer.drain(num_bytes, self._aio_fd, self._file_offset) + self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st) + self._incr_stats(AIO_WRITE_BYTES_KEY, num_bytes) + self._file_offset += num_bytes + + def _io_buffer_is_full(self): + return self._io_buffer.is_full() + + def _io_buffer_is_empty(self): + return self._io_buffer.is_empty() + + def _force_drain(self): + st = time.time() + aligned_num_bytes = self._io_buffer.get_aligned_num_bytes() + # Important to retrieve unaligned drain bytes and tensor before doing aligned drain because of the side effects. + # TODO: Need to eliminate this dependency + unaligned_num_bytes = self._io_buffer.get_unaligned_num_bytes() + unaligned_tensor = torch.narrow(self._io_buffer.get_buffer(), 0, aligned_num_bytes, unaligned_num_bytes) + + if aligned_num_bytes > 0: + self._drain_io_buffer(aligned_num_bytes) + + self._io_buffer.complete_ongoing_drain() + self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st) + + if unaligned_num_bytes > 0: + self._unaligned_drain(unaligned_tensor) + self._incr_stats(WRITE_SEC_KEY, time.time() - st) + + def _unaligned_drain(self, unaligned_tensor): + os.close(self._aio_fd) + st = time.time() + fp = open(self._file_path, 'ab') + fp.write(tensor_to_bytes(unaligned_tensor.cpu())) + fp.close() + self._file_offset += unaligned_tensor.numel() + self._incr_stats(SLOW_WRITE_SEC_KEY, time.time() - st) + self._incr_stats(SLOW_WRITE_BYTES_KEY, unaligned_tensor.numel()) + self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_WRONLY | os.O_APPEND) + + def _dump_state(self): + if self._stats[AIO_WRITE_SEC_KEY] > 0: + self._stats[AIO_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / self._stats[AIO_WRITE_SEC_KEY] / + (1024**3)) + if self._stats[AIO_FILL_BUFFER_SEC_KEY] > 0: + self._stats[AIO_FILL_BUFFER_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / + self._stats[AIO_FILL_BUFFER_SEC_KEY] / (1024**3)) + super()._dump_state() + + def _update_write_stats(self, num_bytes, secs_latency): + self._incr_stats(WRITE_COUNT_KEY) + self._incr_stats(WRITE_BYTES_KEY, num_bytes) + self._incr_stats(WRITE_SEC_KEY, secs_latency) + + def _write_from_tensor(self, buffer_tensor): + st = time.time() + buffer_offset = 0 + while (buffer_offset < buffer_tensor.numel()): + num_copied_bytes = self._fill_io_buffer(buffer_tensor, buffer_offset) + if self._io_buffer_is_full(): + self._drain_io_buffer(self._io_buffer.get_offset()) + buffer_offset += num_copied_bytes + + self._update_write_stats(buffer_offset, time.time() - st) + + return buffer_offset + + def _save_storage_list(self, obj_list, save_size): + byte_tensor_list, byte_tensor_nbytes = self._convert_to_byte_tensors(obj_list, save_size) + if self._num_parallel_writers > 1: + my_byte_tensor_list = self._partition_byte_tensors(byte_tensor_list, byte_tensor_nbytes, + self._num_parallel_writers, self._writer_rank) + else: + my_byte_tensor_list = byte_tensor_list + + num_object_bytes_written = 0 + for byte_tensor in my_byte_tensor_list: + num_object_bytes_written += self._write_from_tensor(byte_tensor) + + self._incr_stats(SAVE_STORAGE_KEY, len(obj_list)) + self._incr_stats(SAVE_STORAGE_BYTES_KEY, num_object_bytes_written) + return num_object_bytes_written + + # Convert list of storage objects into list of byte tensors of object and size bytes + def _convert_to_byte_tensors(self, obj_list, save_size): + tensor_list = [] + num_bytes = 0 + for storage_obj in obj_list: + details = self._get_serialization_details(storage_obj) + if save_size: + tensor_list.append( + torch.tensor( + details.size, + dtype=torch.int64, + ).to(get_accelerator().device_name())) + tensor_list.append(torch.empty(0, dtype=details.dtype, device=details.obj.device).set_(details.obj)) + num_bytes += details.nbytes + if save_size: + num_bytes += STORAGE_OBJ_SIZE * len(obj_list) + + return self._cast_to_byte_tensor(tensor_list), num_bytes + + def _partition_byte_tensors(self, byte_tensor_list, byte_tensor_nbytes, num_ranks, my_rank): + assert my_rank >= 0, f'Invalid for rank number to be negative: {my_rank}' + assert num_ranks > my_rank, f'Number of ranks {num_ranks} must be greater than rank {my_rank}' + + partition_size = int(byte_tensor_nbytes // num_ranks) + num_remainder_bytes = byte_tensor_nbytes % num_ranks + if num_remainder_bytes == 0: + partition_start = partition_size * my_rank + else: + # Spread extra bytes evenly among early ranks + if num_remainder_bytes > my_rank: + partition_size += 1 + partition_start = partition_size * my_rank + else: + # Account for allocation of extra bytes to earlier ranks + partition_start = (partition_size * my_rank) + num_remainder_bytes + + partition_end = min(partition_start + partition_size, byte_tensor_nbytes) + partition_tensor_list = [] + current_offset = 0 + for byte_tensor in byte_tensor_list: + byte_tensor_end = current_offset + byte_tensor.numel() + if current_offset < partition_end and byte_tensor_end > partition_start: + fragment_start = max(current_offset, partition_start) + fragment_end = min(byte_tensor_end, partition_end) + assert fragment_start < fragment_end, \ + f'fragment start {fragment_start} should be < fragment_end {fragment_end}' + + fragment_numel = fragment_end - fragment_start + partition_tensor_list.append(byte_tensor.narrow(0, fragment_start - current_offset, fragment_numel)) + + current_offset += byte_tensor.numel() + + actual_partition_nbytes = sum([t.numel() for t in partition_tensor_list]) + assert actual_partition_nbytes == partition_size, \ + f'Incorrect partition bytes for rank {my_rank}, expected = {partition_size} actual = {actual_partition_nbytes}' + + return partition_tensor_list diff --git a/deepspeed/io/mock_file_writer.py b/deepspeed/io/mock_file_writer.py new file mode 100644 index 000000000000..5957ad771974 --- /dev/null +++ b/deepspeed/io/mock_file_writer.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .constants import * +from .base_file_writer import BaseFileWriter +from .utils import obj_serialization_details + + +class MockFileWriter(BaseFileWriter): + + def __init__(self, file_path): + super(MockFileWriter, self).__init__(file_path) + self._fp = open(file_path, 'wb') + self._stats[SAVE_STORAGE_KEY] = 0 + self._stats[SAVE_STORAGE_BYTES_KEY] = 0 + self._get_serialization_details = obj_serialization_details() + + def close(self): + self._incr_stats(CLOSE_COUNT_KEY) + self._fp.close() + + def fileno(self): + self._incr_stats(FILENO_COUNT_KEY) + return INVALID_FD # self._fp.fileno() + + def flush(self): + self._incr_stats(FLUSH_COUNT_KEY) + self._fp.flush() + + def write(self, buffer): + return self._write(len(buffer)) + + def save_torch_storage_object_list(self, storage_obj_list, save_size): + num_bytes = sum([self._save_torch_storage_object(obj, save_size) for obj in storage_obj_list]) + return num_bytes + + def _save_torch_storage_object(self, storage_obj, save_size): + details = self._get_serialization_details(storage_obj) + self._incr_stats(SAVE_STORAGE_KEY) + self._incr_stats(SAVE_STORAGE_BYTES_KEY, details.size) + num_written_bytes = self._write(STORAGE_OBJ_SIZE) if save_size else 0 + return num_written_bytes + self._write(details.size) + + def _write(self, num_bytes): + self._incr_stats(WRITE_COUNT_KEY) + self._incr_stats(WRITE_BYTES_KEY, num_bytes) + return num_bytes diff --git a/deepspeed/io/py_file_writer.py b/deepspeed/io/py_file_writer.py new file mode 100644 index 000000000000..51849cb89c8a --- /dev/null +++ b/deepspeed/io/py_file_writer.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +from .constants import * +from .base_file_writer import BaseFileWriter + + +class PyFileWriter(BaseFileWriter): + + def __init__(self, file_path): + super(PyFileWriter, self).__init__(file_path) + self._fp = open(file_path, 'wb') + + def close(self): + self._incr_stats(CLOSE_COUNT_KEY) + self._fp.close() + + def fileno(self): + self._incr_stats(FILENO_COUNT_KEY) + return INVALID_FD # self._fp.fileno() + + def flush(self): + self._incr_stats(FLUSH_COUNT_KEY) + self._fp.flush() + + def write(self, buffer): + st = time.time() + self._fp.write(buffer) + self._incr_stats(WRITE_SEC_KEY, time.time() - st) + self._incr_stats(WRITE_COUNT_KEY) + self._incr_stats(WRITE_BYTES_KEY, len(buffer)) + return len(buffer) diff --git a/deepspeed/io/single_io_buffer.py b/deepspeed/io/single_io_buffer.py new file mode 100644 index 000000000000..57fbfe38c45d --- /dev/null +++ b/deepspeed/io/single_io_buffer.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base_io_buffer import Base_IO_Buffer + + +class Single_IO_Buffer(Base_IO_Buffer): + + def __init__(self, pinned_tensor, dnvme_handle): + super(Single_IO_Buffer, self).__init__(pinned_tensor, dnvme_handle) + self._pinned_offset = 0 + + def fill(self, src_tensor, src_offset): + copy_bytes = Base_IO_Buffer.fill_buffer(src_tensor, src_offset, self._pinned_tensor, self._pinned_offset) + self._pinned_offset += copy_bytes + return copy_bytes + + def drain(self, num_bytes, fd, file_offset): + self._drain(num_bytes, fd, file_offset, blocking=True) + self._pinned_offset = 0 + + def get_buffer(self): + return self._pinned_tensor + + def get_offset(self): + return self._pinned_offset + + def get_aligned_num_bytes(self): + aligned_size = self._dnvme_handle.get_alignment() + return (self._pinned_offset // aligned_size) * aligned_size + + def get_unaligned_num_bytes(self): + return self._pinned_offset % self._dnvme_handle.get_alignment() + + def is_full(self): + return self._pinned_offset == self._pinned_tensor.numel() + + def is_empty(self): + return self._pinned_offset == 0 + + def reset(self): + self._pinned_offset = 0 diff --git a/deepspeed/io/utils.py b/deepspeed/io/utils.py new file mode 100644 index 000000000000..f811aeb84577 --- /dev/null +++ b/deepspeed/io/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy +import torch +from dataclasses import dataclass + + +@dataclass +class serialize_details: + obj: object + dtype: torch.dtype + size: int + nbytes: int + + +def tensor_to_bytes(tensor): + return tensor.numpy().tobytes() + + +def bytes_to_tensor(buffer): + return torch.from_numpy(numpy.array(numpy.frombuffer(buffer, dtype=numpy.uint8))) + + +def required_minimum_torch_version(major_version, minor_version): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR < major_version: + return False + + return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version + + +# torch < 1.12 +def _legacy_obj_serialization_details(storage_obj): + nbytes = storage_obj.element_size() * storage_obj.size() + return serialize_details(obj=storage_obj, dtype=storage_obj.dtype, size=nbytes, nbytes=nbytes) + + +# torch >= 1.12 +def _new_obj_serialization_details(storage_obj): + obj, dtype = storage_obj + return serialize_details(obj=obj, + dtype=dtype, + size=obj.size() // torch._utils._element_size(dtype), + nbytes=obj.size()) + + +def obj_serialization_details(): + if required_minimum_torch_version(1, 12): + return _new_obj_serialization_details + + return _legacy_obj_serialization_details diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index 838dd66fcb30..366ae0b236f8 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -8,6 +8,7 @@ OPENMPI_LAUNCHER = 'openmpi' MPICH_LAUNCHER = 'mpich' +IMPI_LAUNCHER = 'impi' SLURM_LAUNCHER = 'slurm' MVAPICH_LAUNCHER = 'mvapich' MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index d84fe9586b96..7d177501064d 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -19,13 +19,13 @@ import time import signal import psutil -import distutils from collections import defaultdict from typing import Dict from argparse import ArgumentParser, REMAINDER -from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE +from deepspeed.accelerator import get_accelerator from ..nebula.constants import DLTS_POD_ENV_PATH -from ..utils import logger +from ..utils import logger, get_numactl_cmd, set_log_level_from_string from ..elasticity import is_torch_elastic_compatible from .constants import ELASTIC_TRAINING_ID_DEFAULT @@ -102,6 +102,18 @@ def parse_args(): "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not " "specified, all cores on system would be used rank binding") + # TODOV1: change the default to 'warning' + parser.add_argument("--log_level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical'], + help="Set launcher loglevel. The default is 'info'") + + parser.add_argument("-q", + "--quiet", + action="store_true", + help="Try to be as quiet as possible. Aliases to `--log_level error`") + # positional parser.add_argument("training_script", type=str, @@ -130,93 +142,14 @@ def terminate_process_tree(pid): p.kill() -def parse_range(rng): - try: - value = int(rng) - return range(value, value + 1) - except ValueError: - # value is not a single number - parts = rng.split('-') - if len(parts) != 2: - raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" % - (rng, )) - start = int(parts[0]) - end = int(parts[1]) - if start > end: - raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, )) - return range(start, end + 1) - - -# parse comma and dash separated range list into list -# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6] -# rules: -# 1. Range list numser be comma sepeaated, each item are either a single number, -# or a range marked by two numbers (both number are included in the range) -# 2. Sub ranges must be in ascend order and not overlap with each other -# 3. No space in the range expression -def parse_range_list(range_str): - number_list = [] - last = -1 - range_list = range_str.split(',') - for sub_range in range_list: - sub_number_list = parse_range(sub_range) - if sub_number_list[0] <= last: - raise ValueError( - "Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" % - (range_str, )) - last = sub_number_list[-1] - number_list.extend(sub_number_list) - return number_list - - -# return a list of list for cores to numa mapping -# [ -# [ cores for numa 0 ] -# [ cores belong to numa 1 ] -# ... -# ] -def get_numa_cores(): - ret = [] - output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8") - lines = output.split('\n') - for line in lines: - if line.startswith('available:'): - num_numas = int(line.split(' ')[1]) - break - for numa in range(num_numas): - for line in lines: - if line.startswith(f'node {numa} cpus:'): - cores = line.split(' ')[3:] - ret.append([int(core) for core in cores]) - return ret - - -def check_for_numactl_pkg(): - libs = dict( - dpkg=["-l", "numactl", "apt"], - pacman=["-Q", "numactl", "pacman"], - rpm=["-q", "numactl", "yum"], - ) - - found = False - for pkgmgr, data in libs.items(): - flag, lib, tool = data - path = distutils.spawn.find_executable(pkgmgr) - if path is not None: - cmd = f"{pkgmgr} {flag} {lib}" - result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) - if result.wait() == 0: - found = True - else: - print(f"please install the {lib} package with {tool}") - break - return found - - def main(): args = parse_args() current_env = os.environ.copy() + if args.quiet: + args.log_level = "error" + set_log_level_from_string(args.log_level) + for k in current_env.keys(): if "NCCL" in k: logger.info(f"{args.node_rank} {k}={current_env[k]}") @@ -230,8 +163,8 @@ def main(): node_list = list(world_info.keys()) args.nnodes = len(node_list) local_node = node_list[args.node_rank] - local_gpu_ids = world_info[local_node] - num_local_procs = len(local_gpu_ids) + local_accelerator_ids = world_info[local_node] + num_local_procs = len(local_accelerator_ids) logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}") global_rank_mapping = defaultdict(list) @@ -245,15 +178,17 @@ def main(): curr_global_rank += 1 logger.info(f"global_rank_mapping={global_rank_mapping}") logger.info(f"dist_world_size={dist_world_size}") - current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids)) - logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}") + + get_accelerator().set_visible_devices_envs(current_env, local_accelerator_ids) + for env in get_accelerator().visible_devices_envs(): + logger.info(f"Setting {env}={current_env[env]}") # set PyTorch distributed related environmental variables current_env["MASTER_ADDR"] = args.master_addr current_env["MASTER_PORT"] = str(args.master_port) current_env["WORLD_SIZE"] = str(dist_world_size) - current_env["CROSS_RANK"] = str(args.node_rank) - current_env["CROSS_SIZE"] = str(args.nnodes) + current_env[CROSS_RANK] = str(args.node_rank) + current_env[CROSS_SIZE] = str(args.nnodes) current_env["LOCAL_SIZE"] = str(num_local_procs) if args.save_pid: @@ -269,7 +204,7 @@ def main(): if not is_torch_elastic_compatible(): if args.enable_elastic_training: - logger.info(f"Disabling elastic training support as \ + logger.info("Disabling elastic training support as \ PyTorch version should be greater than 1.11.x") args.enable_elastic_training = False @@ -299,48 +234,19 @@ def main(): raise ValueError(f"unable to create directory {args.enable_each_rank_log} for each rank log.") log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for local_rank in range(0, num_local_procs): + for local_proc in range(0, num_local_procs): # each process's rank - dist_rank = global_rank_mapping[local_node][local_rank] + dist_rank = global_rank_mapping[local_node][local_proc] + local_rank = dist_rank % num_local_procs current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes cmd = [] if args.bind_cores_to_rank: - check_for_numactl_pkg() - if 'KMP_AFFINITY' in os.environ.keys(): - raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl " - "because it interfere with how many CPU cores numactl can set. " - "Unset KMP_AFFINITY before launching deepspeed.\n\n" - "\t$ unset KMP_AFFINITY\n" - "\t$ deepspeed ") - if args.bind_core_list != None: - core_list = parse_range_list(args.bind_core_list) - total_cores = len(core_list) - else: - total_cores = psutil.cpu_count(logical=False) - core_list = range(total_cores) - cores_per_rank = total_cores // num_local_procs - assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank" - core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)] + cores_per_rank, numactl_cmd = get_numactl_cmd(args.bind_core_list, num_local_procs, local_rank) current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}" - cmd.append("numactl") - - # check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter - numa_cores = get_numa_cores() - num_numas = len(numa_cores) - for i in range(num_numas): - if set(core_list_for_rank) <= set(numa_cores[i]): - cmd.append("-m") - cmd.append(f"{i}") - break - - cmd.append("-C") - core_list_str = f"{core_list_for_rank[0]}" - for core_id in core_list_for_rank[1:]: - core_list_str = f"{core_list_str},{core_id}" - cmd.append(f"{core_list_str}") + cmd = cmd + numactl_cmd if not args.no_python: cmd.append(sys.executable) cmd.append("-u") @@ -362,7 +268,8 @@ def main(): process = subprocess.Popen(cmd, env=current_env, stdout=log_fd, stderr=log_fd) else: process = subprocess.Popen(cmd, env=current_env) - + # logs the command from processes + logger.info(f"process {process.pid} spawned with command: {cmd}") processes.append(process) else: from ..elasticity import DSElasticAgent diff --git a/deepspeed/launcher/launcher_helper.py b/deepspeed/launcher/launcher_helper.py new file mode 100644 index 000000000000..b7d65f896820 --- /dev/null +++ b/deepspeed/launcher/launcher_helper.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import sys +import argparse +import subprocess +from deepspeed.utils import logger +from deepspeed.launcher.constants import MPICH_LAUNCHER + + +def parse_args(args=None): + parser = argparse.ArgumentParser(description="DeepSpeed launcher helper to map environment variables for" + "multi-node/multi-gpu training jobs.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--launcher", + default=MPICH_LAUNCHER, + type=str, + help="(optional) choose launcher backend for multi-node " + "training. Options currently include MPICH.") + + parser.add_argument("--module", + action="store_true", + help="Change each process to interpret the launch " + "script as a Python module, executing with the same " + "behavior as 'python -m'.") + + parser.add_argument("--no_python", + action="store_true", + help="Skip prepending the training script with " + "'python' - just execute it directly.") + + parser.add_argument("user_script", type=str, help="User script to launch, followed by any required " + "arguments.") + + parser.add_argument('user_args', nargs=argparse.REMAINDER) + + parser.add_argument("--bind_cores_to_rank", + action="store_true", + help="Bind each rank to different cores of the host") + + parser.add_argument("--bind_core_list", + type=str, + default=None, + help="List of cores to bind to with comma separated list of " + "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not " + "specified, all cores on system would be used rank binding") + + return parser.parse_args(args=args) + + +def env_mapping(env, rank_name_list=None, local_rank_name_list=None): + rank = None + for rank_name in rank_name_list: + if rank_name in env: + if rank == None: + rank = env.get(rank_name) + elif rank != env.get(rank_name): + raise EnvironmentError("rank number doesn't match!") + if rank == None: + raise EnvironmentError("rank number is not in current env!") + env['RANK'] = rank + + local_rank = None + for local_rank_name in local_rank_name_list: + if local_rank_name in env: + if local_rank == None: + local_rank = env.get(local_rank_name) + elif local_rank != env.get(local_rank_name): + raise EnvironmentError("local_rank number doesn't match!") + if local_rank == None: + raise EnvironmentError("rank number is not in current env!") + env['LOCAL_RANK'] = local_rank + + return env + + +def main(args=None): + args = parse_args(args) + + env = os.environ.copy() + + args.launcher = args.launcher.lower() + if args.launcher == MPICH_LAUNCHER: + rank_name_list = ["PMIX_RANK"] + ["PMI_RANK"] + local_rank_name_list = ["PALS_LOCAL_RANKID"] + ["MPI_LOCALRANKID"] + env = env_mapping(env, rank_name_list=rank_name_list, local_rank_name_list=local_rank_name_list) + else: + raise NotImplementedError(f"Unknown launcher {args.launcher}") + + python_exec = [] + if not args.no_python: + python_exec += [sys.executable, "-u"] + if args.module: + python_exec.append("-m") + cmd = python_exec + [args.user_script] + args.user_args + + logger.info(f"launcher_helper cmd = {' '.join(cmd)}") + + result = subprocess.Popen(cmd, env=env, close_fds=False) + result.wait() + + +if __name__ == "__main__": + main() diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index f974c8daf946..5171765f48cd 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -8,10 +8,11 @@ import shutil import subprocess import warnings +import re from shlex import split from abc import ABC, abstractmethod from deepspeed.accelerator import get_accelerator -from ..utils import logger +from ..utils import logger, get_numactl_cmd from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE @@ -34,7 +35,10 @@ def get_cmd(self, environment, active_resources): """Return the command to execute on node""" def add_export(self, key, var): - self.exports[key.strip()] = var.strip() + var = var.strip() + if re.search(r'[^\w@%+=:,./-]', var): + var = f"\"{var}\"" + self.exports[key.strip()] = var def parse_user_args(self): return self.args.user_args @@ -56,15 +60,26 @@ def __init__(self, args, world_info_base64): def backend_exists(self): return shutil.which('pdsh') + def parse_user_args(self): + processed_args = [] + for arg in self.args.user_args: + # With pdsh, if we are passing a string as an argument, it will get + # split on whitespace. To avoid this and support strings that + # contain '"', we do this extra processing step: + if " " in arg: + arg = '"{}"'.format(arg.replace('"', '\\"')) + processed_args.append(arg) + return processed_args + @property def name(self): return "pdsh" - def parse_user_args(self): - return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) - def get_cmd(self, environment, active_resources): environment['PDSH_RCMD_TYPE'] = 'ssh' + if self.args.ssh_port is not None: # only specify ssh port if it is specified + environment["PDSH_SSH_ARGS_APPEND"] = f"{environment.get('PDSH_SSH_ARGS_APPEND', '')} \ + -p {self.args.ssh_port}" active_workers = ",".join(active_resources.keys()) logger.info("Running on the following workers: %s" % active_workers) @@ -85,6 +100,8 @@ def get_cmd(self, environment, active_resources): f'--world_info={self.world_info_base64}', "--node_rank=%n", f"--master_addr={self.args.master_addr}", f"--master_port={self.args.master_port}" ] + if self.args.venv_script is not None: + deepspeed_launch = [f"source {self.args.venv_script};"] + deepspeed_launch if self.args.no_python: deepspeed_launch.append("--no_python") if self.args.module: @@ -93,6 +110,8 @@ def get_cmd(self, environment, active_resources): deepspeed_launch.append("--no_local_rank") if self.args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if self.args.enable_each_rank_log: + deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}") if self.args.elastic_training: deepspeed_launch.append("--enable_elastic_training") deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}") @@ -101,7 +120,7 @@ def get_cmd(self, environment, active_resources): cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]] kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]] - return pdsh_cmd_args + deepspeed_launch + [self.user_script] + self.user_arguments, kill_command + return pdsh_cmd_args + deepspeed_launch + [self.user_script] + self.user_arguments, kill_command, environment class OpenMPIRunner(MultiNodeRunner): @@ -121,6 +140,7 @@ def name(self): def validate_args(self): super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level if self.args.include != "" or self.args.exclude != "": raise ValueError(f"{self.name} backend does not support worker include/exclusion") @@ -130,6 +150,17 @@ def validate_args(self): def get_cmd(self, environment, active_resources): total_process_count = sum(self.resource_pool.values()) + launcher_args = split(self.args.launcher_args) + + # If btl_tcp_if_include option is provided through launcher_args, we use it. Otherwise, we add + # `--mca btl_tcp_if_include eth0` option as a default value for compatibility. + btl_tcp_opt = ['--mca', 'btl_tcp_if_include', 'eth0'] + if len(launcher_args) >= 2: + for i in range(len(launcher_args) - 1): + if launcher_args[i] in ['-mca', '--mca'] and launcher_args[i + 1] == 'btl_tcp_if_include': + btl_tcp_opt = [] + break + mpirun_cmd = [ 'mpirun', '-n', @@ -139,10 +170,7 @@ def get_cmd(self, environment, active_resources): '--mca', 'btl', '^openib', - '--mca', - 'btl_tcp_if_include', - 'eth0', - ] + split(self.args.launcher_args) + ] + btl_tcp_opt + launcher_args export_cmd = [] for k, v in self.exports.items(): @@ -184,6 +212,8 @@ def get_cmd(self, environment, active_resources): devices_per_node = self.resource_pool.values() total_process_count = sum(devices_per_node) process_per_node = list(devices_per_node)[0] + if not all([n == process_per_node for n in devices_per_node]): + raise ValueError("MPICH requires same number of devices per node") mpirun_cmd = [ 'mpirun', @@ -195,14 +225,121 @@ def get_cmd(self, environment, active_resources): export_cmd = [] for k, v in self.exports.items(): - export_cmd += ['-x', "{}={}".format(k, v)] + export_cmd += ['-genv', "{}={}".format(k, v)] + + export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)] + export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)] + export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)] + export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)] + + export_cmd += ['-hosts'] + hosts = "" + for i, host in enumerate(self.resource_pool.keys()): + if i == 0: + hosts = f"{host}" + else: + hosts += f",{host}" + export_cmd += [hosts] + helper_args = ["--launcher"] + [self.args.launcher] python_exec = [] if not self.args.no_python: - python_exec = [sys.executable, "-u"] + python_exec += [sys.executable, "-u"] if self.args.module: python_exec.append("-m") - return mpirun_cmd + python_exec + [self.user_script] + self.user_arguments + helper_args.append("--module") + else: + helper_args.append("--no_python") + + helper_cmd = str(os.path.dirname(os.path.realpath(__file__))) + '/launcher_helper.py' + helper_cmd = [helper_cmd] + helper_args + [self.user_script] + self.user_arguments + + return mpirun_cmd + export_cmd + python_exec + helper_cmd + + +class IMPIRunner(MultiNodeRunner): + + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + + def backend_exists(self): + #TODO: if IB is available we should suggestion mpich + return shutil.which('mpirun') #mpich_info + + @property + def name(self): + return "impi" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError(f"{self.name} backend does not support worker include/exclusion") + + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + devices_per_node = self.resource_pool.values() + total_process_count = sum(devices_per_node) + process_per_node = list(devices_per_node)[0] + if not all([n == process_per_node for n in devices_per_node]): + raise ValueError("Intel MPI requires same number of devices per node") + + mpirun_cmd = [ + 'mpirun', + '-ppn', + f'{process_per_node}', + ] + split(self.args.launcher_args) + export_cmd = [] + + for k, v in self.exports.items(): + export_cmd += ['-genv', f'{k}', f'{v}'] + + if self.args.bind_cores_to_rank: + cores_per_rank, _ = get_numactl_cmd(self.args.bind_core_list, process_per_node, 0) + export_cmd += ['-genv', 'OMP_NUM_THREADS', str(cores_per_rank)] + + export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)] + export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)] + export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)] + export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)] + + # turn off IMPI core binding, use deepspeed's own core binding + export_cmd += ['-genv', 'I_MPI_PIN', '0'] + + export_cmd += ['-hosts'] + hosts = "" + for i, host in enumerate(self.resource_pool.keys()): + if i == 0: + hosts = f"{host}" + else: + hosts += f",{host}" + export_cmd += [hosts] + + per_host_cmd = [] + + for i in range(total_process_count): + local_rank = i % process_per_node + python_exec = [] + if self.args.bind_cores_to_rank: + _, numactl_cmd = get_numactl_cmd(self.args.bind_core_list, process_per_node, local_rank) + python_exec += numactl_cmd + + if not self.args.no_python: + python_exec += [sys.executable, "-u"] + if self.args.module: + python_exec.append("-m") + env_mapping = ['-env', 'RANK', str(i)] + env_mapping += ['-env', 'LOCAL_RANK', str(local_rank)] + if i == 0: + per_host_cmd = ['-n', '1'] + env_mapping + python_exec + [self.user_script] + self.user_arguments + else: + per_host_cmd = per_host_cmd + [':', '-n', '1'] + env_mapping + python_exec + [self.user_script + ] + self.user_arguments + print(mpirun_cmd + export_cmd + per_host_cmd) + return mpirun_cmd + export_cmd + per_host_cmd class SlurmRunner(MultiNodeRunner): @@ -286,7 +423,7 @@ def backend_exists(self): if not mpiname_exists: warnings.warn("mpiname does not exist, mvapich is not installed properly") else: - results = subprocess.check_output('mpiname', shell=True) + results = subprocess.check_output(['mpiname']) mpiname_results = results.decode('utf-8').strip() if "MVAPICH2-GDR" in mpiname_results: exists = True diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index cc34af81b2fe..1a1271c1fda2 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -20,27 +20,35 @@ from copy import deepcopy import signal import time +from typing import Tuple, List, Dict +from collections import defaultdict +import shlex -from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner -from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER +from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner +from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..nebula.constants import NEBULA_EXPORT_ENVS -from ..utils import logger +from ..utils import logger, set_log_level_from_string from ..autotuning import Autotuner from deepspeed.accelerator import get_accelerator DLTS_HOSTFILE = "/job/hostfile" -EXPORT_ENVS = ['MLFLOW', 'NCCL', 'PYTHON', 'MV2', 'UCX'] +EXPORT_ENVS = ['MLFLOW', 'PYTHON', 'MV2', 'UCX'] EXPORT_ENVS += NEBULA_EXPORT_ENVS -DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" +DEEPSPEED_ENVIRONMENT_NAME = os.getenv("DS_ENV_FILE", ".deepspeed_env") DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] PDSH_MAX_FAN_OUT = 1024 +# On AISC compute, each node sets environment variables independently, want to prevent +# exporting rank-0 env variables in case of heterogeneous compute. +EXCLUDE_ENVS = {'AISC_JOB_NAME': ['NCCL_IB_HCA', 'UCX_NET_DEVICES']} + def parse_args(args=None): parser = argparse.ArgumentParser(description="DeepSpeed runner to help launch distributed " - "multi-node/multi-gpu training jobs.") + "multi-node/multi-gpu training jobs.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-H", "--hostfile", @@ -94,6 +102,7 @@ def parse_args(args=None): "Default is num_nodes when elastic training is enabled") parser.add_argument("--num_gpus", + "--num_accelerators", type=int, default=-1, help="Max number of GPUs to use on each node, will use " @@ -111,11 +120,17 @@ def parse_args(args=None): help="(optional) IP address of node 0, will be " "inferred via 'hostname -I' if not specified.") + parser.add_argument("--node_rank", + default=-1, + type=int, + help="ID of each node in the range [0:N). " + "Only required when --no_ssh is set.") + parser.add_argument("--launcher", default=PDSH_LAUNCHER, type=str, help="(optional) choose launcher backend for multi-node " - "training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.") + "training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH, IMPI.") parser.add_argument("--launcher_args", default="", @@ -139,6 +154,10 @@ def parse_args(args=None): help="Do not pass local_rank as an argument when calling " "the user's training script.") + parser.add_argument("--no_ssh", + action="store_true", + help="Launch training independently on each node without ssh setup.") + parser.add_argument("--no_ssh_check", action="store_true", help="Do not perform ssh check in multi-node launcher model") @@ -172,16 +191,39 @@ def parse_args(args=None): parser.add_argument("user_script", type=str, help="User script to launch, followed by any required " "arguments.") + parser.add_argument('user_args', nargs=argparse.REMAINDER) + parser.add_argument("--bind_cores_to_rank", action="store_true", help="Bind each rank to different cores of the host") + parser.add_argument("--bind_core_list", type=str, default=None, help="List of cores to bind to with comma separated list of " "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not " "specified, all cores on system would be used rank binding") + + parser.add_argument("--ssh_port", type=int, default=None, help="SSH port to use for remote connections") + + parser.add_argument("--venv_script", + type=str, + default=None, + help="Python virtual environment activation script for job.") + + # TODOV1: change the default to 'warning' + parser.add_argument("--log_level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical'], + help="Set runner loglevel. The default is 'info'") + + parser.add_argument("-q", + "--quiet", + action="store_true", + help="Try to be as quiet as possible. Aliases to `--log_level error`") + return parser.parse_args(args=args) @@ -221,7 +263,7 @@ def _parse_hostfile(hostfile_lines): resource_pool[host] = num_slots else: logger.error(f"Bad hostfile text: {hostfile_lines}") - raise ValueError("Hostfile contains a bad entry: {line}, unable to proceed with launching") + raise ValueError(f"Hostfile contains a bad entry: {line}, unable to proceed with launching") if len(resource_pool) == 0: logger.error(f"Bad hostfile text: {hostfile_lines}") @@ -240,6 +282,31 @@ def _stable_remove_duplicates(data): return new_list +def parse_node_config(node_config: str) -> Tuple[str, List[int]]: + SLOT_LIST_START = ':' + SLOT_SEP = ',' + + if SLOT_LIST_START not in node_config: + return node_config, [] + + hostname, slots = node_config.split(SLOT_LIST_START) + slots = [int(x) for x in slots.split(SLOT_SEP)] + + return hostname, slots + + +def parse_node_config_list(node_config_list: List[str]) -> Dict[str, List[int]]: + NODE_SEP = '@' + + node_configs = defaultdict(list) + + for node_config in node_config_list.split(NODE_SEP): + hostname, slots = parse_node_config(node_config) + node_configs[hostname] += slots + + return {k: sorted(list(set(v))) for k, v in node_configs.items()} + + def parse_resource_filter(host_info, include_str="", exclude_str=""): '''Parse an inclusion or exclusion string and filter a hostfile dictionary. @@ -254,11 +321,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): slot 0 on worker-1. ''' - # Constants that define our syntax - NODE_SEP = '@' - SLOT_LIST_START = ':' - SLOT_SEP = ',' - # Ensure include/exclude are mutually exclusive if (include_str != "") and (exclude_str != ""): raise ValueError('include_str and exclude_str are mutually exclusive.') @@ -276,12 +338,9 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): parse_str = exclude_str # foreach node in the list - for node_config in parse_str.split(NODE_SEP): + for hostname, slots in parse_node_config_list(parse_str).items(): # Node can either be alone or node:slot,slot,slot - if SLOT_LIST_START in node_config: - hostname, slots = node_config.split(SLOT_LIST_START) - slots = [int(x) for x in slots.split(SLOT_SEP)] - + if len(slots) > 0: # sanity checks if hostname not in host_info: raise ValueError(f"Hostname '{hostname}' not found in hostfile") @@ -299,7 +358,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): # User just specified the whole node else: - hostname = node_config # sanity check hostname if hostname not in host_info: raise ValueError(f"Hostname '{hostname}' not found in hostfile") @@ -332,8 +390,10 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): def parse_inclusion_exclusion(resource_pool, inclusion, exclusion): active_resources = collections.OrderedDict() + node_configs = parse_node_config_list(inclusion) + for hostname, slots in resource_pool.items(): - active_resources[hostname] = list(range(slots)) + active_resources[hostname] = node_configs[hostname] if hostname in node_configs else list(range(slots)) return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion) @@ -376,23 +436,28 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool): def main(args=None): args = parse_args(args) + if args.quiet: + args.log_level = "error" + set_log_level_from_string(args.log_level) + if args.elastic_training: assert args.master_addr != "", "Master Addr is required when elastic training is enabled" resource_pool = fetch_hostfile(args.hostfile) - # respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if not resource_pool and len(cuda_visible_devices): - detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}" + # respect VISIBLE_DEVICES for a single node and no explicit resource filters + visible_devices_env = get_accelerator().visible_devices_envs()[0] + visible_devices = os.environ.get(visible_devices_env, "") + if not resource_pool and len(visible_devices): + detected_str = f"Detected VISIBLE_DEVICES={visible_devices}" if len(args.include) or len(args.exclude) or args.num_nodes > 1 or args.num_gpus > 0: print( f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed." ) else: - args.include = f"localhost:{cuda_visible_devices}" + args.include = f"localhost:{visible_devices}" print(f"{detected_str}: setting --include={args.include}") - del os.environ["CUDA_VISIBLE_DEVICES"] + del os.environ[visible_devices_env] if args.num_nodes >= 0 or args.num_gpus >= 0: if args.include != "" or args.exclude != "": @@ -415,13 +480,13 @@ def main(args=None): env = os.environ.copy() # validate that passwordless-ssh is workly properly with this hostfile - if multi_node_exec and not args.no_ssh_check: + if multi_node_exec and not args.no_ssh_check and not args.no_ssh: first_host = list(active_resources.keys())[0] try: - subprocess.check_call(f'ssh -o PasswordAuthentication=no {first_host} hostname', - stderr=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - shell=True) + ssh_check_cmd = ("ssh -o PasswordAuthentication=no " + + (f"-p {args.ssh_port} " if args.ssh_port is not None else "") + f"{first_host} hostname") + safe_ssh_cmd = shlex.split(ssh_check_cmd) + subprocess.check_call(safe_ssh_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) except subprocess.CalledProcessError: raise RuntimeError( f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh." @@ -430,18 +495,22 @@ def main(args=None): if not args.master_addr: assert multi_node_exec first_host = list(active_resources.keys())[0] - hostname_cmd = [f"ssh {first_host} hostname -I"] + ssh_check_cmd = "ssh " + if args.ssh_port is not None: + ssh_check_cmd += f" -p {args.ssh_port}" + ssh_check_cmd += f" {first_host} hostname -I" + hostname_cmd = shlex.split(ssh_check_cmd) try: - result = subprocess.check_output(hostname_cmd, shell=True) + result = subprocess.check_output(hostname_cmd) except subprocess.CalledProcessError as err: logger.error( - "Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr" + "Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr" ) raise err args.master_addr = result.decode('utf-8').split()[0] if not args.master_addr: raise RuntimeError( - f"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr" + "Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr" ) logger.info(f"Using IP address of {args.master_addr} for node {first_host}") @@ -466,16 +535,22 @@ def main(args=None): if args.elastic_training: assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training" + if args.no_ssh: + assert (0 <= args.node_rank < + len(active_resources)), "Launching training without ssh, but --node_rank is not set correctly." + # encode world info as base64 to make it easier to pass via command line world_info_base64 = encode_world_info(active_resources) - multi_node_exec = args.force_multi or len(active_resources) > 1 + multi_node_exec = (args.force_multi or len(active_resources) > 1) and not args.no_ssh if not multi_node_exec: deepspeed_launch = [ sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}", f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" ] + if args.no_ssh: + deepspeed_launch.append(f"--node_rank={args.node_rank}") if args.no_python: deepspeed_launch.append("--no_python") if args.module: @@ -492,8 +567,12 @@ def main(args=None): deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}") if args.bind_cores_to_rank: deepspeed_launch.append("--bind_cores_to_rank") - if args.bind_core_list != None: + if args.bind_core_list is not None: deepspeed_launch.append(f"--bind_core_list={args.bind_core_list}") + if args.quiet: + deepspeed_launch.append("--quiet") + deepspeed_launch.append(f"--log_level={args.log_level}") + cmd = deepspeed_launch + [args.user_script] + args.user_args else: args.launcher = args.launcher.lower() @@ -503,6 +582,8 @@ def main(args=None): runner = OpenMPIRunner(args, world_info_base64, resource_pool) elif args.launcher == MPICH_LAUNCHER: runner = MPICHRunner(args, world_info_base64, resource_pool) + elif args.launcher == IMPI_LAUNCHER: + runner = IMPIRunner(args, world_info_base64, resource_pool) elif args.launcher == MVAPICH_LAUNCHER: runner = MVAPICHRunner(args, world_info_base64, resource_pool) elif args.launcher == SLURM_LAUNCHER: @@ -519,21 +600,30 @@ def main(args=None): else: env['PYTHONPATH'] = curr_path - exports = "" + excluded_vars = [] + for exclude_key, var_list in EXCLUDE_ENVS.items(): + if exclude_key in env.keys(): + # key exists in launcher env -> var list should be used + excluded_vars += var_list + + # load envs from accelerator + exports = EXPORT_ENVS + get_accelerator().export_envs() for var in env.keys(): - if any([var.startswith(name) for name in EXPORT_ENVS]): - runner.add_export(var, env[var]) + if any([var.startswith(name) for name in exports]): + if not any([var == name for name in excluded_vars]): + runner.add_export(var, env[var]) for environ_path in DEEPSPEED_ENVIRONMENT_PATHS: environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME) if os.path.isfile(environ_file): + logger.info(f"deepspeed_env file = {environ_file}") with open(environ_file, 'r') as fd: for var in fd.readlines(): key, val = var.split('=', maxsplit=1) runner.add_export(key, val) if args.launcher == PDSH_LAUNCHER: - cmd, kill_cmd = runner.get_cmd(env, active_resources) + cmd, kill_cmd, env = runner.get_cmd(env, active_resources) else: cmd = runner.get_cmd(env, active_resources) @@ -549,8 +639,9 @@ def sigkill_handler(signum, frame): time.sleep(1) sys.exit(1) - if args.launcher == PDSH_LAUNCHER: + if args.launcher == PDSH_LAUNCHER and multi_node_exec: signal.signal(signal.SIGINT, sigkill_handler) + signal.signal(signal.SIGTERM, sigkill_handler) result.wait() diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py new file mode 100644 index 000000000000..9931a95a0a40 --- /dev/null +++ b/deepspeed/linear/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import OptimizedLinear +from .config import LoRAConfig, QuantizationConfig +from .context_manager import Init, init_lora diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py new file mode 100644 index 000000000000..1459704a32c5 --- /dev/null +++ b/deepspeed/linear/config.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass, field +from typing import List + +import torch + + +@dataclass +class LoRAConfig: + """ + Configuration settings for LoRAOptimizedLinear. + + Attributes: + lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64. + lora_alpha (float): LoRA scaling factor, default is 16. + base_weight_sharding (int): The degree to which the base weights are sharded, + should typically be set to the data-parallel world size to maximize the memory + reduction benefits. Defaults to 1, which means this feature is disabled. + offload (bool): offload frozen parameters to cpu when not in use + offload_ratio (float): ratio of parameters to offload to cpu when not in use + delay_lora_init (bool): initialize lora parameters at time of model init or allow manual init later + target_mods (str): target module names to apply LoRA to, defaults to llama-3.1 arch + """ + lora_r: int = 64 + lora_alpha: float = 16. + base_weight_sharding: int = 1 + offload: bool = False + offload_ratio: float = 0.0 + delay_lora_init: bool = False + target_mods: List[str] = field( + default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) + + +@dataclass +class QuantizationConfig: + """ + Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, + and QuantizedParameter + + Attributes: + q_bits (int): The number of bits used for quantization. Default is 8. + mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. + group_size (int): The number of elements used for quantization. Default is 512. + q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as + uint8, but inside the kernels the quantization is done to fp8) + """ + q_bits: int = 8 + mantissa_bits: int = 3 + group_size: int = 512 + q_dtype: torch.dtype = torch.uint8 diff --git a/deepspeed/linear/context_manager.py b/deepspeed/linear/context_manager.py new file mode 100644 index 000000000000..204fa0fe9c1d --- /dev/null +++ b/deepspeed/linear/context_manager.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear + +import torch + +try: + import transformers +except ImportError: + transformers = None + + +def init_lora(model): + model.requires_grad_(False) + for m in model.modules(): + if isinstance(m, LoRAOptimizedLinear): + m.init_lora() + + +class Init(object): + """ + Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model + construction which will shard base weights and reduce overall memory usage during model init. Primarily + useful when initializing a model via transformers.AutoModelForCausalLM. + + Example usage: + lora_config = deepspeed.linear.LoRAConfig(..) + quant_config = deepspeed.linear.QuantizationConfig(..) + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B") + + """ + + def __init__(self, lora_config=None, quant_config=None): + self._orig_nn_linear = torch.nn.Linear + self._orig_causallm_pretrained = None + if transformers != None: + self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained + self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config + self.lora_config = lora_config + self.quant_config = quant_config + self._post_init_complete = False + + def __enter__(self): + + class OptLinearWrapper: + _orig_nn_linear = self._orig_nn_linear + _lora_config = self.lora_config + _quant_config = self.quant_config + + def __new__(self, *args, **kwargs): + self._lora_config.delay_lora_init = True + kwargs['lora_config'] = self._lora_config + kwargs['quantization_config'] = self._quant_config + kwargs['linear_cls'] = self._orig_nn_linear + return OptimizedLinear(*args, **kwargs) + + def _model_init(model): + if self.lora_config != None: + init_lora(model) + self._post_init_complete = True + return model + + # ensures non-lora params are frozen and lora weights are initialized + def from_pretrained(*args, **kwargs): + model = self._orig_causallm_pretrained(*args, **kwargs) + return _model_init(model) + + def from_config(*args, **kwargs): + model = self._orig_causallm_config(*args, **kwargs) + return _model_init(model) + + torch.nn.Linear = OptLinearWrapper + if transformers != None: + transformers.AutoModelForCausalLM.from_pretrained = from_pretrained + transformers.AutoModelForCausalLM.from_config = from_config + + def __exit__(self, *args, **kwargs): + torch.nn.Linear = self._orig_nn_linear + if not self._post_init_complete: + print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically ' + 'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). ' + 'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise ' + 'you will error out during runtime.') + else: + transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained + transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py new file mode 100644 index 000000000000..3720196aa255 --- /dev/null +++ b/deepspeed/linear/optimized_linear.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import is_dataclass +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from .config import LoRAConfig, QuantizationConfig +from .quantization import QuantizedParameter, QuantizedLinear + + +class OptimizedLinear(nn.Module): + """ + Optimized version of nn.Linear that adds features such as: + * LoRA w. base weight sharding + * FP [6,8,12] quantization + + Arguments: + input_dim: Required: size of each input sample + output_dim: Required: size of each output sample + bias: Optional: If set to False, the layer will not learn an additive bias. Default: False + lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree + quantization_config: Optional: QuantizationConfig defining quantization features + dtype: Optional: parameter dtype, only supports bfloat16 currently + + Returns: + Returns a new nn.Module depending on the input config. Either native + torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. + """ + + def __new__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + device=None, + dtype=torch.bfloat16, + linear_cls=nn.Linear): + + if quantization_config is not None and not is_dataclass(quantization_config): + raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") + if lora_config is not None and not is_dataclass(lora_config): + raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") + if lora_config is None and quantization_config is None: + # Everything disabled, fall back to normal nn.Linear + self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device) + + elif lora_config: + # lora enabled, quantization may or may not be + self = LoRAOptimizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=dtype, + device=device, + linear_cls=linear_cls) + + elif quantization_config: + # only quantization enabled, no lora + self = QuantizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + quantization_config=quantization_config, + dtype=dtype) + return self + + +class LoRAOptimizedLinear(nn.Module): + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + device=None, + dtype=torch.bfloat16, + linear_cls=nn.Linear): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + self.lora_config = lora_config + self.quantization_config = quantization_config + self.device = get_accelerator().current_device_name() if device is None else device + self.linear_cls = linear_cls + self.dtype = dtype + assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" + assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear" + self.zero_shards = self.lora_config.base_weight_sharding + self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) + if self.zero_shards > 1: + assert self.zero_shards == dist.get_world_size( + ), "base weight sharding is only supported across world size" + w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype), + requires_grad=False) + else: + w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False) + torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim)) + + if self.quantization_config is not None: + assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" + self.weight = QuantizedParameter(w, quantization_config=quantization_config) + else: + self.weight = w + + self.disabled = False + self._initialized = False + if not self.lora_config.delay_lora_init: + self.init_lora() + + def disable(self): + self.disabled = True + self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype), + requires_grad=False) + + def init_lora(self): + if self.disabled: + return + + if self.quantization_config is not None: + # ensure quant-param wasn't stripped, in some cases transformers will do this during model init + if not isinstance(self.weight, QuantizedParameter): + self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config) + + self._initialized = True + self.weight.requires_grad = False + + # Mark base weight to prevent broadcast and ensure proper offload behavior + self.weight.ds_optim_param = True + + self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r + + # Keeping lora weights in bf16 precision for ease of training. + self.lora_weight_1 = self.linear_cls(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=self.device, + dtype=self.dtype) + self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=self.device, + dtype=self.dtype) + + # initialize "A" with kaiming uniform and "B" with zeros following this + # https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155 + nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_weight_2.weight) + self.lora_weight_1.weight.requires_grad = True + self.lora_weight_2.weight.requires_grad = True + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + if not any([target in prefix for target in self.lora_config.target_mods]): + # module does not match any target_mods, we must revert to normal nn.Linear via disable + self.disable() + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + + if self.zero_shards > 1: + if not dist.is_initialized(): + raise RuntimeError( + "attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first." + ) + rank = dist.get_rank() + shape_local = self.output_dim * self.sharded_weight_size + base_weight_name = f"{prefix}weight" + incoming_param = state_dict[base_weight_name] + state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local) + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def full_weight(self): + base_weight = self.weight + if getattr(base_weight, 'ds_offload', False): + # move to gpu so we can dequant and all-gather + assert base_weight.device == torch.device('cpu'), \ + f"expected base weight on cpu but found {base_weight.device}" + base_weight.offload(revert=True) + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + base_weight.offload() + else: + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + + tensor_out = torch.empty(self.output_dim * self.input_dim, + dtype=local_weight.dtype, + device=local_weight.device) + dist.all_gather_into_tensor(tensor_out, local_weight) + return tensor_out.reshape(self.output_dim, self.input_dim) + + def linear_without_F_linear(self, input, weight): + output = torch.mm(input.reshape(-1, input.shape[-1]), weight) + output = output.view(*input.shape[:-1], weight.shape[1]) + return output + + def forward(self, input_tensor): + if self.disabled: + return F.linear(input_tensor, self.weight) + assert self._initialized, "init_lora was never called, please initialize before proceeding" + + # Gather the sharded base weight + if self.zero_shards > 1: + with torch.no_grad(): + base_weight = self.full_weight() + elif self.quantization_config: + base_weight = self.weight.dequantized() + else: + base_weight = self.weight + + base_weight_output = F.linear(input_tensor, base_weight) + lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) + return base_weight_output + self.lora_scaling_factor * lora_output diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py new file mode 100644 index 000000000000..beabd4f930e4 --- /dev/null +++ b/deepspeed/linear/quantization.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize +from .config import QuantizationConfig + + +class QuantizedParameter(nn.Parameter): + """ + Quantized parameter class that implements weight quantization. Weights + are stored in quantized form on GPUs, and can be dequantized on-the-fly when + needed by the model. The weights are actually quantized during any `.to(device)`. + + Arguments: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Defaults + to False and is not supported to be True. Argument provided only for interface + compatibility with torch.nn.Parameter. + quantization_config (QuantizationConfig, optional): + quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer + that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also + required since the quantizer is stashed in the Parameter itself, some models + may clone the Parameter by passing an attribute __dict__. For an example, see + tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone + """ + + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = False, # quantized weights must be frozen + quantization_config: QuantizationConfig = None, + quantizer: Quantizer = None, + ): + if requires_grad: + raise ValueError("requires_grad=True is not supported with QuantizedParameter") + if data is None: + data = torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config + if quantizer is not None: + self.quantizer = quantizer + else: + # if FPQuantizerBuilder is not compatible in this env this init will fail + self.quantizer = FP_Quantize(quantization_config=self.quantization_config) + self._ensure_quantized(self) + return self + + def _ensure_quantized(self, tensor: torch.Tensor): + # If the tensor is on the accelerator and is not quantized, then quantize it in-place. + if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype: + with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): + tensor.data = self.quantizer.quantize(tensor.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + assert tensor.dtype == self.quantization_config.q_dtype + + def dequantized(self) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype: + with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): + return self.quantizer.dequantize(self.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + return self.data + + def offload(self, revert=False): + if getattr(self, 'ds_offload', False): + if revert: + self.data = self.to(get_accelerator().current_device_name()) + else: + self.data = self.to('cpu') + + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["quantization_config"] = self.quantization_config + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.quantizer = state["quantizer"] + self.quantization_config = state["quantization_config"] + self.data = state["data"] + self.requires_grad = state["requires_grad"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quantizer = copy.deepcopy(state["quantizer"]) + new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def cuda(self, device=None, non_blocking=False): + device = "cuda" if device is None else device + self.quantizer.to(device, non_blocking=non_blocking) + return self.to(device, non_blocking=non_blocking) + + def to(self, *args, **kwargs): + """ + Move the parameter to the given device. Then, if the device is a cuda device, + quantize it. + """ + tensor = super().to(*args, **kwargs) + self.quantizer.to(*args, **kwargs) + self._ensure_quantized(tensor) + return tensor + + +class QuantizedLinear(nn.Linear): + """ + Linear layer that implements weight quantization. Parameters + are stored via `QuantizedParameter` and are dequantized on-the-fly during any + forward pass. + """ + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) + assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" + self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.dequantized(), self.bias) diff --git a/deepspeed/model_implementations/diffusers/unet.py b/deepspeed/model_implementations/diffusers/unet.py index 6086d9fb9862..8d5ddd95437a 100644 --- a/deepspeed/model_implementations/diffusers/unet.py +++ b/deepspeed/model_implementations/diffusers/unet.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def forward(self, *inputs, **kwargs): @@ -53,16 +54,23 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True - def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None): + def _forward(self, + sample, + timestamp, + encoder_hidden_states, + return_dict=True, + cross_attention_kwargs=None, + timestep_cond=None, + added_cond_kwargs=None): if cross_attention_kwargs: return self.unet(sample, timestamp, diff --git a/deepspeed/model_implementations/diffusers/vae.py b/deepspeed/model_implementations/diffusers/vae.py index 445a9843921a..ce50ade647a8 100644 --- a/deepspeed/model_implementations/diffusers/vae.py +++ b/deepspeed/model_implementations/diffusers/vae.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -27,10 +28,10 @@ def _graph_replay_decoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_decoder_kwargs[k].copy_(kwargs[k]) - self._decoder_cuda_graph.replay() + get_accelerator().replay_graph(self._decoder_cuda_graph) return self.static_decoder_output - def _decode(self, x, return_dict=True): + def _decode(self, x, return_dict=True, generator=None): return self.vae.decode(x, return_dict=return_dict) def _create_cuda_graph_decoder(self, *inputs, **kwargs): @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._decoder_cuda_graph = torch.cuda.CUDAGraph() + self._decoder_cuda_graph = get_accelerator().create_graph() self.static_decoder_inputs = inputs self.static_decoder_kwargs = kwargs - with torch.cuda.graph(self._decoder_cuda_graph): + with get_accelerator().capture_to_graph(self._decoder_cuda_graph): self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs) self.decoder_cuda_graph_created = True @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_encoder_kwargs[k].copy_(kwargs[k]) - self._encoder_cuda_graph.replay() + get_accelerator().replay_graph(self._encoder_cuda_graph) return self.static_encoder_output def _encode(self, x, return_dict=True): @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._encoder_cuda_graph = torch.cuda.CUDAGraph() + self._encoder_cuda_graph = get_accelerator().create_graph() self.static_encoder_inputs = inputs self.static_encoder_kwargs = kwargs - with torch.cuda.graph(self._encoder_cuda_graph): + with get_accelerator().capture_to_graph(self._encoder_cuda_graph): self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs) self.encoder_cuda_graph_created = True @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._all_cuda_graph.replay() + get_accelerator().replay_graph(self._all_cuda_graph) return self.static_output def forward(self, *inputs, **kwargs): @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._all_cuda_graph = torch.cuda.CUDAGraph() + self._all_cuda_graph = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._all_cuda_graph): + with get_accelerator().capture_to_graph(self._all_cuda_graph): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.all_cuda_graph_created = True diff --git a/deepspeed/model_implementations/transformers/clip_encoder.py b/deepspeed/model_implementations/transformers/clip_encoder.py index 8d9291896986..848a5b48dcf1 100644 --- a/deepspeed/model_implementations/transformers/clip_encoder.py +++ b/deepspeed/model_implementations/transformers/clip_encoder.py @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[self.iter][k].copy_(kwargs[k]) - self._cuda_graphs[self.iter].replay() + get_accelerator().replay_graph(self._cuda_graphs[self.iter]) return self.static_output[self.iter] def forward(self, *inputs, **kwargs): @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph() + self._cuda_graphs[self.iter] = get_accelerator().create_graph() self.static_inputs[self.iter] = inputs self.static_kwargs[self.iter] = kwargs - with torch.cuda.graph(self._cuda_graphs[self.iter]): + with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]): self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter], **self.static_kwargs[self.iter]) diff --git a/deepspeed/model_implementations/transformers/ds_llama2.py b/deepspeed/model_implementations/transformers/ds_llama2.py new file mode 100644 index 000000000000..325bfb4f7e18 --- /dev/null +++ b/deepspeed/model_implementations/transformers/ds_llama2.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference + + +class DeepSpeedLlama2Inference(DeepSpeedTransformerInference): + """Initialize the DeepSpeed OPT Transformer Layer. + """ + + def __init__(self, + config, + mp_group=None, + quantize_scales=None, + quantize_groups=1, + merge_count=1, + mlp_extra_grouping=False): + super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping) + + def forward(self, *args, **kwargs): + + input = args[0] + input_mask = None + get_present = True + + self.allocate_workspace(input.size()) + + # We set the prev key/value to None when there is a prompt + if input.shape[1] > 1: + self.layer_past = None + layer_past = self.layer_past + + input_type = input.dtype + + if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \ + and input.dtype == torch.float: + target_dtype = torch.half if self.dtype == torch.int8 else self.dtype + input = input.to(target_dtype) + + with torch.no_grad(): + attention_output, key, value, context_outputtn_ctx, inp_norm = \ + self.attention(input, + input_mask, + None, + layer_past, + get_present, + None, None, None, + self.norm_w, + self.norm_b, + None) + self.layer_past = (key, value) + output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) + + output = output.to(input_type) + return output diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 6ef838cea741..db6359701be8 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -6,14 +6,17 @@ import torch import torch.nn as nn from deepspeed import comm as dist +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp from deepspeed.utils.logging import log_dist from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention +from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder - -inference_cuda_module = None +import deepspeed +if deepspeed.HAS_TRITON and get_accelerator().is_triton_supported(): + from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP + from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention class DeepSpeedTransformerInference(nn.Module): @@ -33,6 +36,7 @@ class DeepSpeedTransformerInference(nn.Module): for specific downstream tasks. """ layer_id = 0 + workspace = None def __init__(self, config, @@ -47,22 +51,28 @@ def __init__(self, self.config.layer_id = DeepSpeedTransformerInference.layer_id DeepSpeedTransformerInference.layer_id += 1 - data_type = torch.half if config.fp16 else torch.float - global inference_cuda_module - if inference_cuda_module is None: - builder = InferenceBuilder() - inference_cuda_module = builder.load() + data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype if DeepSpeedTransformerInference.layer_id == 1: log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0]) + if deepspeed.HAS_TRITON and self.config.use_triton: + log_dist("Injecting Triton kernels ...", [0]) if self.config.bigscience_bloom: self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count) + assert not self.config.use_triton + else: + if deepspeed.HAS_TRITON and self.config.use_triton: + self.attention = TritonSelfAttention(self.config) + else: + self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, + merge_count) + + if deepspeed.HAS_TRITON and self.config.use_triton: + self.mlp = TritonMLP(self.config) else: - self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, - merge_count) - self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count, - mlp_extra_grouping) + self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count, + mlp_extra_grouping) device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu' if self.config.set_empty_params: @@ -74,14 +84,25 @@ def __init__(self, self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), requires_grad=False) self.layer_past = None - self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \ - inference_cuda_module.allocate_workspace_fp16 - self._alloc_workspace = True + self.layer_norm = LayerNormOp() + if DeepSpeedTransformerInference.workspace is None: + DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config) + self._should_allocate_workspace = True + + def allocate_workspace(self, size): + # Allocate memory only on first layer forward + if self.config.layer_id == 0 and self._should_allocate_workspace: + DeepSpeedTransformerInference.workspace.allocate_workspace( + self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id, + self.config.mp_size, self.config.bigscience_bloom, + dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, + self.config.min_out_tokens) + self._should_allocate_workspace = False @classmethod def reset_cache(cls): - if inference_cuda_module is not None: - inference_cuda_module.reset_cache() + if cls.workspace is not None: + cls.workspace.reset_cache() def forward( self, @@ -112,17 +133,12 @@ def forward( if "hidden_states" in kwargs: input = kwargs["hidden_states"] + if layer_past is not None and past_key_value is not None: + raise ValueError("Only one of `layer_past` or `past_key_value` can be present.") + input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask - # Allocate memory only on first layer forward - if self.config.layer_id == 0 and self._alloc_workspace: - self.allocate_workspace(self.config.hidden_size, self.config.heads, - input.size()[1], - input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size, - self.config.bigscience_bloom, - dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, - self.config.min_out_tokens) - self._alloc_workspace = False + self.allocate_workspace(input.size()) get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask @@ -130,7 +146,7 @@ def forward( # We set the prev key/value to None when there is a prompt if input.shape[1] > 1: self.layer_past = None - layer_past = layer_past if layer_past is not None else self.layer_past + _layer_past = layer_past or past_key_value or self.layer_past head_mask = layer_head_mask if layer_head_mask is not None else head_mask attn_mask = None @@ -139,29 +155,32 @@ def forward( input = input[0] input_type = input.dtype - if (self.config.fp16 or self.config.q_int8) \ + if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \ and input.dtype == torch.float: - input = input.half() + target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype + input = input.to(target_dtype) + with torch.no_grad(): attention_output, key, value, context_outputtn_ctx, inp_norm = \ self.attention(input, input_mask, head_mask, - layer_past, + _layer_past, get_present, encoder_hidden_states, encoder_attention_mask, output_attentions, self.norm_w, self.norm_b, - alibi) + alibi, + **kwargs) presents = (key, value) - self.layer_past = presents if layer_past is None else None + self.layer_past = presents if layer_past is None and past_key_value is None else None output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) if not self.config.pre_layer_norm: - output = inference_cuda_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon) + output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon) output = output.to(input_type) if get_present: diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py index 4bdabf383b26..2299ef6e7c3a 100755 --- a/deepspeed/module_inject/__init__.py +++ b/deepspeed/module_inject/__init__.py @@ -6,5 +6,6 @@ from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection from .module_quantize import quantize_transformer_layer from .replace_policy import HFBertLayerPolicy -from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize +from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode, SubParamLinearLayer, SubParamLinearAllreduce from .policy import DSPolicy +from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType, AutoTPPresets, merge_autotp_configs diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py old mode 100644 new mode 100755 index bf49df9781f5..852c492f8b8e --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -8,10 +8,214 @@ from torch import nn from .replace_policy import replace_policies +from typing import Optional +import torch +from deepspeed import comm as dist +from .layers import * +from deepspeed.accelerator import get_accelerator +from .fusedqkv_utils import require_tp_fused_qkvw +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +from deepspeed.utils import groups +from deepspeed.module_inject.layers import is_autotp_training_mode +from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType + + +def move(tensor, device, copy=True): + if tensor.is_meta: + return torch.empty_like(tensor, device=device) + else: + # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). + # Using copy=True instead of clone() will help in case of cpu --> cpu. + # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. + return tensor.to(device, copy=copy) + + +class ReplaceWithTensorSlicing: + + def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): + if mp_group is not None: + self.gpu_index = dist.get_rank(group=mp_group) + else: + self.gpu_index = 0 + self.out_dim = out_dim + self.in_dim = in_dim + self.mp_size = mp_size + + def merge_assert(self, dim1, dim2): + assert dim1 > dim2, \ + 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\ + for merging your checkpoints before replacing the transformer layer with\ + inference-kernels' + + def strided_copy(self, + dst: Optional[torch.Tensor], + src: Optional[torch.Tensor], + num_splits: int, + int8: bool = False, + allocate_tensor: bool = False): + if src is None: + return src + src_shape = src.shape + dst_shape = dst.shape + + outer_dim = 0 if int8 else -1 + + if allocate_tensor: + dst = torch.empty_like(dst) + + src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim) + if (len(src_shape) == 2 and len(dst_shape) == 2): + if src_shape[outer_dim] == dst_shape[self.out_dim]: + try: + dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) + except Exception: + print(dst.shape, src.shape) + exit() + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) + if hasattr(src, 'scale'): + dst.scale = src.scale + return dst + self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) + qkv_size = dst_shape[self.out_dim] // num_splits + qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split] + weight_split = [ + torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0])) + ] + dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape( + weight_split[self.gpu_index].shape) + else: + if src_shape[0] == dst_shape[0]: + return torch.nn.parameter.Parameter(src) + qkv_size = dst_shape[0] // num_splits + qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] + bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))] + dst.data.copy_(bias_split[self.gpu_index].contiguous()) + + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) + if hasattr(src, 'scale'): + dst.scale = src.scale + return dst + + def copy(self, dst, src, int8=False, allocate_tensor=False): + if src is None: + return src + assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors + if allocate_tensor: + dst = torch.empty_like(dst) + outer_dim = 0 if int8 else 1 + inner_dim = 1 if int8 else 0 + src_shape = src.shape + dst_shape = dst.shape + if (len(src_shape) == 2 and len(dst_shape) == 2): + + if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]: + dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) + else: + if src_shape[inner_dim] != dst_shape[self.in_dim]: + self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim]) + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \ + src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :]) + else: + self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \ + src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :]) + else: + if src_shape[0] == dst_shape[0]: + dst = src if src.dtype == dst.dtype else dst.data.copy_(src) + else: + dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]]) + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) + if hasattr(src, 'scale'): + dst.scale = src.scale + return dst + + +class Loading(): + + def is_load_module(module): + load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] + load_layer_names = [ + "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", + "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm", + "Qwen3RMSNorm", "Qwen3MoeRMSNorm", "DeepseekV2RMSNorm", "DeepseekV3RMSNorm", + "DeepseekV2YarnRotaryEmbedding", "DeepseekV3YarnRotaryEmbedding", "MoEGate" + ] + return module.__class__ in load_layers or module._get_name() in load_layer_names + + def load_buffer(module, state_dict, prefix): + for name in module._buffers.keys(): + if module._buffers[name].data.is_meta: + module._buffers[name] = torch.nn.parameter.Parameter( + data=torch.empty_like(module._buffers[name].data, device="cpu"), + requires_grad=module._buffers[name].data.requires_grad) + if prefix + name in state_dict.keys(): + module._buffers[name].data.copy_(state_dict[prefix + name]) + + def load(module, state_dict, prefix, mp_group=None): + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) + if hasattr(module, 'weight'): + if module.weight.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"), + requires_grad=module.weight.data.requires_grad) + if 'query_key_value' in prefix: + module.weight = mp_replace.strided_copy(module.weight.data, + state_dict[prefix + 'weight'], + num_splits=3) + else: + module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight']) + else: + if hasattr(module, 'norm') and hasattr(module.norm, 'weight'): + if module.norm.weight.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.norm.weight = torch.nn.parameter.Parameter( + data=torch.empty_like(module.norm.weight.data, device="cpu"), + requires_grad=module.norm.weight.data.requires_grad) + module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight']) + + if prefix + 'bias' in state_dict.keys(): + if hasattr(module, 'bias'): + if module.bias.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"), + requires_grad=module.bias.data.requires_grad) + module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias']) + else: + if hasattr(module, 'norm') and hasattr(module.norm, 'bias'): + if module.norm.bias.data.is_meta: + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here + module.norm.bias = torch.nn.parameter.Parameter( + data=torch.empty_like(module.norm.bias.data, device="cpu"), + requires_grad=module.norm.bias.data.requires_grad) + module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias']) class AutoTP(): + def __init__(self, + module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + keep_module_on_host=False, + partition_config: Optional[AutoTPConfig] = None): + self.module = module + self.all_reduce_linears = all_reduce_linears + self.prefix = prefix + self.state_dict = state_dict + + self.mp_size = None + self.mp_group = None + self.linear_layer_setting = linear_layer_setting + self.orig_layer_impl = orig_layer_impl + self.linear_policies = None + self.conv_linear_layer = False + self.partition_config = partition_config + TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host) + def in_module_list(module, module_list): for item in module_list: if type(item).__name__ == type(module).__name__: @@ -32,7 +236,7 @@ def get_module_list(model): return mlist def supported(model): - unsupported = ['codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet'] + unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet'] model = str(model) key = re.search(r": (.*?)Model", model) if key is None: @@ -90,11 +294,13 @@ def tp_parser(model): module_list = AutoTP.get_module_list(model) assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy." + norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2'] + #ln_1 , ln_2 for Qwen for module in module_list: for key, submodule in module._modules.items(): if isinstance(submodule, nn.Linear): layer_list = layer_list + ["." + key] - elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm': + elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list: layer_list = layer_list + ["ln"] else: layer_list = layer_list + AutoTP.get_layers(key, submodule) @@ -104,6 +310,25 @@ def tp_parser(model): gem_list = gem_list + [layer_list[i - 1]] elif 'out_proj' in layer: gem_list = gem_list + [layer] + elif 'o_proj' in layer: + gem_list = gem_list + [layer] + elif 'down_proj' in layer: + gem_list = gem_list + [layer] + elif 'attention.dense' in layer and 'GPTNeoX' in str(model): + gem_list = gem_list + [layer] + elif 'self_attention.dense' in layer and 'falcon' in str( + type(module)): # this is a hack to get the right linear layer for this model! + gem_list = gem_list + [layer] + # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. + elif 'w2' in layer and 'Mixtral' in str(type(module)): + gem_list = gem_list + [layer] + elif 'self_attn.dense' in layer and 'Phi' in str(type(module)): + gem_list = gem_list + [layer] + elif 'self_attention.dense' in layer and 'ChatGLM' in str(model): + gem_list = gem_list + [layer] + elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model): + gem_list = gem_list + [layer] + layer_list = [] if gem_list != []: gem_list = list(set(gem_list)) @@ -112,3 +337,294 @@ def tp_parser(model): assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy." return policy_list + + def set_tensor_parallel_config(self, mp_size, mp_group): + + if is_autotp_training_mode(): + self.mp_group = groups.get_tensor_model_parallel_group() + self.mp_size = groups.get_tensor_model_parallel_world_size() + return + + self.mp_size = mp_size + self.mp_group = mp_group + + def _replace(self, child, name, conv_linear_layer): + # This function should clearly define the routing rules for specific layers + # and avoid any complex shard-related logic. + if getattr(child, "replaced", False) == True: + return + + weight_shape = child.weight.shape + mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) + + # If partition_config is provided, use the new configurable API + if self.partition_config is not None: + return self._replace_with_config(child, name) + + # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip + if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( + ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): + return child + # For Yuan model + if 'Yuan' in str(self.module): + if 'v_proj' in name: + return Yuan_LinearLayer(child, self.mp_group) + + elif 'o_proj' in name: + return Yuan_LinearAllreduce(child, self.mp_group) + + # For MLP including chunk layer. + if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): + return GateUpPack_LinearLayer(child, self.mp_group) + # For Arctic model, bypass to all_reduce replacement for w2 weights + arctic_w2_all_reduce_linear = False + if 'Arctic' in str(self.module) and 'w2' in name: + arctic_w2_all_reduce_linear = True + # For MoE MLP model, e.g., deepseek and jamba + down_proj = False + if 'down_proj' in name: + down_proj = True + if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: + + setattr(child, "replaced", True) + if self.conv_linear_layer: + return Conv_LinearALlreduce(child, self.mp_group, name=name) + elif name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce(child, self.mp_group) + + return LinearAllreduce(child, self.mp_group, name=name) + else: + + setattr(child, "replaced", True) + if self.conv_linear_layer: + conv_LinearLayer(child, self.mp_group) + elif require_tp_fused_qkvw(name, self.mp_size): + #Check and handle fused qkv for TP + return fused_LinearLayer(child, self.mp_group, fused_module=self.module) + + return LinearLayer(child, self.mp_group, name=name) + + def _replace_with_config(self, child, name): + """ + Replace layer using the new configurable AutoTP API. + + This method uses TPLayerSpec to determine how to partition the layer. + """ + if getattr(child, "replaced", False) == True: + return child + + # Build the full parameter name for pattern matching + param_name = name + ".weight" if not name.endswith(".weight") else name + + # Find matching spec + model_type = self._get_model_type() + spec = self.partition_config.find_matching_spec(param_name, model_type) + + if spec is None: + # No matching spec found + if self.partition_config.strict_mode: + raise ValueError(f"No matching spec for {param_name}") + # With partition_config, rely only on explicit specs and skip unmatched layers. + return child + + setattr(child, "replaced", True) + + if spec.partition_type == PartitionType.SKIP: + return child + + if spec.partition_type == PartitionType.ROW: + return self._create_row_parallel_layer(child, spec, name) + else: + return self._create_column_parallel_layer(child, spec, name) + + def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str): + """Create row-parallel layer (AllReduce after forward).""" + if self.conv_linear_layer: + return Conv_LinearALlreduce(module, self.mp_group, name=name) + # Check for lm_head / embed_out + if name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce(module, self.mp_group) + + if spec.shape is not None: + return SubParamLinearAllreduce( + module, + self.mp_group, + shape=spec.shape, + partition_dim=spec.get_partition_dim(), + name=name, + ) + return LinearAllreduce(module, self.mp_group, name=name) + + def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str): + """Create column-parallel layer (AllReduce in backward).""" + if self.conv_linear_layer: + return conv_LinearLayer(module, self.mp_group, name=name) + # Only use fused-QKV heuristics when no partition_config is provided. + elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size): + # Check and handle fused qkv for TP + return fused_LinearLayer(module, self.mp_group, fused_module=self.module) + if spec.shape is not None: + return SubParamLinearLayer( + module, + self.mp_group, + shape=spec.shape, + partition_dim=spec.get_partition_dim(), + name=name, + ) + return LinearLayer(module, self.mp_group, name=name) + + def _get_model_type(self) -> Optional[str]: + """Extract model type from module config or class name.""" + config = getattr(self.module, "config", None) + if config is not None: + model_type = getattr(config, "model_type", None) + if model_type: + return str(model_type).lower() + module_str = str(type(self.module)) + # Try to extract model type from class name (e.g., "LlamaDecoderLayer" -> "llama") + patterns = [ + r"(\w+)DecoderLayer", + r"(\w+)Block", + r"(\w+)Layer", + ] + for pattern in patterns: + match = re.search(pattern, module_str) + if match: + return match.group(1).lower() + return None + + def _slice_embedding(self, child, name, conv_linear_layer): + if getattr(child, "replaced", False) == True: + return + + mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) + + if hasattr(child.weight, 'ds_tensor'): + data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) + else: + data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1) + data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) + data = torch.nn.parameter.Parameter(data, requires_grad=False) + + new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size, name)) + new_embedding.weight.data.copy_(data) + setattr(child, "replaced", True) + return new_embedding + + def update_mp_params(self, child): + if getattr(child, "replaced", False) == True: + return + param_list = [ + "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size", + "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model", + "num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition" + ] + for param in param_list: + if "Yuan" in str(child) and 'embed_dim' in param_list: + param_list.remove('embed_dim') + if hasattr(child, param): + param_val = getattr(child, param) + setattr(child, param, get_shard_size(param_val, self.mp_size)) + setattr(child, "replaced", True) + + def update_linear_policies(self): + self.conv_linear_layer = False + if self.linear_layer_setting is not None: + self.linear_policies = {self.linear_layer_setting[0]: self._replace} + if len(self.linear_layer_setting) == 2: + self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding}) + else: + import transformers + if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block: + try: + self.conv_linear_layer = True + self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace} + except ImportError: + self.linear_policies = {nn.Linear: self._replace} + else: + self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding} + + def _replace_module(self, r_module, prev_name='', prev_class_name=''): + for name, child in r_module.named_children(): + if prev_class_name == "": + class_name = prev_name + elif prev_name == "": + class_name = prev_class_name + else: + class_name = prev_class_name + '.' + prev_name + checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.' + if Loading.is_load_module(child) and self.state_dict is not None: + if any(checking_key in item for item in self.state_dict): + Loading.load(child, self.state_dict, checking_key, self.mp_group) + else: + continue + if len(child._buffers) != 0 and self.state_dict is not None: + Loading.load_buffer(child, self.state_dict, checking_key) + + # When using partition_config (custom patterns/presets), use pattern-based routing + # instead of linear_policies. This keeps all pattern logic centralized here. + if self.partition_config is not None: + full_name = prev_name + '.' + name if prev_name else name + if isinstance(child, nn.Embedding): + # Check if embedding matches any pattern + param_name = full_name + ".weight" + model_type = self._get_model_type() + spec = self.partition_config.find_matching_spec(param_name, model_type) + if spec is not None and spec.partition_type != PartitionType.SKIP: + new_child = self._slice_embedding(child, full_name, False) + if new_child is not None: + setattr(r_module, name, new_child) + # If no pattern matched or skip, leave embedding unchanged + elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2: + new_child = self._replace_with_config(child, full_name) + if new_child is not None: + setattr(r_module, name, new_child) + else: + self.update_mp_params(child) + self._replace_module(child, full_name, class_name) + # Traditional path: use linear_policies for type-based routing + elif child.__class__ in self.linear_policies: + setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name, + self.conv_linear_layer)) + elif any(isinstance(child, lp) for lp in self.linear_policies): + # Added for falcon model support + # Note: isinstance will account for class inheritance, child.__class__ does not + key = None + for lp in self.linear_policies: + if isinstance(child, lp): + key = lp + break + assert key is not None + setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name, + self.conv_linear_layer)) + else: + self.update_mp_params(child) + self._replace_module(child, name, class_name) + return r_module + + def get_model_num_kv_heads(self, config): + num_kv_heads = None + # multi_query_group_num is for chatglm2 & chatglm3 + kv_head_names = [ + 'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads', + 'attention_heads' + ] + for name in kv_head_names: + if hasattr(config, name): + num_kv_heads = getattr(config, name) + if num_kv_heads is not None: + break + return num_kv_heads + + def _replace_last_linear_module(self, r_module): + if hasattr(r_module, "lm_head"): + name = "lm_head" + child = r_module.lm_head + elif hasattr(r_module, "embed_out"): + name = "embed_out" + child = r_module.embed_out + else: + return r_module + if child.__class__ in self.linear_policies: + setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer)) + return r_module diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py new file mode 100644 index 000000000000..a71b1a54d6f6 --- /dev/null +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed import comm as dist +import torch +from typing import Optional +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list + + +def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()]) + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + +def get_alibi_mask(self, tensor, seq_length_with_past): + mask = self.get_alibi_mask_orig(tensor, seq_length_with_past) + if not self.training and dist.is_initialized(): + num_heads_per_rank = get_shard_size(self.n_head, dist.get_world_size()) + offset = sum(get_shard_size_list(self.n_head, dist.get_world_size())[0:dist.get_rank()]) + mask = mask[offset:num_heads_per_rank + offset, :seq_length_with_past, :seq_length_with_past] + + return mask + + +def build_mpt_atten_bias_tensor(self, + device, + dtype, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None): + (attn_bias, attention_mask) = self._attn_bias_orig(device, + dtype, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id) + if dist.is_initialized(): + num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()]) + attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :] + return attn_bias, attention_mask + + +def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor: + r""" + Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation. This implementation has been copied from + the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi: + https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292 + """ + alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device) + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size()) + offset = dist.get_rank() * num_heads_per_rank + alibi = alibi[offset:num_heads_per_rank + offset, :, :] + return alibi diff --git a/deepspeed/module_inject/autotp_config.py b/deepspeed/module_inject/autotp_config.py new file mode 100644 index 000000000000..4bafea806829 --- /dev/null +++ b/deepspeed/module_inject/autotp_config.py @@ -0,0 +1,569 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Configurable AutoTP API + +This module provides a unified specification for tensor parallel layer partitioning. +The design is inspired by Universal Checkpointing's SubparamShape and provides +a single, well-defined format that users can easily understand, customize, and extend. +""" + +import re +from dataclasses import dataclass, field +from typing import List, Tuple, Union, Optional +from enum import Enum +from deepspeed.utils.logging import warning_once + + +class PartitionType(Enum): + """How the layer should be partitioned for tensor parallelism.""" + COLUMN = "column" # Partition output dim, AllReduce in backward + ROW = "row" # Partition input dim, AllReduce in forward + SKIP = "skip" # Do not partition this layer + + +@dataclass +class TPLayerSpec: + """ + Unified specification for tensor parallel layer partitioning. + + This is inspired by Universal Checkpointing's SubparamShape but extended + for AutoTP's needs (forward/backward communication patterns). + + The `shape` parameter supports at most 1-level nesting at the partition dimension: + - (3, -1) -> 3 equal-size sub-params + - ((q, k, v), -1) -> 3 unequal-size sub-params (1-level nesting) + + Examples: + # Simple row-parallel layer (e.g., o_proj, down_proj) + TPLayerSpec( + patterns=[".*\\.o_proj$", ".*\\.down_proj$"], + partition_type=PartitionType.ROW, + ) + + # Simple column-parallel layer (e.g., q_proj, k_proj, v_proj) + TPLayerSpec( + patterns=[".*\\.[qkv]_proj$"], + partition_type=PartitionType.COLUMN, + ) + + # Fused QKV - GLM style [Q, K, V] concatenated on dim 0 + TPLayerSpec( + patterns=[".*\\.query_key_value\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # 3 equal sub-params, -1 = infer + partition_dim=0, + ) + + # Fused QKV - Bloom style [q1,k1,v1,q2,k2,v2,...] + TPLayerSpec( + patterns=[".*\\.query_key_value\\.weight$"], + partition_type=PartitionType.COLUMN, + # No reshape needed, just split along dim 0 + ) + + # GQA with different Q/K/V sizes (1-level nesting) + TPLayerSpec( + patterns=[".*\\.qkv_proj\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=((q_size, k_size, v_size), -1), # Unequal sub-params + partition_dim=0, + ) + + # Chunked MLP (gate_up_proj) + TPLayerSpec( + patterns=[".*\\.gate_up_proj\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ) + + # MoE FFN with expert dimension + TPLayerSpec( + patterns=[".*\\.experts\\..*\\.w1\\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(num_experts, -1, hidden_in), # View as 3D + partition_dim=1, # Partition the hidden_out dimension + ) + + # Skip layer (e.g., MoE gate) + TPLayerSpec( + patterns=[".*\\.gate$", ".*\\.router$"], + partition_type=PartitionType.SKIP, + ) + """ + + # Layer identification - regex patterns to match parameter names + patterns: List[str] + + # Partition type determines communication pattern + partition_type: PartitionType = PartitionType.COLUMN + + # Optional: logical shape for partitioning + # - Use -1 for dimensions that should be inferred + # - Use tuple of ints at partition_dim for unequal sub-params (1-level nesting only) + # Examples: + # (3, -1) -> 3 equal sub-params + # ((4096, 1024, 1024), -1) -> 3 unequal sub-params (GQA) + # (n_experts, -1, hidden) -> MoE reshape + shape: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None + + # Which dimension to partition (after optional reshape) + # Default: 0 for COLUMN, 1 for ROW (standard 2D weight matrix) + partition_dim: Optional[int] = None + + # Optional: model type constraint (only apply for specific models) + model_types: Optional[List[str]] = None + + def __post_init__(self): + if isinstance(self.partition_type, str): + self.partition_type = PartitionType(self.partition_type.lower()) + if self.shape is not None: + self.shape = self._normalize_shape(self.shape) + self._validate_shape_format() + + @staticmethod + def _normalize_shape(shape): + if isinstance(shape, list): + return tuple(TPLayerSpec._normalize_shape(item) for item in shape) + if isinstance(shape, tuple): + return tuple(TPLayerSpec._normalize_shape(item) if isinstance(item, list) else item for item in shape) + return shape + + def _validate_shape_format(self): + if not isinstance(self.shape, tuple): + raise ValueError("AutoTP shape must be a tuple of ints or a tuple at partition_dim.") + partition_dim = self.get_partition_dim() + if partition_dim < 0 or partition_dim >= len(self.shape): + raise ValueError( + f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(self.shape)}.") + nested_tuple_seen = False + for idx, dim in enumerate(self.shape): + if isinstance(dim, tuple): + if idx != partition_dim: + raise ValueError( + f"AutoTP shape nested tuple only allowed at partition_dim={partition_dim}, got at {idx}.") + if nested_tuple_seen: + raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.") + nested_tuple_seen = True + if len(dim) == 0: + raise ValueError("AutoTP shape nested tuple cannot be empty.") + for val in dim: + if isinstance(val, tuple): + raise ValueError("AutoTP shape supports only 1-level nesting at partition_dim.") + if not isinstance(val, int) or val <= 0: + raise ValueError("AutoTP nested sub-parameter sizes must be positive integers.") + elif isinstance(dim, int): + if dim == 0 or dim < -1: + raise ValueError("AutoTP shape dimensions must be positive integers or -1.") + else: + raise ValueError("AutoTP shape must contain only integers or a tuple at partition_dim.") + + def get_partition_dim(self) -> int: + """Get effective partition dimension.""" + if self.partition_dim is not None: + return self.partition_dim + # Default based on partition type for 2D weight matrices + return 0 if self.partition_type == PartitionType.COLUMN else 1 + + def has_unequal_sub_params(self) -> bool: + """Check if this spec has unequal sub-parameters (nested tuple at partition_dim).""" + if self.shape is None: + return False + dim = self.get_partition_dim() + if dim >= len(self.shape): + return False + return isinstance(self.shape[dim], tuple) + + def get_sub_param_sizes(self) -> Optional[Tuple[int, ...]]: + """Get sub-parameter sizes if using unequal sub-params.""" + if not self.has_unequal_sub_params(): + return None + return self.shape[self.get_partition_dim()] + + def get_num_sub_params(self) -> Optional[int]: + """Get the number of sub-parameters.""" + if self.shape is None: + return None + dim = self.get_partition_dim() + if dim >= len(self.shape): + return None + if isinstance(self.shape[dim], tuple): + return len(self.shape[dim]) + elif isinstance(self.shape[dim], int) and self.shape[dim] > 0: + return self.shape[dim] + return None + + def matches(self, param_name: str, model_type: Optional[str] = None) -> bool: + """Check if this spec matches the given parameter.""" + # Check model type constraint + if self.model_types: + if model_type is None: + return False + model_type_norm = str(model_type).lower() + model_types_norm = [str(mt).lower() for mt in self.model_types] + if model_type_norm not in model_types_norm: + return False + # Check pattern match + return any(re.match(pattern, param_name) for pattern in self.patterns) + + +@dataclass +class AutoTPConfig: + """ + Configuration for Automatic Tensor Parallelism. + + Example usage: + config = AutoTPConfig( + tp_size=4, + layer_specs=[ + # Row-parallel layers (AllReduce after forward) + TPLayerSpec( + patterns=[".*\\.o_proj", ".*\\.down_proj"], + partition_type=PartitionType.ROW, + ), + # Column-parallel layers + TPLayerSpec( + patterns=[".*\\.[qkv]_proj", ".*\\.up_proj", ".*\\.gate_proj"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gates + TPLayerSpec( + patterns=[".*\\.gate$"], + partition_type=PartitionType.SKIP, + ), + ], + ) + """ + + tp_size: int = 1 + + # Unified layer specifications + layer_specs: List[TPLayerSpec] = field(default_factory=list) + + # Embedding configuration + embedding_partition_dim: int = 1 # Usually partition vocab dim + + # LM head configuration + lm_head_patterns: List[str] = field(default_factory=lambda: ["lm_head", "embed_out"]) + + # Behavior flags + use_default_specs: bool = True # Merge with built-in specs + strict_mode: bool = False # Fail if unmatched Linear layers found + + def find_matching_spec(self, param_name: str, model_type: Optional[str] = None) -> Optional[TPLayerSpec]: + """Find the first matching spec for a parameter.""" + matches = [spec for spec in self.layer_specs if spec.matches(param_name, model_type)] + if not matches: + return None + if len(matches) > 1: + matched_patterns = [spec.patterns for spec in matches] + warning_once(f"AutoTPConfig: parameter {param_name} matched multiple layer_specs {matched_patterns}; " + "using the first match.") + return matches[0] + + @classmethod + def from_dict(cls, config_dict: dict) -> "AutoTPConfig": + """Create config from dictionary (JSON config).""" + layer_specs = [] + for spec_dict in config_dict.get("layer_specs", []): + # Convert partition_type string to enum + partition_type_str = spec_dict.get("partition_type", "column") + if isinstance(partition_type_str, str): + partition_type = PartitionType(partition_type_str.lower()) + else: + partition_type = partition_type_str + + # Convert shape from list to tuple if necessary + shape = spec_dict.get("shape") + if shape is not None: + shape = cls._convert_shape(shape) + + layer_specs.append( + TPLayerSpec( + patterns=spec_dict.get("patterns", []), + partition_type=partition_type, + shape=shape, + partition_dim=spec_dict.get("partition_dim"), + model_types=spec_dict.get("model_types"), + )) + + return cls( + tp_size=config_dict.get("tp_size", 1), + layer_specs=layer_specs, + embedding_partition_dim=config_dict.get("embedding_partition_dim", 1), + lm_head_patterns=config_dict.get("lm_head_patterns", ["lm_head", "embed_out"]), + use_default_specs=config_dict.get("use_default_specs", True), + strict_mode=config_dict.get("strict_mode", False), + ) + + @staticmethod + def _convert_shape(shape): + """Convert shape from list to tuple, handling nested structures.""" + if isinstance(shape, list): + return tuple(AutoTPConfig._convert_shape(item) if isinstance(item, list) else item for item in shape) + return shape + + +class AutoTPPresets: + """Built-in presets for common model architectures.""" + + @staticmethod + def llama() -> AutoTPConfig: + """LLaMA-style models (separate Q, K, V projections).""" + return AutoTPConfig(layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def llama_gqa(num_heads: int, num_kv_heads: int, head_dim: int) -> AutoTPConfig: + """LLaMA with Grouped Query Attention (fused QKV variant).""" + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Fused QKV with unequal sizes (GQA) + TPLayerSpec( + patterns=[r".*\.self_attn\.qkv_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=((q_size, kv_size, kv_size), -1), # 1-level nesting + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def bloom() -> AutoTPConfig: + """BLOOM-style models (fused QKV with interleaved heads).""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attention\.dense\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attention\.query_key_value\.weight$"], + partition_type=PartitionType.COLUMN, + # Bloom style: [q1,k1,v1,q2,k2,v2,...] - no reshape needed + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def chatglm() -> AutoTPConfig: + """ChatGLM-style models (GLM-style fused QKV).""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attention\.dense\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attention\.query_key_value\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # [Q, K, V] concatenated + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_4h_to_h\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.dense_h_to_4h\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ), + ], ) + + @staticmethod + def mixtral() -> AutoTPConfig: + """Mixtral MoE model.""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # MoE experts + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w2\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.experts\.\d+\.w[13]\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gate + TPLayerSpec( + patterns=[r".*\.block_sparse_moe\.gate\.weight$"], + partition_type=PartitionType.SKIP, + ), + ], ) + + @staticmethod + def deepseek_v2() -> AutoTPConfig: + """DeepSeek-V2 with MLA (Multi-head Latent Attention).""" + return AutoTPConfig( + layer_specs=[ + # Standard attention output + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # MLA uses compressed KV, skip low-rank projections + TPLayerSpec( + patterns=[r".*\.self_attn\.(q_a_proj|kv_a_proj_with_mqa)\.weight$"], + partition_type=PartitionType.SKIP, + ), + # Q/K/V projections from latent + TPLayerSpec( + patterns=[r".*\.self_attn\.(q_b_proj|kv_b_proj)\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # MoE experts + TPLayerSpec( + patterns=[r".*\.mlp\.experts\.\d+\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.experts\.\d+\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + # Skip MoE gate + TPLayerSpec( + patterns=[r".*\.mlp\.gate\.weight$"], + partition_type=PartitionType.SKIP, + ), + # Shared expert + TPLayerSpec( + patterns=[r".*\.mlp\.shared_experts\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.shared_experts\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def qwen2() -> AutoTPConfig: + """Qwen2 model.""" + return AutoTPConfig(layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.self_attn\.[qkv]_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.(up|gate)_proj\.weight$"], + partition_type=PartitionType.COLUMN, + ), + ], ) + + @staticmethod + def phi3() -> AutoTPConfig: + """Phi3 model with fused QKV and chunked MLP.""" + return AutoTPConfig( + layer_specs=[ + TPLayerSpec( + patterns=[r".*\.self_attn\.o_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Phi3 has fused qkv_proj + TPLayerSpec( + patterns=[r".*\.self_attn\.qkv_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(3, -1), # [Q, K, V] concatenated + partition_dim=0, + ), + TPLayerSpec( + patterns=[r".*\.mlp\.down_proj\.weight$"], + partition_type=PartitionType.ROW, + ), + # Phi3 has gate_up_proj fused + TPLayerSpec( + patterns=[r".*\.mlp\.gate_up_proj\.weight$"], + partition_type=PartitionType.COLUMN, + shape=(2, -1), # [gate, up] packed + partition_dim=0, + ), + ], ) + + @staticmethod + def get_preset(model_type: str) -> Optional[AutoTPConfig]: + """Get a preset configuration by model type name.""" + presets = { + "llama": AutoTPPresets.llama, + "bloom": AutoTPPresets.bloom, + "chatglm": AutoTPPresets.chatglm, + "mixtral": AutoTPPresets.mixtral, + "deepseek_v2": AutoTPPresets.deepseek_v2, + "qwen2": AutoTPPresets.qwen2, + "phi3": AutoTPPresets.phi3, + } + preset_fn = presets.get(model_type.lower()) + if preset_fn: + return preset_fn() + return None + + +def merge_autotp_configs(base: AutoTPConfig, override: AutoTPConfig) -> AutoTPConfig: + """Merge two AutoTP configs, with override taking precedence.""" + # Combine layer specs - override specs come first (higher priority) + merged_specs = list(override.layer_specs) + list(base.layer_specs) + + return AutoTPConfig( + tp_size=override.tp_size if override.tp_size > 1 else base.tp_size, + layer_specs=merged_specs, + embedding_partition_dim=override.embedding_partition_dim, + lm_head_patterns=override.lm_head_patterns or base.lm_head_patterns, + use_default_specs=override.use_default_specs, + strict_mode=override.strict_mode, + ) diff --git a/deepspeed/module_inject/containers/__init__.py b/deepspeed/module_inject/containers/__init__.py index 4655b29b5ba6..993d14071659 100644 --- a/deepspeed/module_inject/containers/__init__.py +++ b/deepspeed/module_inject/containers/__init__.py @@ -10,6 +10,9 @@ from .gptj import DS_GPTJContainer, HFGPTJLayerPolicy from .gptneo import DS_GPTNEOContainer, HFGPTNEOLayerPolicy from .gptneox import DS_GPTNEOXContainer, GPTNEOXLayerPolicy +from .llama import DS_LLAMAContainer, LLAMALayerPolicy +from .llama2 import LLAMA2LayerPolicy, DS_LLAMA2Container +from .internlm import DS_InternLMContainer, InternLMLayerPolicy from .megatron_gpt import DS_MegatronGPTContainer, MegatronLayerPolicy from .megatron_gpt_moe import DS_MegatronGPTMoEContainer, MegatronMoELayerPolicy from .opt import DS_OPTContainer, HFOPTLayerPolicy diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py index 20a664668f87..83e109167ffe 100644 --- a/deepspeed/module_inject/containers/base.py +++ b/deepspeed/module_inject/containers/base.py @@ -5,11 +5,17 @@ # Create a container object to save model-specific tensors using the policy file above. from abc import ABC + import torch +import deepspeed from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig from deepspeed.accelerator import get_accelerator +# If the intermediate size attribute is set DEFAULT_INTERMEDIATE_SIZE +# it is assumed the intermediate size is 4x the embedding dimension +DEFAULT_INTERMEDIATE_SIZE = -1 + class BaseConvolutionContainer(ABC): # not implemented @@ -32,11 +38,12 @@ def __init__(self, policy, config, model_config, layer_id, child): # configuration for models. todo: can this be moved to a pydantic model config? self.hidden_size = None + self.intermediate_size = None self.num_attention_heads = None self.mp_size = self.config.tensor_parallel.tp_size self.pre_layer_norm = self.model_config.do_layer_norm_before if \ hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm - self.fp16 = False + self.dtype = self.config.dtype self.attn_linear_layer = self.policy.linear_layer self.mlp_linear_layer = self.policy.linear_layer self.return_tuple = self.config.return_tuple @@ -45,6 +52,7 @@ def __init__(self, policy, config, model_config, layer_id, child): self.model_config, 'attention_layers') else False) self.window_size = getattr(self.model_config, "window_size", 1) self.mlp_act_func_type = self.policy.mlp_act_func_type + self.norm_type = self.policy.norm_type self.training_mp_size = self.config.training_mp_size self.bigscience_bloom = False self.max_out_tokens = self.config.max_out_tokens @@ -52,9 +60,7 @@ def __init__(self, policy, config, model_config, layer_id, child): self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False) self.use_mup = self.policy.use_mup self.return_single_tuple = False - self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \ - else self.child.attention.rotary_ndims if \ - hasattr(self.child, 'attention') and hasattr(self.child.attention,'rotary_ndims') else -1 + self.rotary_dim = self.get_rotary_dim() self.mlp_after_attn = (self.rotary_dim is None or self.rotary_dim < 0) # Attention tensors @@ -74,6 +80,10 @@ def __init__(self, policy, config, model_config, layer_id, child): self.input_nb = None self.mp_group = None + self.use_triton = False + + # Triton + self.use_triton = config.use_triton and deepspeed.HAS_TRITON def create_ds_model_config(self): self.set_hidden_heads(*self.policy.get_hidden_heads()) @@ -83,12 +93,13 @@ def create_ds_model_config(self): self.ds_model_config = DeepSpeedInferenceConfig( hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, heads=self.num_attention_heads, layer_norm_eps=self.layernorm_epsilon, - fp16=self.fp16, + dtype=self.dtype, pre_layer_norm=self.pre_layer_norm, + norm_type=self.norm_type, mp_size=self.mp_size, - q_int8=self.quantize if hasattr(self, 'quantize') else False, return_tuple=self.return_tuple, triangular_masking=self.triangular_masking, local_attention=self.local_attention, @@ -104,34 +115,53 @@ def create_ds_model_config(self): use_mup=self.use_mup, return_single_tuple=self.return_single_tuple, set_empty_params=self.config.set_empty_params, - transposed_mode=self.config.transposed_mode) + transposed_mode=self.config.transposed_mode, + use_triton=self.use_triton, + triton_autotune=self.config.triton_autotune) + + if self.use_triton and deepspeed.HAS_TRITON: + from .bert import DS_BERTContainer + if not isinstance(self, DS_BERTContainer): + raise NotImplementedError("Triton kernels are only for BERT-like models yet") + + if not self.config.triton_autotune: + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() return self.ds_model_config + def check_meta_tensor_support(self): + if hasattr(self.qkvw, 'is_meta'): + if self.qkvw.is_meta: + assert self.ckpt_load_enabled, "Meta tensors are not supported for this model currently." + else: + raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+") + def initialize_tensors(self, enable_training=False): # Set the tensors from policy (user module) to container (DS module) self.set_attention(*self.policy.attention(enable_training=enable_training)) - self.set_mlp(*self.policy.mlp()) + self.set_mlp(*self.policy.mlp(enable_training=enable_training)) self.set_layernorm(*self.policy.layernorm()) - self.set_lora_params(self.policy.get_lora_params()) - self.q_k_v = self.policy.get_q_k_v() - if self.q_k_v is not None: - self.set_q_k_v(*self.q_k_v) + #self.check_meta_tensor_support() - def convert_to_required_dtype(self, dtype): + def convert_to_required_dtype(self): # Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy - if dtype == torch.half: + if self.dtype in [torch.half, torch.bfloat16]: for k, v in self.__dict__.items(): # The list comprehension is used for MoE tensor lists if isinstance(v, list) and all((isinstance(tensor, torch.Tensor) \ or isinstance(tensor, torch.nn.Parameter)) for tensor in v): - self.__dict__[k] = [moe_tensor.half() for moe_tensor in v] + self.__dict__[k] = [moe_tensor.to(self.dtype) for moe_tensor in v] if isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter): - self.__dict__[k] = v.half() + self.__dict__[k] = v.to(self.dtype) - def set_dtype(self, fp16=False): - self.fp16 = fp16 + def get_rotary_dim(self): + if hasattr(self.model_config, 'rotary_dim'): + return self.model_config.rotary_dim + if hasattr(self.child, 'attention') and hasattr(self.child.attention, 'rotary_ndims'): + return self.child.attention.rotary_ndims + return -1 def set_moe(self, moe=False): self.moe = moe @@ -140,12 +170,23 @@ def set_tensor_parallel_config(self, mp_size, mp_group): self.mp_size = mp_size self.mp_group = mp_group - def set_quantization_config(self, quantize, quantizer): - self.quantize = quantize + def set_quantization_config(self, quantizer): self.quantizer = quantizer - def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon): + def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon, intermediate_size): + """ + Args: + hidden_size: embedding dimension of the model + num_attention_heads: number of attention heads in the model + epsilon: epsilon value for layer norm (same value used for all norms) + intermediate_size: Size of MLP projection. If `DEFAULT_INTERMEDIATE_SIZE` is passed + it is assumed to be `4 * hidden_size` + """ self.hidden_size = hidden_size + if intermediate_size == DEFAULT_INTERMEDIATE_SIZE: + self.intermediate_size = 4 * hidden_size + else: + self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.layernorm_epsilon = epsilon @@ -155,17 +196,6 @@ def set_attention(self, qkvw, qkvb, dense_w, dense_b): self.dense_w = dense_w self.dense_b = dense_b - def set_lora_params(self, lora_params): - self.lora_params = lora_params - - def set_q_k_v(self, qw, qb, kw, kb, vw, vb): - self.qw = qw - self.qb = qb - self.kw = kw - self.kb = kb - self.vw = vw - self.vb = vb - def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b): self._h4h_w = _h4h_w self._h4h_b = _h4h_b @@ -193,159 +223,63 @@ def mlp_quantization(self): self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w) self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w) - def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None): - reversed_dim = False - if mp_replace is None: - from deepspeed.module_inject import ReplaceWithTensorSlicing - mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=tp_size, out_dim=0, in_dim=1) - reversed_dim = True + def apply_tensor_parallelism(self, mp_replace): # setup the new Attention module - if self.module.attention.attn_qkvw is None: - self.attention_q_k_v_mp(mp_replace, reversed_dim=reversed_dim) - else: - self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim) - self.attention_o_mp(mp_replace, reversed_dim=reversed_dim) + self.attention_qkv_mp(mp_replace) + self.attention_o_mp(mp_replace) # setup the new MLP module - self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim) - self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim) + self.mlp_inter_mp(mp_replace) + self.mlp_output_mp(mp_replace) # Apply weight quantization + # TODO(cmikeh2): Re-enable this once verified #self.apply_weight_quantization() def attention_qkv_mp(self, mp_replace, reversed_dim=False): - self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, - self.qkvw, - int8=reversed_dim) - self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, - self.qkvb, - int8=reversed_dim) - - def attention_q_k_v_mp(self, mp_replace, reversed_dim=False): - self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] // - mp_replace.mp_size], - self.qw, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] // - mp_replace.mp_size], - self.kw, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] // - mp_replace.mp_size], - self.vw, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] // - mp_replace.mp_size], - self.qb, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] // - mp_replace.mp_size], - self.kb, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] // - mp_replace.mp_size], - self.vb, - int8=reversed_dim, - allocat_tensor=reversed_dim) + self.module.attention.attn_qkvw = mp_replace.strided_copy(self.module.attention.attn_qkvw, + self.qkvw, + num_splits=3, + int8=reversed_dim) + self.module.attention.attn_qkvb = mp_replace.strided_copy(self.module.attention.attn_qkvb, + self.qkvb, + num_splits=3, + int8=reversed_dim) def attention_o_mp(self, mp_replace, reversed_dim=False): - if reversed_dim: - self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow[:, :self.dense_w.shape[1] // - mp_replace.mp_size], - self.dense_w, - int8=reversed_dim, - allocat_tensor=reversed_dim) - else: - self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, - self.dense_w, - int8=reversed_dim) + self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w, int8=reversed_dim) self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.dense_b, int8=reversed_dim, - allocat_tensor=reversed_dim) + allocate_tensor=reversed_dim) def mlp_inter_mp(self, mp_replace, reversed_dim=False): - if reversed_dim: - self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w[:self._h4h_w.shape[0] // - mp_replace.mp_size], - self._h4h_w, - int8=reversed_dim, - allocat_tensor=reversed_dim) - self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b[:self._h4h_w.shape[0] // - mp_replace.mp_size], - self._h4h_b, - int8=reversed_dim, - allocat_tensor=reversed_dim) - else: - self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim) - self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim) + self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim) + self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim) def mlp_output_mp(self, mp_replace, reversed_dim=False): - if reversed_dim: - self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w[:, :self._4hh_w.shape[1] // - mp_replace.mp_size], - self._4hh_w, - int8=reversed_dim, - allocat_tensor=reversed_dim) - else: - self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim) + self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim) self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b, int8=reversed_dim, - allocat_tensor=reversed_dim) - - def release_qkv(self): - del self.module.attention.attn_qkvw - del self.module.attention.attn_qkvb - self.module.attention.attn_qkvw = None - self.module.attention.attn_qkvb = None - - qkv_data = [self.module.attention.attn_qw.data, \ - self.module.attention.attn_qb.data, \ - self.module.attention.attn_kw.data, \ - self.module.attention.attn_kb.data, \ - self.module.attention.attn_vw.data, \ - self.module.attention.attn_vb.data] - for data in qkv_data: - del data - - self.module.attention.attn_qw = self.qw - self.module.attention.attn_qb = self.qb - self.module.attention.attn_kw = self.kw - self.module.attention.attn_kb = self.kb - self.module.attention.attn_vw = self.vw - self.module.attention.attn_vb = self.vb - - def release_memory(self): - self.release_qkv() - del self.module.attention.attn_ow - del self.module.attention.attn_ob - self.module.attention.attn_ow = self.dense_w - self.module.attention.attn_ob = self.dense_b - del self.module.mlp.inter_w - del self.module.mlp.inter_b - del self.module.mlp.output_w - del self.module.mlp.output_b - self.module.mlp.inter_w = self._h4h_w - self.module.mlp.inter_b = self._h4h_b - self.module.mlp.output_w = self._4hh_w - self.module.mlp.output_b = self._4hh_b + allocate_tensor=reversed_dim) def copy_data_to_new_module(self): - if self.attn_nw is None: - self.module.mlp.attn_nw = self.attn_nw - self.module.mlp.attn_nb = self.attn_nb - else: - self.module.mlp.attn_nw.data.copy_(self.attn_nw.to(get_accelerator().current_device_name())) - self.module.mlp.attn_nb.data.copy_(self.attn_nb.to(get_accelerator().current_device_name())) + params = {'attn_nw': self.attn_nw, 'attn_nb': self.attn_nb} + for key in params: + if params[key] is None: + setattr(self.module.mlp, key, None) + else: + setattr(self.module.mlp, key, + torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name()))) - self.module.norm_w.data.copy_(self.input_nw.to(get_accelerator().current_device_name())) - self.module.norm_b.data.copy_(self.input_nb.to(get_accelerator().current_device_name())) + params = {'norm_w': self.input_nw, 'norm_b': self.input_nb} + for key in params: + if params[key] is None: + setattr(self.module, key, None) + else: + setattr(self.module, key, + torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name()))) def transpose(self): self.transpose_attention() @@ -368,105 +302,21 @@ def transpose_impl(self, data): data.to(get_accelerator().current_device_name()) return data - def reset_qkv_experimental(self): - if self.module.attention.attn_qkvw is None: - self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3, - self.qw.shape[0], - dtype=self.qw.dtype, - device=self.qw.device) - self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3, - dtype=self.qw.dtype, - device=self.qw.device) - self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data - self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data - self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data - self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data - self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data - self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data - - qkv_data = [self.qw.data, \ - self.qb.data, \ - self.kw.data, \ - self.kb.data, \ - self.vw.data, \ - self.vb.data] - - self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]] - self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]] - self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] - self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] - self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] - self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] - - for data in qkv_data: - del data - - def reset_qkv(self): - self.qkvw.data[:self.qw.shape[0]] = self.qw.data - self.qkvb.data[:self.qw.shape[0]] = self.qb.data - self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data - self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data - self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data - self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data - - qkv_data = [self.qw.data, \ - self.qb.data, \ - self.kw.data, \ - self.kb.data, \ - self.vw.data, \ - self.vb.data] - - self.qw.data = self.qkvw.data[:self.qw.shape[0]] - self.qb.data = self.qkvb.data[:self.qw.shape[0]] - self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] - self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] - self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:] - self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:] - - for data in qkv_data: - del data - - def set_params_wo_copy(self, Z3_enabled=False): - self.module.mlp.attn_nw = self.attn_nw - self.module.mlp.attn_nb = self.attn_nb - self.module.norm_w = self.input_nw - self.module.norm_b = self.input_nb - self.module.mlp.inter_w = self._h4h_w - self.module.mlp.inter_b = self._h4h_b - self.module.mlp.output_w = self._4hh_w - self.module.mlp.output_b = self._4hh_b - self.module.attention.attn_ow = self.dense_w - self.module.attention.attn_ob = self.dense_b - if not Z3_enabled or self.q_k_v is None: - self.module.attention.attn_qkvw = self.qkvw - self.module.attention.attn_qkvb = self.qkvb - if self.q_k_v is not None: - if Z3_enabled: - self.module.attention.attn_qw = self.qw - self.module.attention.attn_qb = self.qb - self.module.attention.attn_kw = self.kw - self.module.attention.attn_kb = self.kb - self.module.attention.attn_vw = self.vw - self.module.attention.attn_vb = self.vb - else: - self.qw.data = self.qkvw[:self.qw.shape[0], :] - self.qb.data = self.qkvb[:self.qw.shape[0]] - self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :] - self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]] - self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :] - self.vb.data = self.qkvb[self.qw.shape[0] * 2:] + def get_all_params(self): + params = [ + self.attn_nw, + self.attn_nb, + self.input_nw, + self.input_nb, + ] - def get_lora_params(self): - return self.lora_params + params.extend(self.get_attn_params()) + params.extend(self.get_mlp_params()) - def get_all_params(self): - if self.q_k_v is not None: - return [ - self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w, - self._4hh_b, self.qw, self.qb, self.kw, self.kb, self.vw, self.vb, self.dense_w, self.dense_b - ] - else: - return [ - self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w, - self._4hh_b, self.qkvw, self.qkvb, self.dense_w, self.dense_b - ] + return params + + def get_attn_params(self): + return [self.qkvw, self.qkvb, self.dense_w, self.dense_b] + + def get_mlp_params(self): + return [self._h4h_w, self._h4h_b, self._4hh_w, self._4hh_b] diff --git a/deepspeed/module_inject/containers/bert.py b/deepspeed/module_inject/containers/bert.py index f8070655283e..2bb520e7449d 100644 --- a/deepspeed/module_inject/containers/bert.py +++ b/deepspeed/module_inject/containers/bert.py @@ -18,6 +18,7 @@ def __init__(self, **kwargs): # All model specific things should be defined here instead of the base class. self.return_tuple = True self.triangular_masking = False + self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON def create_module(self, config=None): _config = config if config is not None else self.ds_model_config @@ -40,7 +41,7 @@ def __init__(self, client_module, inference=False): transformers.models.bert.modeling_bert.BertLayer, transformers.models.roberta.modeling_roberta.RobertaLayer ] - except: + except Exception: HFBertLayerPolicy._orig_layer_class = None def get_hidden_heads(self): @@ -50,10 +51,8 @@ def get_hidden_heads(self): attention_layernorm = self.client_module.attention.output.LayerNorm return self.client_module.attention.self.query.weight.shape[1], \ self.client_module.attention.self.num_attention_heads, \ - attention_layernorm.eps - - def get_q_k_v(self): - return None + attention_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): qw = self.client_module.attention.self.query.weight @@ -71,7 +70,7 @@ def attention(self, enable_training=False): self.client_module.attention.output.dense.weight, \ self.client_module.attention.output.dense.bias, \ - def mlp(self): + def mlp(self, enable_training=False): if self.pre_attn_norm: intermediate_ff = self.client_module.intermediate.dense_act else: @@ -92,6 +91,3 @@ def layernorm(self): attention_layernorm.bias, \ transformer_layernorm.weight, \ transformer_layernorm.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index 7bcf6943de60..5ce96f023675 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -5,33 +5,68 @@ from .base import * from .features.meta_tensor import MetaTensorContainer +from .features.hybrid_engine import HybridEngineContainer from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference from ..policy import TransformerPolicy from ..policy import transformer_param_names from ..policy import maybe_copy +from ..policy import maybe_get_lora + supported_models = {None} -class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer): +class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer): def __init__(self, **kwargs): + # Check transformers version, error if > 4.43.4 (breaks at 4.44.0) + from importlib.metadata import version + v_transformers = version('transformers') + vers = v_transformers.split('.') + major = int(vers[0]) + minor = int(vers[1]) + if major > 4 or (major == 4 and minor > 43): + raise RuntimeError( + f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported." + ) + super().__init__(**kwargs) # All model specific things should be defined here instead of the base class. self.bigscience_bloom = True + self.triangular_masking = False def create_module(self, config=None): _config = config if config is not None else self.ds_model_config self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group) self.module.config.scale_attention = self.scale_attention + self.module.config.invert_mask = False return self.module def attention_qkv_mp(self, mp_replace, reversed_dim=False): self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw) self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb) + def get_lora_matched_pair(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params() + ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)] + return ret + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, self.policy. + client_module.self_attention.query_key_value, self.policy.client_module.self_attention.dense + ] + ] + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'self_attention.query_key_value.weight', \ @@ -85,10 +120,8 @@ def __init__(self, client_module, inference=True, use_load_prefix=True, split_qk def get_hidden_heads(self): return self.client_module.self_attention.hidden_size, \ self.client_module.self_attention.num_heads, \ - self.client_module.input_layernorm.eps - - def get_q_k_v(self): - return None + self.client_module.input_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): return self.client_module.self_attention.query_key_value.weight, \ @@ -96,7 +129,7 @@ def attention(self, enable_training=False): self.client_module.self_attention.dense.weight, \ self.client_module.self_attention.dense.bias, - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.dense_h_to_4h.weight, \ self.client_module.mlp.dense_h_to_4h.bias, \ self.client_module.mlp.dense_4h_to_h.weight, \ @@ -107,6 +140,3 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/clip.py b/deepspeed/module_inject/containers/clip.py index 144f1b823a1a..8cf42267caa8 100644 --- a/deepspeed/module_inject/containers/clip.py +++ b/deepspeed/module_inject/containers/clip.py @@ -35,18 +35,16 @@ def __init__(self, client_module, inference=False): try: import transformers HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer - except: + except Exception: HFCLIPLayerPolicy._orig_layer_class = None def get_hidden_heads(self): return self.client_module.self_attn.q_proj.weight.shape[1], \ self.client_module.self_attn.num_heads, \ - self.client_module.layer_norm1.eps + self.client_module.layer_norm1.eps, \ + DEFAULT_INTERMEDIATE_SIZE - def get_q_k_v(self): - return None - - def attention(self): + def attention(self, enable_training=False): qw = self.client_module.self_attn.q_proj.weight qb = self.client_module.self_attn.q_proj.bias kw = self.client_module.self_attn.k_proj.weight @@ -54,15 +52,15 @@ def attention(self): vw = self.client_module.self_attn.v_proj.weight vb = self.client_module.self_attn.v_proj.bias - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) - qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ self.client_module.self_attn.out_proj.weight, \ self.client_module.self_attn.out_proj.bias - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.fc1.weight, \ self.client_module.mlp.fc1.bias, \ self.client_module.mlp.fc2.weight, \ @@ -73,6 +71,3 @@ def layernorm(self): self.client_module.layer_norm2.bias, \ self.client_module.layer_norm1.weight, \ self.client_module.layer_norm1.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/distil_bert.py b/deepspeed/module_inject/containers/distil_bert.py index 792b965399e2..37107fd28506 100644 --- a/deepspeed/module_inject/containers/distil_bert.py +++ b/deepspeed/module_inject/containers/distil_bert.py @@ -18,6 +18,7 @@ def __init__(self, **kwargs): # All model specific things should be defined here instead of the base class. self.triangular_masking = False self.return_single_tuple = True + self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON def create_module(self, config=None): _config = config if config is not None else self.ds_model_config @@ -40,16 +41,14 @@ def __init__(self, client_module, inference=False, preln=False): HFDistilBertLayerPolicy._orig_layer_class = [ transformers.models.distilbert.modeling_distilbert.TransformerBlock, ] - except: + except Exception: HFDistilBertLayerPolicy._orig_layer_class = None def get_hidden_heads(self): return self.client_module.attention.q_lin.weight.shape[1], \ self.client_module.attention.n_heads, \ - self.client_module.sa_layer_norm.eps - - def get_q_k_v(self): - return None + self.client_module.sa_layer_norm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): qw = self.client_module.attention.q_lin.weight @@ -67,7 +66,7 @@ def attention(self, enable_training=False): self.client_module.attention.out_lin.weight, \ self.client_module.attention.out_lin.bias - def mlp(self): + def mlp(self, enable_training=False): intermediate_ff = self.client_module.ffn.lin1 return intermediate_ff.weight, intermediate_ff.bias, \ @@ -81,6 +80,3 @@ def layernorm(self): attention_layernorm.bias, \ transformer_layernorm.weight, \ transformer_layernorm.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/features/__init__.py b/deepspeed/module_inject/containers/features/__init__.py index 9bf65591925d..fc2eb2a65531 100644 --- a/deepspeed/module_inject/containers/features/__init__.py +++ b/deepspeed/module_inject/containers/features/__init__.py @@ -3,5 +3,7 @@ # DeepSpeed Team +from .gated_mlp import HybridGatedMLPContainer from .megatron import MegatronContainer from .meta_tensor import MetaTensorContainer +from .split_qkv import HybridSplitQKVContainer diff --git a/deepspeed/module_inject/containers/features/gated_mlp.py b/deepspeed/module_inject/containers/features/gated_mlp.py new file mode 100644 index 000000000000..24f0826db14e --- /dev/null +++ b/deepspeed/module_inject/containers/features/gated_mlp.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod + +from .hybrid_engine import HybridEngineContainer + + +class HybridGatedMLPContainer(HybridEngineContainer): + """ + The HybridGatedMLPContainer supports models for which the first MLP layer + is represented with two separate weights, one for the activation function + and one for the gating function. + """ + + def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b): + super().set_mlp(_h4h_w, _h4h_b, _4hh_w, _4hh_b) + self.set_mlp_gate() + + @abstractmethod + def set_mlp_gate(self): + """ + In `set_mlp_gate`, it is necessary to populate the following variables (where appropriate) + for the given model: + self.inter_up_w: inter up weight + self.inter_up_b: inter up bias + self.inter_gate_w: inter gate weight + self.inter_gate_b: inter gate bias + If the parameter does not exist in the original model, set the attribute to None. + """ + raise NotImplementedError("A set_mlp_gate() function must be defined in the model container \ + in order to set the unfused inter up and gate tensors.") + + def mlp_inter_mp(self, mp_replace, reversed_dim=False): + # Only need to alter behavior if we can't do the normal destructive copy + if self.module.mlp.inter_w is None: + params = [ + (self.module.mlp.inter_up_w, self.inter_up_w), + (self.module.mlp.inter_up_b, self.inter_up_b), + (self.module.mlp.inter_gate_w, self.inter_gate_w), + (self.module.mlp.inter_gate_b, self.inter_gate_b), + ] + for dst, src in params: + dst = mp_replace.copy(dst[:self.inter_up_w.shape[0] // mp_replace.mp_size], + src, + int8=reversed_dim, + allocate_tensor=reversed_dim) if src is not None else None + else: + self.module.mlp.inter_w = mp_replace.strided_copy(self.module.mlp.inter_w, + self._h4h_w, + num_splits=2, + int8=reversed_dim) + self.module.mlp.inter_b = mp_replace.strided_copy(self.module.mlp.inter_b, + self._h4h_b, + num_splits=2, + int8=reversed_dim) + + def release_mlp(self): + super().release_mlp() + gated_mlp_params = [ + (self.module.mlp.inter_up_w, self.inter_up_w), + (self.module.mlp.inter_up_b, self.inter_up_b), + (self.module.mlp.inter_gate_w, self.inter_gate_w), + (self.module.mlp.inter_gate_b, self.inter_gate_b), + ] + + self._release_params(gated_mlp_params) + + def reset_mlp(self): + self._h4h_w.data[:self.inter_up_w.shape[0]] = self.inter_up_w.data + self._h4h_w.data[self.inter_up_w.shape[0]:] = self.inter_gate_w.data + + if self.inter_up_b is not None: + self._h4h_b.data[:self.inter_up_b.shape[0]] = self.inter_up_b.data + self._h4h_b.data[self.inter_up_b.shape[0]:] = self.inter_gate_b.data + + inter_data = [self.inter_up_w.data, self.inter_gate_w.data] + if self.inter_up_b is not None: + inter_data.extend([self.inter_up_b.data, self.inter_gate_b.data]) + + self.inter_up_w.data = self._h4h_w.data[:self.inter_up_w.shape[0]] + self.inter_gate_w.data = self._h4h_w.data[self.inter_up_w.shape[0]:] + + if self.inter_up_b is not None: + self.inter_up_b.data = self._h4h_b.data[:self.inter_up_b.shape[0]] + self.inter_gate_b.data = self._h4h_b.data[self.inter_up_b.shape[0]:] + + for data in inter_data: + del data + + def set_mlp_params_wo_copy(self, Z3_enabled=False): + self.module.mlp.output_w = self._4hh_w + self.module.mlp.output_b = self._4hh_b + + if not Z3_enabled: + # In initialize_tensors, we create a fused inter projection with the appropriate shape + # and copy the up projection and gate projection into it + self.module.mlp.inter_w = self._h4h_w + self.module.mlp.inter_b = self._h4h_b + + self.inter_up_w.data = self._h4h_w[:self.inter_up_w.shape[0], :] + self.inter_gate_w.data = self._h4h_w[self.inter_up_w.shape[0]:, :] + + if self.inter_up_b is not None: + self.inter_up_b.data = self._h4h_b[:self.inter_up_w.shape[0]] if self._h4h_b is not None else None + self.inter_gate_b.data = self._h4h_b[self.inter_up_w.shape[0]:] if self._h4h_b is not None else None + else: + self.module.mlp.inter_up_w = self.inter_up_w + self.module.mlp.inter_up_b = self.inter_up_b + self.module.mlp.inter_gate_w = self.inter_gate_w + self.module.mlp.inter_gate_b = self.inter_gate_b + + def get_mlp_params(self): + params = super().get_mlp_params() + params.extend([self.inter_up_w, self.inter_up_b, self.inter_gate_w, self.inter_gate_b]) + return params diff --git a/deepspeed/module_inject/containers/features/hybrid_engine.py b/deepspeed/module_inject/containers/features/hybrid_engine.py new file mode 100644 index 000000000000..3f702abcf319 --- /dev/null +++ b/deepspeed/module_inject/containers/features/hybrid_engine.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import List, Tuple + +import torch + + +class HybridEngineContainer(ABC): + """ + This container identifies which methods need to be overridden in addition to + the base container to enable use in the RLHF pipeline. These methods are not + necessary for inference alone. + + NOTE: If you are using this feature with a container that + also inherits from `MetaTensorContainer`, ensure that `MetaTensorContainer` + is inherited before `HybridEngineContainer` in the class definition. + """ + + def initialize_tensors(self, enable_training=False): + """ + Same purposes as the base container, but also grabs the hooks for any LoRA + parameters. If it's necessary to override specific sub-components of the model, + it's best to augment the specific `set_[component]` itself rather than modifying + the `initialize_tensors` method. See the `HybridSplitQKVContainer` for an example. + """ + super().initialize_tensors(enable_training=enable_training) + self.set_lora_params() + + def transform_for_training(self): + """ + If the views on certain parameters are largely incompatible, it may be necessary to do + more substantial transformations to the parameters. This method should be overridden to + transform the inference format to what is necessary for training. + """ + pass + + def transform_for_inference(self): + """ + If the views on certain parameters are largely incompatible, it may be necessary to do + more substantial transformations to the parameters. This method should be overridden to + transform the training format to what is necessary for inference. + """ + pass + + @abstractmethod + def set_lora_params(self): + """ + If available, set the LoRA parameters for the module. An implementation + for this would iterate over all parameters of the model and use the `maybe_get_lora` helper + method to check if the parameter does in fact have any LoRA params. + """ + raise NotImplementedError("A set_lora_params() function must be defined for the relevant parameters.") + + @abstractmethod + def get_lora_matched_pair(self): + """Get the pair of lora params and its matched model parameters.""" + raise NotImplementedError("get_lora_matched_pair() must be defined for the relevant parameters.") + + def fuse_lora(self): + """Fuse the LoRA parameters for the inference mode.""" + for maybe_lora_param, param in self.get_lora_matched_pair(): + if len(maybe_lora_param) == 3: + lora_right_weight, \ + lora_left_weight, \ + lora_scaling = maybe_lora_param + param.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + + def unfuse_lora(self): + """Unfuse the LoRA parameters for the training mode.""" + for maybe_lora_param, param in self.get_lora_matched_pair(): + if len(maybe_lora_param) == 3: + lora_right_weight, \ + lora_left_weight, \ + lora_scaling = maybe_lora_param + param.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + + def apply_tensor_parallelism(self, mp_replace, reversed_dim=False): + """ + Add support for reversed dim in tensor parallelism. If necessary, override + the called methods to handle partitioned weights (i.e. if qkv is split, override + the `attention_qkv_mp` method). If the model component is not split, it should + be safe to use the default implementation. + """ + # Setup the new Attention module + self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim) + self.attention_o_mp(mp_replace, reversed_dim=reversed_dim) + + # Setup the new MLP module + self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim) + self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim) + + # Apply weight quantization + # TODO(cmikeh2): Re-enable this once verified + #self.apply_weight_quantization() + + def _release_params(self, param_pairs: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Helper for `release_[component]` methods. Accepts a list of tuples where the first + element is the module param that needs to be deleted, and the second is the reassignment + from the container. + """ + for module_param, container_param in param_pairs: + if module_param is not None: + del module_param + module_param = container_param + + def release_memory(self): + """ + Delete module parameters if they exist and point them back to the container. The primary + purpose of this is for TP-inference with ZeRO-3. In this scenario, we need to delete the + parameters we've created for inference to free their memory. + """ + general_params = [ + (self.module.attention.attn_ow, self.dense_w), + (self.module.attention.attn_ob, self.dense_b), + (self.module.mlp.attn_nw, self.attn_nw), + (self.module.mlp.attn_nb, self.attn_nb), + (self.module.norm_w, self.input_nw), + (self.module.norm_b, self.input_nb), + ] + + self._release_params(general_params) + + self.release_qkv() + self.release_mlp() + + def release_qkv(self): + """ + Release for QKV parameters (as well as any aliases). + """ + qkv_params = [ + (self.module.attention.attn_qkvw, self.qkvw), + (self.module.attention.attn_qkvb, self.qkvb), + ] + + self._release_params(qkv_params) + + def release_mlp(self): + """ + Release for MLP parameters (as well as any aliases). + """ + mlp_params = [ + (self.module.mlp.inter_w, self._h4h_w), + (self.module.mlp.inter_b, self._h4h_b), + (self.module.mlp.output_w, self._4hh_w), + (self.module.mlp.output_b, self._4hh_b), + ] + + self._release_params(mlp_params) + + def reset_params(self): + """ + The purpose of reset params is to get the weights from the FP16 training + copy of the model and copy to them to contiguous inference view. This only needs + to be performed when the container parameters cannot be used directly for inference. + """ + self.reset_qkv() + self.reset_mlp() + + def reset_qkv(self): + """ + Perform any necessary resets of the model parameters for the QKV components. + """ + pass + + def reset_mlp(self): + """ + Perform any necessary resets of the model parameters for the MLP components. + """ + pass + + def get_lora_params(self): + """ + Return a list of all parameters that would have LoRA for the module. + """ + if not hasattr(self, "lora_params"): + self.set_lora_params() + return self.lora_params + + def set_params_wo_copy(self, Z3_enabled=False): + """ + Rather than copying into, set the parameters directly. This is necessary to provide + an inexpensive (low-memory-overhead) view onto the FP16 forward weights. + """ + self.module.mlp.attn_nw = self.attn_nw + self.module.mlp.attn_nb = self.attn_nb + self.module.norm_w = self.input_nw + self.module.norm_b = self.input_nb + self.set_attn_params_wo_copy(Z3_enabled=Z3_enabled) + self.set_mlp_params_wo_copy(Z3_enabled=Z3_enabled) + + def set_attn_params_wo_copy(self, **kwargs): + """ + Narrower sub-method for finer grained overriding. + """ + self.module.attention.attn_ow = self.dense_w + self.module.attention.attn_ob = self.dense_b + self.module.attention.attn_qkvw = self.qkvw + self.module.attention.attn_qkvb = self.qkvb + + def set_mlp_params_wo_copy(self, **kwargs): + """ + Narrower sub-method for finer grained overriding. + """ + self.module.mlp.inter_w = self._h4h_w + self.module.mlp.inter_b = self._h4h_b + self.module.mlp.output_w = self._4hh_w + self.module.mlp.output_b = self._4hh_b diff --git a/deepspeed/module_inject/containers/features/hybrid_megatron.py b/deepspeed/module_inject/containers/features/hybrid_megatron.py new file mode 100644 index 000000000000..d40f2a6b57e8 --- /dev/null +++ b/deepspeed/module_inject/containers/features/hybrid_megatron.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .hybrid_engine import HybridEngineContainer +from .megatron import MegatronContainer + + +class HybridMegatronContainer(MegatronContainer, HybridEngineContainer): + + def _align_qkv(self, x: torch.Tensor): + """ + Internal helper for accepting the head-contiguous weight matrix and chunking + the query, key, and value components. + """ + attention_head_size = x.shape[0] // self.num_attention_heads + new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:] + x_1 = x.view(*new_x_shape) + div_dim = len(x_1.size()) - 2 if len(x.shape) == 2 else -1 + (q, k, v) = torch.split(x_1, (x_1.shape[div_dim] // 3), dim=div_dim) + if len(q.shape) > 2: + x.data.copy_( + torch.cat((q.reshape(-1, q.shape[-1]), k.reshape(-1, q.shape[-1]), v.reshape(-1, q.shape[-1])), + dim=0).reshape(x.shape)) + else: + x.data.copy_(torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)) + + def transform_for_inference(self) -> None: + """ + Overrides the HybridEngineContainer implementation. + + The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V + are sequential in memory. This is different from the default layout in which all of the Qs + are sequential, followed by all of the Ks, and then all of the Vs. Here, we take the default + layout and transform it to the inference layout. + """ + if hasattr(self.qkvw, 'ds_id'): + from deepspeed.runtime.zero import GatheredParameters + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + param_list = [self.qkvw, self.qkvb] + non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \ + param.ds_status == ZeroParamStatus.NOT_AVAILABLE)] + with GatheredParameters(non_active_params): + self._align_qkv(self.qkvw) + self._align_qkv(self.qkvb) + else: + self._align_qkv(self.qkvw) + self._align_qkv(self.qkvb) + + def _partition_qkv(self, x: torch.Tensor): + """ + Internal helper for taking contiguous QKV and partitioning it for contiguous + heads. + """ + q_k_v = torch.split(x, (x.shape[0] // 3), dim=0) + attention_head_size = q_k_v[0].shape[0] // self.num_attention_heads + new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:] + q, k, v = [data.view(*new_x_shape) for data in q_k_v] + if len(q.shape) > 2: + x.data.copy_(torch.cat((q, k, v), dim=-2).reshape(-1, q.shape[-1])) + else: + x.data.copy_(torch.cat((q, k, v), dim=-1).reshape(-1)) + + def transform_for_training(self): + """ + Overrides the HybridEngineContainer implementation. + + The alternative layout of the QKV matrix for Megatron is such that each head's Q, K, and V + are sequential in memory. This is different from the default layout in which all of the Qs + are sequential, followed by all of the Ks, and then all of the Vs. This function takes the inference format and reverts it back to the default format. + """ + # If parameter is distributed, handle gathering it + if hasattr(self.qkvw, 'ds_id'): + from deepspeed.runtime.zero import GatheredParameters + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + param_list = [self.qkvw, self.qkvb] + non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \ + param.ds_status == ZeroParamStatus.NOT_AVAILABLE)] + with GatheredParameters(non_active_params): + self._partition_qkv(self.qkvw) + self._partition_qkv(self.qkvb) + else: + self._partition_qkv(self.qkvw) + self._partition_qkv(self.qkvb) diff --git a/deepspeed/module_inject/containers/features/megatron.py b/deepspeed/module_inject/containers/features/megatron.py index b223fb96231d..4daccf7d7c8d 100644 --- a/deepspeed/module_inject/containers/features/megatron.py +++ b/deepspeed/module_inject/containers/features/megatron.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.megatron_v2 = self.policy.is_megatron_v2 - def transpose_qkv_alignment(self, x): + def _align_qkv_transposed(self, x): attention_head_size = x.shape[-1] // self.num_attention_heads new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size) x_1 = x.view(*new_x_shape) @@ -27,5 +27,5 @@ def transpose_qkv_alignment(self, x): def transpose(self): super().transpose() if self.megatron_v2: - self.qkvw = torch.nn.parameter.Parameter(self.transpose_qkv_alignment(self.qkvw).contiguous()) - self.qkvb = torch.nn.parameter.Parameter(self.transpose_qkv_alignment(self.qkvb).contiguous()) + self.qkvw = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvw).contiguous()) + self.qkvb = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvb).contiguous()) diff --git a/deepspeed/module_inject/containers/features/meta_tensor.py b/deepspeed/module_inject/containers/features/meta_tensor.py index 7aa507ca2e44..57b136663be3 100644 --- a/deepspeed/module_inject/containers/features/meta_tensor.py +++ b/deepspeed/module_inject/containers/features/meta_tensor.py @@ -4,11 +4,20 @@ # DeepSpeed Team from abc import ABC, abstractmethod +from packaging import version as pkg_version +import torch class MetaTensorContainer(ABC): + """ + NOTE: If you are using this feature with a container that + also inherits from `HybridEngineContainer`, ensure that `MetaTensorContainer` + is inherited before `HybridEngineContainer` in the class definition. + """ def __init__(self, **kwargs): + if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__): + raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+") super().__init__(**kwargs) self.is_meta = False self.ckpt_load_enabled = True @@ -17,14 +26,14 @@ def initialize_tensors(self, enable_training=False): super().initialize_tensors(enable_training=enable_training) self.is_meta = self.qkvw.is_meta - def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None): + def apply_tensor_parallelism(self, mp_replace, **kwargs): if self.is_meta: if self.qkvb is None: self.module.attention.attn_qkvb = None if self.dense_b is None: self.module.attention.attn_ob = None else: - super().apply_tensor_parallelism(mp_replace, mp_group, tp_size) + super().apply_tensor_parallelism(mp_replace, **kwargs) def copy_data_to_new_module(self): if self.is_meta: @@ -43,7 +52,7 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): """ Load all the transformer parameter from the checkpoint file (sd). In addition to the parameter names, we require two - more parameters to help read the the data correctly + more parameters to help read the data correctly from the checkpoint and split the qkv heads in the right order: 1. `use_load_prefix` (Default: False): this specifies @@ -51,7 +60,7 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): layer of the model for searching the parameter's name in a checkpoint file. For more information of how this is used please see - https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py + https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py 2. `split_qkv` (Default: True): we use this flag when splitting the qkv parameter into heads. If it is False, it means the heads of q, k, and v are stored together and needs to split in the diff --git a/deepspeed/module_inject/containers/features/split_qkv.py b/deepspeed/module_inject/containers/features/split_qkv.py new file mode 100644 index 000000000000..f4c14d4e425a --- /dev/null +++ b/deepspeed/module_inject/containers/features/split_qkv.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +import torch + +from .hybrid_engine import HybridEngineContainer + + +class HybridSplitQKVContainer(HybridEngineContainer): + + def set_attention(self, qkvw, qkvb, dense_w, dense_b): + super().set_attention(qkvw, qkvb, dense_w, dense_b) + self.set_q_k_v() + + @abstractmethod + def set_q_k_v(self): + """ + In `set_q_k_v`, it is necessary to populate the following variables (where appropriate) + for the given model: + self.qw: q weight + self.qb: q bias + self.kw: k weight + self.kb: k bias + self.vw: v weight + self.vb: v bias + """ + raise NotImplementedError("A set_q_k_v() function must be defined in the model container \ + in order to set the unfused q, k, and v tensors.") + + def attention_qkv_mp(self, mp_replace, reversed_dim=False): + # Only need to alter + if self.module.attention.attn_qkvw is None: + params = [ + (self.module.attention.attn_qw, self.qw), + (self.module.attention.attn_qb, self.qb), + (self.module.attention.attn_kw, self.kw), + (self.module.attention.attn_kb, self.kb), + (self.module.attention.attn_vw, self.vw), + (self.module.attention.attn_vb, self.vb), + ] + for dst, src in params: + dst = mp_replace.copy( + dst[:self.qw.shape[0] // mp_replace.mp_size], src, int8=reversed_dim, + allocate_tensor=reversed_dim) if src is not None else None + else: + super().attention_qkv_mp(mp_replace) + + def release_qkv(self): + super().release_qkv() + split_qkv_params = [ + (self.module.attention.attn_qw, self.qw), + (self.module.attention.attn_qb, self.qb), + (self.module.attention.attn_kw, self.kw), + (self.module.attention.attn_kb, self.kb), + (self.module.attention.attn_vw, self.vw), + (self.module.attention.attn_vb, self.vb), + ] + + self._release_params(split_qkv_params) + + def reset_qkv(self): + self.qkvw.data[:self.qw.shape[0]] = self.qw.data + self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data + self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data + + qkv_data = [self.qw.data, self.kw.data, self.vw.data] + + self.qw.data = self.qkvw.data[:self.qw.shape[0]] + self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:] + + if self.qkvb is not None: + self.qkvb.data[:self.qw.shape[0]] = self.qb.data + self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data + self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data + + qkv_data.extend([self.qb.data, self.kb.data, self.vb.data]) + + self.qb.data = self.qkvb.data[:self.qw.shape[0]] + self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:] + + for data in qkv_data: + del data + + def reset_qkv_experimental(self): + """ + WIP - experimental and likely to be changed/improved. + Unused by keeping for now. + """ + if self.module.attention.attn_qkvw is None: + self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3, + self.qw.shape[0], + dtype=self.qw.dtype, + device=self.qw.device) + self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3, + dtype=self.qw.dtype, + device=self.qw.device) + self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data + self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data + self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data + self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data + self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data + self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data + + qkv_data = [self.qw.data, \ + self.qb.data, \ + self.kw.data, \ + self.kb.data, \ + self.vw.data, \ + self.vb.data] + + self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]] + self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]] + self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] + self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] + + for data in qkv_data: + del data + + def set_attn_params_wo_copy(self, Z3_enabled=False): + self.module.attention.attn_ow = self.dense_w + self.module.attention.attn_ob = self.dense_b + if not Z3_enabled: + # In initialize_tensors, we create a fused qkvw with the appropriate shape + # and copy the qw, qb, kw, kb, vw, vb into it + self.module.attention.attn_qkvw = self.qkvw + self.module.attention.attn_qkvb = self.qkvb + + # We reset the data for qw (which is the original model parameter) to point + # to the fused weight matrix we have created here + self.qw.data = self.qkvw[:self.qw.shape[0], :] + self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :] + self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :] + + # Assume if one of the biases is not None, then all of them are not None + if self.qb is not None: + self.qb.data = self.qkvb[:self.qw.shape[0]] + self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vb.data = self.qkvb[self.qw.shape[0] * 2:] + else: + # In ZeRO-3 this will be managed by ZeRO and handled separately in the + # forward of ds_attention + self.module.attention.attn_qw = self.qw + self.module.attention.attn_qb = self.qb + self.module.attention.attn_kw = self.kw + self.module.attention.attn_kb = self.kb + self.module.attention.attn_vw = self.vw + self.module.attention.attn_vb = self.vb + + def get_attn_params(self): + params = super().get_attn_params() + params.extend([self.qw, self.qb, self.kw, self.kb, self.vw, self.vb]) + return params diff --git a/deepspeed/module_inject/containers/gpt2.py b/deepspeed/module_inject/containers/gpt2.py index 3f6373897c58..f887e4dd67b0 100644 --- a/deepspeed/module_inject/containers/gpt2.py +++ b/deepspeed/module_inject/containers/gpt2.py @@ -32,16 +32,14 @@ def __init__(self, client_module, inference=True): try: import transformers HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block - except: + except Exception: HFGPT2LayerPolicy._orig_layer_class = None def get_hidden_heads(self): return self.client_module.attn.embed_dim, \ self.client_module.attn.num_heads, \ - self.client_module.ln_1.eps - - def get_q_k_v(self): - return None + self.client_module.ln_1.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): return self.client_module.attn.c_attn.weight, \ @@ -49,7 +47,7 @@ def attention(self, enable_training=False): self.client_module.attn.c_proj.weight, \ self.client_module.attn.c_proj.bias - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.c_fc.weight, \ self.client_module.mlp.c_fc.bias, \ self.client_module.mlp.c_proj.weight, \ @@ -60,6 +58,3 @@ def layernorm(self): self.client_module.ln_2.bias, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/gptj.py b/deepspeed/module_inject/containers/gptj.py index e7883105dde9..1beb1616db53 100644 --- a/deepspeed/module_inject/containers/gptj.py +++ b/deepspeed/module_inject/containers/gptj.py @@ -5,6 +5,7 @@ from .base import * from .features.meta_tensor import MetaTensorContainer +from .features.split_qkv import HybridSplitQKVContainer from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference import torch from torch.nn.parameter import Parameter @@ -13,8 +14,10 @@ from ..policy import maybe_copy from ..policy import maybe_copy_qkv +from ..policy import maybe_get_lora -class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer): + +class DS_GPTJContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -27,6 +30,35 @@ def create_module(self, config=None): self.module.config.scale_attention = self.scale_attention return self.module + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.fc_in, self.policy.client_module.mlp.fc_out, + self.policy.client_module.attn.q_proj, self.policy.client_module.attn.k_proj, + self.policy.client_module.attn.v_proj, self.policy.client_module.attn.out_proj + ] + ] + + def get_lora_matched_pair(self): + fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw), + (k_lora, self.kw), (v_lora, self.vw)] + return ret + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.attn.q_proj.weight + self.qb = None + self.kw = self.policy.client_module.attn.k_proj.weight + self.kb = None + self.vw = self.policy.client_module.attn.v_proj.weight + self.vb = None + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'attn.q_proj.weight', \ @@ -66,16 +98,14 @@ def __init__(self, client_module, inference=True): try: import transformers HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock - except: + except Exception: HFGPTJLayerPolicy._orig_layer_class = None def get_hidden_heads(self): - return self.client_module.attn.q_proj.weight.shape[1], \ + return self.client_module.attn.embed_dim, \ self.client_module.attn.num_attention_heads, \ - self.client_module.ln_1.eps - - def get_q_k_v(self): - return None + self.client_module.ln_1.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): qw = self.client_module.attn.q_proj.weight @@ -89,7 +119,7 @@ def attention(self, enable_training=False): self.client_module.attn.out_proj.weight, \ None, - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.fc_in.weight, \ self.client_module.mlp.fc_in.bias, \ self.client_module.mlp.fc_out.weight, \ @@ -100,6 +130,3 @@ def layernorm(self): None, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/gptneo.py b/deepspeed/module_inject/containers/gptneo.py index b9261b8c0b3b..6f8c0f13a1a4 100644 --- a/deepspeed/module_inject/containers/gptneo.py +++ b/deepspeed/module_inject/containers/gptneo.py @@ -5,6 +5,7 @@ from .base import * from .features.meta_tensor import MetaTensorContainer +from .features.split_qkv import HybridSplitQKVContainer from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference import torch from torch.nn.parameter import Parameter @@ -13,8 +14,10 @@ from ..policy import maybe_copy from ..policy import maybe_copy_qkv +from ..policy import maybe_get_lora -class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer): + +class DS_GPTNEOContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -27,6 +30,38 @@ def create_module(self, config=None): self.module.config.scale_attention = self.scale_attention return self.module + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.c_fc, self.policy.client_module.mlp.c_proj, + self.policy.client_module.attn.attention.q_proj, self.policy.client_module.attn.attention.k_proj, + self.policy.client_module.attn.attention.v_proj, self.policy.client_module.attn.attention.out_proj + ] + ] + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.attn.attention.q_proj.weight + self.qb = None + self.kw = self.policy.client_module.attn.attention.k_proj.weight + self.kb = None + self.vw = self.policy.client_module.attn.attention.v_proj.weight + self.vb = None + + def get_lora_matched_pair(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw), + (k_lora, self.kw), (v_lora, self.vw)] + return ret + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'attn.attention.q_proj.weight', \ @@ -68,16 +103,22 @@ def __init__(self, client_module, inference=True): try: import transformers HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock - except: + except Exception: HFGPTNEOLayerPolicy._orig_layer_class = None def get_hidden_heads(self): - return self.client_module.attn.attention.q_proj.weight.shape[1], \ + return self.client_module.attn.attention.embed_dim, \ self.client_module.attn.attention.num_heads, \ - self.client_module.ln_1.eps + self.client_module.ln_1.eps, \ + DEFAULT_INTERMEDIATE_SIZE def get_q_k_v(self): - return None + return self.client_module.attn.attention.q_proj.weight, \ + None, \ + self.client_module.attn.attention.k_proj.weight, \ + None, \ + self.client_module.attn.attention.v_proj.weight, \ + None def attention(self, enable_training=False): qw = self.client_module.attn.attention.q_proj.weight @@ -91,7 +132,7 @@ def attention(self, enable_training=False): self.client_module.attn.attention.out_proj.weight, \ self.client_module.attn.attention.out_proj.bias - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.c_fc.weight, \ self.client_module.mlp.c_fc.bias, \ self.client_module.mlp.c_proj.weight, \ @@ -102,6 +143,3 @@ def layernorm(self): self.client_module.ln_2.bias, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/gptneox.py b/deepspeed/module_inject/containers/gptneox.py index 61fd2b14dab0..3978b166b329 100644 --- a/deepspeed/module_inject/containers/gptneox.py +++ b/deepspeed/module_inject/containers/gptneox.py @@ -5,7 +5,7 @@ from .base import * from .features.meta_tensor import MetaTensorContainer -from .features.megatron import MegatronContainer +from .features.hybrid_megatron import HybridMegatronContainer from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference import torch from ..policy import TransformerPolicy @@ -13,8 +13,10 @@ from ..policy import maybe_copy from packaging import version as pkg_version +from ..policy import maybe_get_lora -class DS_GPTNEOXContainer(MetaTensorContainer, MegatronContainer, BaseTransformerContainer): + +class DS_GPTNEOXContainer(MetaTensorContainer, HybridMegatronContainer, BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -32,6 +34,30 @@ def create_module(self, config=None): return self.module + def get_lora_matched_pair(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params() + ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)] + return ret + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + if GPTNEOXLayerPolicy.version == 0: + attention = self.policy.client_module.attention + else: + attention = self.policy.client_module.self_attention + + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, + attention.query_key_value, attention.dense + ] + ] + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'attention.query_key_value.weight', \ @@ -91,12 +117,10 @@ def get_hidden_heads(self): else: attention = self.client_module.self_attention - return self.client_module.attention.query_key_value.weight.shape[1], \ + return self.client_module.attention.hidden_size, \ self.client_module.attention.num_attention_heads, \ - self.client_module.input_layernorm.eps - - def get_q_k_v(self): - return None + self.client_module.input_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): if GPTNEOXLayerPolicy.version == 0: @@ -109,7 +133,7 @@ def attention(self, enable_training=False): attention.dense.weight, \ attention.dense.bias - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.mlp.dense_h_to_4h.weight, \ self.client_module.mlp.dense_h_to_4h.bias, \ self.client_module.mlp.dense_4h_to_h.weight, \ @@ -120,6 +144,3 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/internlm.py b/deepspeed/module_inject/containers/internlm.py new file mode 100644 index 000000000000..31255d4b3ca5 --- /dev/null +++ b/deepspeed/module_inject/containers/internlm.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import importlib + +import torch +from torch.nn.parameter import Parameter + +from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference +from deepspeed.utils.types import ActivationFuncType, NormType + +from ..policy import (TransformerPolicy, maybe_copy, maybe_copy_geglu, maybe_copy_qkv, maybe_get_lora, + transformer_param_names) +from .base import * +from .features import HybridGatedMLPContainer, HybridSplitQKVContainer + + +class DS_InternLMContainer(HybridGatedMLPContainer, HybridSplitQKVContainer, BaseTransformerContainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # All model specific things should be defined here instead of the base class. + + def create_module(self, config=None): + _config = config if config is not None else self.ds_model_config + + _config.rotate_half = True + _config.rotate_every_two = False + _config.rotary_dim = self.hidden_size // self.num_attention_heads + self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group) + + return self.module + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.up_proj.weight, self.policy.client_module.mlp.gate_proj.weight, + self.policy.client_module.mlp.down_proj.weight, self.policy.client_module.self_attn.q_proj.weight, + self.policy.client_module.self_attn.k_proj.weight, self.policy.client_module.self_attn.v_proj.weight, + self.policy.client_module.self_attn.o_proj.weight + ] + ] + + def get_lora_matched_pair(self): + up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w), + (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)] + return ret + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.self_attn.q_proj.weight + self.qb = self.policy.client_module.self_attn.q_proj.bias + self.kw = self.policy.client_module.self_attn.k_proj.weight + self.kb = self.policy.client_module.self_attn.k_proj.bias + self.vw = self.policy.client_module.self_attn.v_proj.weight + self.vb = self.policy.client_module.self_attn.v_proj.bias + + def set_mlp_gate(self): + """ + Necessary to implement for `HybridGatedMLPContainer` + """ + self.inter_up_w = self.policy.client_module.mlp.up_proj.weight + self.inter_up_b = None + self.inter_gate_w = self.policy.client_module.mlp.gate_proj.weight + self.inter_gate_b = None + + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): + param_names = ( + 'self_attn.q_proj.weight', \ + 'self_attn.k_proj.weight', \ + 'self_attn.v_proj.weight', \ + 'self_attn.o_proj.weight', \ + 'mlp.up_proj.weight', \ + 'mlp.gate_proj.weight', \ + 'mlp.down_proj.weight', \ + 'input_layernorm.weight', \ + 'post_attention_layernorm.weight' + 'self_attn.q_proj.bias', \ + 'self_attn.k_proj.bias', \ + 'self_attn.v_proj.bias', \ + 'self_attn.o_proj.bias', \ + ) + + maybe_copy_qkv(module.attention, + sd, + weight_quantizer, + mp_replace, + 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]], + split_qkv=self.policy.split_qkv) + maybe_copy_qkv(module.attention, + sd, + weight_quantizer, + mp_replace, + 'attn_qkvb', [prefix + param_names[9], prefix + param_names[10], prefix + param_names[11]], + split_qkv=self.policy.split_qkv) + maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[2], + prefix + param_names[3]) + maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[3], + prefix + param_names[12]) + maybe_copy_geglu(module.mlp, sd, weight_quantizer, mp_replace, 'inter_w', + [prefix + param_names[4], prefix + param_names[5]]) + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, 'output_w', prefix + param_names[6]) + + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7]) + maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8]) + + +class InternLMLayerPolicy(TransformerPolicy): + _orig_layer_class = [] + _orig_layer_class_inited = False + + def __init__(self, client_module, inference=True): + super().__init__( + inference, + mlp_act_func_type=ActivationFuncType.GATED_SILU, + norm_type=NormType.RMSNorm, + ) + self.client_module = client_module + + self._init_orig_layer_class_once() + + def _init_orig_layer_class_once(self): + if InternLMLayerPolicy._orig_layer_class_inited: + return + + for sub_pkg in ['', '.internlm-7b', '.internlm-chat-7b']: + try: + from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME + module = importlib.import_module(f"{TRANSFORMERS_DYNAMIC_MODULE_NAME}{sub_pkg}.modeling_internlm") + if module.InternLMDecoderLayer not in InternLMLayerPolicy._orig_layer_class: + InternLMLayerPolicy._orig_layer_class.append(module.InternLMDecoderLayer) + except ImportError: + continue + + InternLMLayerPolicy._orig_layer_class_inited = True + + def get_hidden_heads(self): + return self.client_module.self_attn.q_proj.weight.shape[1], \ + self.client_module.self_attn.num_heads, \ + self.client_module.input_layernorm.variance_epsilon, \ + self.client_module.mlp.gate_proj.weight.shape[0] + + def attention(self, enable_training=False): + qw = self.client_module.self_attn.q_proj.weight + kw = self.client_module.self_attn.k_proj.weight + vw = self.client_module.self_attn.v_proj.weight + qb = self.client_module.self_attn.q_proj.bias + kb = self.client_module.self_attn.k_proj.bias + vb = self.client_module.self_attn.v_proj.bias + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) + + return qkvw, \ + qkvb, \ + self.client_module.self_attn.o_proj.weight, \ + self.client_module.self_attn.o_proj.bias + + def mlp(self, enable_training=False): + mlp1_up = self.client_module.mlp.up_proj.weight + mlp1_gate = self.client_module.mlp.gate_proj.weight + mlp2 = self.client_module.mlp.down_proj.weight + + mlp1 = Parameter(torch.cat((mlp1_up, mlp1_gate), dim=0), requires_grad=enable_training) + + return mlp1, None, mlp2, None + + def layernorm(self): + return self.client_module.post_attention_layernorm.weight, \ + None, \ + self.client_module.input_layernorm.weight, \ + None diff --git a/deepspeed/module_inject/containers/llama.py b/deepspeed/module_inject/containers/llama.py new file mode 100644 index 000000000000..b63b6ec6cbfa --- /dev/null +++ b/deepspeed/module_inject/containers/llama.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base import * +from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer +from deepspeed.utils.types import ActivationFuncType, NormType +from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference +import torch +from torch.nn.parameter import Parameter + +from ..policy import ( + TransformerPolicy, + transformer_param_names, + maybe_copy, + maybe_copy_qkv, + maybe_copy_geglu, + maybe_get_lora, +) + + +class DS_LLAMAContainer(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer, + BaseTransformerContainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # All model specific things should be defined here instead of the base class. + + def create_module(self, config=None): + _config = config if config is not None else self.ds_model_config + + _config.rotate_half = True + _config.rotate_every_two = False + _config.rotary_dim = self.hidden_size // self.num_attention_heads + if hasattr(self.policy.client_module.self_attn, 'config'): + _config.rope_theta = self.policy.client_module.self_attn.config.rope_theta + else: + _config.rope_theta = self.policy.client_module.self_attn.rope_theta + self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group) + + return self.module + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.mlp.up_proj.weight, self.policy.client_module.mlp.gate_proj.weight, + self.policy.client_module.mlp.down_proj.weight, self.policy.client_module.self_attn.q_proj.weight, + self.policy.client_module.self_attn.k_proj.weight, self.policy.client_module.self_attn.v_proj.weight, + self.policy.client_module.self_attn.o_proj.weight + ] + ] + + def get_lora_matched_pair(self): + up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w), + (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)] + return ret + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.self_attn.q_proj.weight + self.qb = None + self.kw = self.policy.client_module.self_attn.k_proj.weight + self.kb = None + self.vw = self.policy.client_module.self_attn.v_proj.weight + self.vb = None + + def set_mlp_gate(self): + """ + Necessary to implement for `HybridGatedMLPContainer` + """ + self.inter_up_w = self.policy.client_module.mlp.up_proj.weight + self.inter_up_b = None + self.inter_gate_w = self.policy.client_module.mlp.gate_proj.weight + self.inter_gate_b = None + + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): + param_names = ( + 'self_attn.q_proj.weight', \ + 'self_attn.k_proj.weight', \ + 'self_attn.v_proj.weight', \ + 'self_attn.o_proj.weight', \ + 'mlp.up_proj.weight', \ + 'mlp.gate_proj.weight', \ + 'mlp.down_proj.weight', \ + 'post_attention_layernorm.weight', \ + 'input_layernorm.weight', + ) + + maybe_copy_qkv(module.attention, + sd, + weight_quantizer, + mp_replace, + 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]], + split_qkv=self.policy.split_qkv) + for i in range(3, 4): + maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1], + prefix + param_names[i]) + maybe_copy_geglu(module.mlp, sd, weight_quantizer, mp_replace, 'inter_w', + [prefix + param_names[4], prefix + param_names[5]]) + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, 'output_w', prefix + param_names[6]) + + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7]) + maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8]) + + # This line is necessary for proper output when kernels + meta tensors are used in Llama models + # TODO: Investigate root-cause and fix meta tensor loading + module.mlp.output_b = None + + +class LLAMALayerPolicy(TransformerPolicy): + + def __init__(self, client_module, inference=True): + super().__init__( + inference, + mlp_act_func_type=ActivationFuncType.GATED_SILU, + norm_type=NormType.RMSNorm, + ) + self.client_module = client_module + try: + import transformers + LLAMALayerPolicy._orig_layer_class = transformers.models.llama.modeling_llama.LlamaDecoderLayer # type: ignore + except Exception: + LLAMALayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + if hasattr(self.client_module.self_attn, 'config'): + num_heads = self.client_module.self_attn.config.num_attention_heads + else: + num_heads = self.client_module.self_attn.num_heads + hidden_heads = ( + self.client_module.self_attn.q_proj.in_features, + num_heads, + self.client_module.input_layernorm.variance_epsilon, + self.client_module.mlp.gate_proj.out_features, + ) + return hidden_heads + + def attention(self, enable_training=False): + qw = self.client_module.self_attn.q_proj.weight + kw = self.client_module.self_attn.k_proj.weight + vw = self.client_module.self_attn.v_proj.weight + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + + return qkvw, \ + None, \ + self.client_module.self_attn.o_proj.weight, \ + None + + def mlp(self, enable_training=False): + mlp1_up = self.client_module.mlp.up_proj.weight + mlp1_gate = self.client_module.mlp.gate_proj.weight + mlp2 = self.client_module.mlp.down_proj.weight + + mlp1 = Parameter(torch.cat((mlp1_up, mlp1_gate), dim=0), requires_grad=enable_training) + + return mlp1, None, mlp2, None + + def layernorm(self): + return self.client_module.post_attention_layernorm.weight, \ + None, \ + self.client_module.input_layernorm.weight, \ + None diff --git a/deepspeed/module_inject/containers/llama2.py b/deepspeed/module_inject/containers/llama2.py new file mode 100644 index 000000000000..3b376b4017bb --- /dev/null +++ b/deepspeed/module_inject/containers/llama2.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base import * +from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer +from deepspeed.utils.types import ActivationFuncType, NormType +from deepspeed.model_implementations.transformers.ds_llama2 import DeepSpeedLlama2Inference +import torch +from torch.nn.parameter import Parameter + +from ..policy import ( + TransformerPolicy, + transformer_param_names, + maybe_copy, + maybe_copy_qkv, + maybe_copy_geglu, + maybe_get_lora, +) + + +class DS_LLAMA2Container(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer, + BaseTransformerContainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # All model specific things should be defined here instead of the base class. + + def create_module(self, config=None): + _config = config if config is not None else self.ds_model_config + + _config.rotate_half = False + _config.rotate_every_two = True + _config.rotary_dim = self.hidden_size // self.num_attention_heads + _config.num_kv = self.policy.client_module.attention.n_kv_heads + self.module = DeepSpeedLlama2Inference(_config, mp_group=self.mp_group) + + return self.module + + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.feed_forward.w3.weight, self.policy.client_module.feed_forward.w1.weight, + self.policy.client_module.feed_forward.w2.weight, self.policy.client_module.attention.wq.weight, + self.policy.client_module.attention.wk.weight, self.policy.client_module.attention.wv.weight, + self.policy.client_module.attention.wo.weight + ] + ] + + def get_lora_matched_pair(self): + up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w), + (out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)] + return ret + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.attention.wq.weight + self.qb = None + self.kw = self.policy.client_module.attention.wk.weight + self.kb = None + self.vw = self.policy.client_module.attention.wv.weight + self.vb = None + + def set_mlp_gate(self): + """ + Necessary to implement for `HybridGatedMLPContainer` + """ + self.inter_up_w = self.policy.client_module.feed_forward.w2.weight + self.inter_up_b = None + self.inter_gate_w = self.policy.client_module.feed_forward.w1.weight + self.inter_gate_b = None + + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): + param_names = ( + 'attention.wq.weight', \ + 'attention.wk.weight', \ + 'attention.wv.weight', \ + 'attention.wo.weight', \ + 'feed_forward.w3.weight', \ + 'feed_forward.w1.weight', \ + 'feed_forward.w2.weight', \ + 'ffn_norm.weight', \ + 'attention_norm.weight' + ) + + maybe_copy_qkv(module.attention, + sd, + weight_quantizer, + mp_replace, + 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]], + split_qkv=self.policy.split_qkv) + for i in range(3, 4): + maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1], + prefix + param_names[i]) + maybe_copy_geglu(module.mlp, sd, weight_quantizer, mp_replace, 'inter_w', + [prefix + param_names[4], prefix + param_names[5]]) + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, 'output_w', prefix + param_names[6]) + + maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7]) + maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8]) + + +class LLAMA2LayerPolicy(TransformerPolicy): + + def __init__(self, client_module, inference=True): + super().__init__( + inference, + mlp_act_func_type=ActivationFuncType.GATED_SILU, + norm_type=NormType.RMSNorm, + ) + self.client_module = client_module + try: + import llama + LLAMA2LayerPolicy._orig_layer_class = llama.model.TransformerBlock # type: ignore + except Exception: + LLAMA2LayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attention.wq.weight.shape[1], \ + self.client_module.n_heads, \ + self.client_module.ffn_norm.eps, \ + (self.client_module.feed_forward.w1.weight.shape[0] * \ + deepspeed.comm.get_world_size() if deepspeed.comm.is_initialized() else 1) # this is a hack to inject when model is already partitioned! + + def attention(self, enable_training=False): + qw = self.client_module.attention.wq.weight + kw = self.client_module.attention.wk.weight + vw = self.client_module.attention.wv.weight + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + + return qkvw, \ + None, \ + self.client_module.attention.wo.weight, \ + None + + def mlp(self, enable_training=False): + mlp1_up = self.client_module.feed_forward.w3.weight + mlp1_gate = self.client_module.feed_forward.w1.weight + mlp2 = self.client_module.feed_forward.w2.weight + + mlp1 = Parameter(torch.cat((mlp1_up, mlp1_gate), dim=0), requires_grad=enable_training) + + return mlp1, None, mlp2, None + + def layernorm(self): + return self.client_module.ffn_norm.weight, \ + None, \ + self.client_module.attention_norm.weight, \ + None diff --git a/deepspeed/module_inject/containers/megatron_gpt.py b/deepspeed/module_inject/containers/megatron_gpt.py index 28b7df5b48d5..feaec866a3c5 100644 --- a/deepspeed/module_inject/containers/megatron_gpt.py +++ b/deepspeed/module_inject/containers/megatron_gpt.py @@ -51,16 +51,21 @@ def __init__(self, client_module, inference=True): try: from megatron.model.transformer import ParallelTransformerLayer MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer + MegatronLayerPolicy.version = 1 except ImportError: MegatronLayerPolicy._orig_layer_class = None def get_hidden_heads(self): - return self.client_module.attention.query_key_value.weight.shape[1], \ - self.client_module.attention.num_attention_heads, \ - self.client_module.input_layernorm.eps - - def get_q_k_v(self): - return None + if MegatronLayerPolicy.version == 0: + return self.client_module.attention.query_key_value.weight.shape[1], \ + self.client_module.attention.num_attention_heads, \ + self.client_module.input_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE + else: + return self.client_module.self_attention.query_key_value.weight.shape[1], \ + self.client_module.self_attention.num_attention_heads, \ + self.client_module.input_layernorm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): if self.inference: @@ -68,13 +73,15 @@ def attention(self, enable_training=False): attention = self.client_module.attention else: attention = self.client_module.self_attention + else: + return None return attention.query_key_value.weight, \ attention.query_key_value.bias, \ attention.dense.weight, \ attention.dense.bias - def mlp(self, moe_type='standard'): + def mlp(self, moe_type='standard', enable_training=False): from deepspeed.moe.utils import has_moe_layers moe, _ = has_moe_layers(self.client_module) @@ -110,6 +117,3 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias - - def get_lora_params(self): - return [] diff --git a/deepspeed/module_inject/containers/megatron_gpt_moe.py b/deepspeed/module_inject/containers/megatron_gpt_moe.py index 0d5248d8d4d4..c4063be05b6c 100644 --- a/deepspeed/module_inject/containers/megatron_gpt_moe.py +++ b/deepspeed/module_inject/containers/megatron_gpt_moe.py @@ -57,7 +57,7 @@ def __init__(self, client_module, inference=True): def get_num_experts(self): return self.num_experts - def mlp(self, moe_type='standard'): + def mlp(self, moe_type='standard', enable_training=False): # for now, all of this is tightly coupled to megatron-deepspeed moe implementation # todo: think and refactor this to be more general diff --git a/deepspeed/module_inject/containers/opt.py b/deepspeed/module_inject/containers/opt.py index 142b92744103..9ad78f1d73ec 100644 --- a/deepspeed/module_inject/containers/opt.py +++ b/deepspeed/module_inject/containers/opt.py @@ -4,7 +4,7 @@ # DeepSpeed Team from .base import * -from .features.meta_tensor import MetaTensorContainer +from .features import MetaTensorContainer, HybridSplitQKVContainer from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference import torch from torch.nn.parameter import Parameter @@ -16,7 +16,7 @@ from deepspeed.utils.types import ActivationFuncType -class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer): +class DS_OPTContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -29,6 +29,38 @@ def create_module(self, config=None): self.module.config.scale_attention = self.scale_attention return self.module + def set_lora_params(self): + """ + Necessary to implement for `HybridEngineContainer` + """ + self.lora_params = [ + maybe_get_lora(p) for p in [ + self.policy.client_module.fc1, + self.policy.client_module.fc2, + self.policy.client_module.self_attn.q_proj, + self.policy.client_module.self_attn.k_proj, + self.policy.client_module.self_attn.v_proj, + self.policy.client_module.self_attn.out_proj, + ] + ] + + def set_q_k_v(self): + """ + Necessary to implement for `HybridSplitQKVContainer` + """ + self.qw = self.policy.client_module.self_attn.q_proj.weight + self.qb = self.policy.client_module.self_attn.q_proj.bias + self.kw = self.policy.client_module.self_attn.k_proj.weight + self.kb = self.policy.client_module.self_attn.k_proj.bias + self.vw = self.policy.client_module.self_attn.v_proj.weight + self.vb = self.policy.client_module.self_attn.v_proj.bias + + def get_lora_matched_pair(self): + fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params() + ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw), + (k_lora, self.kw), (v_lora, self.vw)] + return ret + def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): param_names = ( 'self_attn.q_proj.weight', \ @@ -72,30 +104,31 @@ class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True, use_load_prefix=True): - super().__init__(inference, - linear_layer=True, - mlp_act_func_type=ActivationFuncType.ReLU, - pre_attn_norm=True, - use_load_prefix=use_load_prefix) + super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix) self.client_module = client_module try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer - except: + except Exception: HFOPTLayerPolicy._orig_layer_class = None + if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config, + "activation_function"): + if TransformerPolicy.hf_model_config.activation_function == "relu": + self.mlp_act_func_type = ActivationFuncType.ReLU + elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]: + self.mlp_act_func_type = ActivationFuncType.GELU + else: + raise ValueError("Unsupported activation function: {}".format( + TransformerPolicy.hf_model_config.activation_function)) + else: + self.mlp_act_func_type = ActivationFuncType.ReLU # default + def get_hidden_heads(self): return self.client_module.self_attn.embed_dim, \ self.client_module.self_attn.num_heads, \ - self.client_module.self_attn_layer_norm.eps - - def get_q_k_v(self): - return self.client_module.self_attn.q_proj.weight, \ - self.client_module.self_attn.q_proj.bias, \ - self.client_module.self_attn.k_proj.weight, \ - self.client_module.self_attn.k_proj.bias, \ - self.client_module.self_attn.v_proj.weight, \ - self.client_module.self_attn.v_proj.bias + self.client_module.self_attn_layer_norm.eps, \ + DEFAULT_INTERMEDIATE_SIZE def attention(self, enable_training=False): qw = self.client_module.self_attn.q_proj.weight @@ -114,7 +147,7 @@ def attention(self, enable_training=False): self.client_module.self_attn.out_proj.weight, \ self.client_module.self_attn.out_proj.bias - def mlp(self): + def mlp(self, enable_training=False): return self.client_module.fc1.weight, \ self.client_module.fc1.bias, \ self.client_module.fc2.weight, \ @@ -125,16 +158,3 @@ def layernorm(self): self.client_module.final_layer_norm.bias, \ self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias - - def get_lora_params(self): - all_lora_params = [] - for p in [ - self.client_module.fc1, \ - self.client_module.fc2, \ - self.client_module.self_attn.q_proj, \ - self.client_module.self_attn.k_proj, \ - self.client_module.self_attn.v_proj, \ - self.client_module.self_attn.out_proj, \ - ]: - all_lora_params.append(maybe_get_lora(p)) - return all_lora_params diff --git a/deepspeed/module_inject/containers/unet.py b/deepspeed/module_inject/containers/unet.py index 4e15699dc5a1..481792655531 100644 --- a/deepspeed/module_inject/containers/unet.py +++ b/deepspeed/module_inject/containers/unet.py @@ -17,6 +17,8 @@ def __init__(self): try: import diffusers self._orig_layer_class = diffusers.models.unet_2d_condition.UNet2DConditionModel + except AttributeError: + self._orig_layer_class = diffusers.models.unets.unet_2d_condition.UNet2DConditionModel except ImportError: self._orig_layer_class = None diff --git a/deepspeed/module_inject/containers/vae.py b/deepspeed/module_inject/containers/vae.py index 016e42c3dbb4..d26d0ef77ca9 100644 --- a/deepspeed/module_inject/containers/vae.py +++ b/deepspeed/module_inject/containers/vae.py @@ -13,10 +13,16 @@ def __init__(self): super().__init__() try: import diffusers - if hasattr(diffusers.models.vae, "AutoencoderKL"): + if hasattr(diffusers.models, "autoencoders"): + # Diffusers >= 0.25.0 + # Changes location to 'autoencoders' directory + self._orig_layer_class = diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL + elif hasattr(diffusers.models.vae, "AutoencoderKL"): + # Diffusers < 0.12.0 self._orig_layer_class = diffusers.models.vae.AutoencoderKL else: - # Diffusers >= 0.12.0 changes location of AutoencoderKL + # Diffusers >= 0.12.0 & < 0.25.0 + # Changes location of AutoencoderKL self._orig_layer_class = diffusers.models.autoencoder_kl.AutoencoderKL except ImportError: self._orig_layer_class = None @@ -32,5 +38,5 @@ def apply(self, module, enable_cuda_graph=True): return DSVAE(module, enable_cuda_graph=enable_cuda_graph) # NOTE (lekurile): Should we have a diffusers policy class? - def attention(self): + def attention(self, client_module): pass diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py new file mode 100644 index 000000000000..757dfc9abdaa --- /dev/null +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import torch +from deepspeed.utils.logging import warning_once +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads + + +def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): + qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list] + tp_fusedqkv_list = [ + torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0])) + ] + return tp_fusedqkv_list + + +def require_tp_fused_qkvw(name, mp_size): + fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn'] + + if mp_size == 1: + return False + for fused_name in fused_qkvw_name_list: + if fused_name in name: + return True + return False + + +def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): + + module_str = str(module).strip() + if src is None: + return + fused_type_dict = { + 'CodeGenBlock': 'codegentype', + 'BloomBlock': 'bloomtype', + 'GLMBlock': 'glmtype', + "MPTBlock": 'glmtype', + "MptBlock": 'glmtype', + "BaichuanLayer": 'glmtype', + "QWenBlock": 'qwentype', + "FalconDecoderLayer": 'bloomtype', + "GPTBigCodeBlock": 'bigcodetype', + "DecoderLayer": 'glmtype', + "Phi3DecoderLayer": "phi3type" + } + + def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): + # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py + assert get_num_kv_heads() % ( + mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" + #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) + + shape = input.shape + dst_shape = get_shard_size(shape[0], mp_size) + num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1]) + + #num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :] + src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1)) + src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split] + + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1) + tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1) + + return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] + + def _glm_type_transpose(input, mp_size): + #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) + + # For chatglm2 & chatglm3(kv_heads=2), need to special handle. + if get_num_kv_heads() == 2: + shape = input.shape + hidden_dim = get_n_embd() + kv_dim = (shape[0] - hidden_dim) // get_num_kv_heads() + q = input[:hidden_dim] + k = input[hidden_dim:hidden_dim + kv_dim] + v = input[hidden_dim + kv_dim:] + q_split = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + k_split = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + v_split = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((q_split[gpu_index], k_split[gpu_index], v_split[gpu_index]), dim=0) + else: + shape = input.shape + src_split = torch.split(input, shape[0] // 3, dim=0) + + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size)) + return split_fusedqkv[gpu_index] + + def _bloom_type_transpose(input, mp_size): + shape = input.shape + + split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0) + return split_fusedqkv[gpu_index] + + def _qwen_type_transpose(input, mp_size, module): + if not hasattr(module, "_ds_fusedqkv_entered"): + # Adjust splitting absolute value variables + setattr(module, "_ds_fusedqkv_entered", True) + module.attn.split_size = get_shard_size(module.attn.split_size, mp_size) + return _glm_type_transpose(input, mp_size) + + def _bigcode_type_transpose(input, mp_size): + n_embd = get_n_embd() + q = input[:n_embd] + kv = input[n_embd:] + shape = q.shape + split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], kv), dim=0) + + def _phi3_type_transpose(input, mp_size): + num_kv_heads = get_num_kv_heads() + num_heads = get_num_attention_heads() + hidden_size = input.shape[1] + head_dim = hidden_size // num_heads + q_pos = input.shape[0] - 2 * num_kv_heads * head_dim + q = input[:q_pos] + k = input[q_pos:q_pos + num_kv_heads * head_dim] + v = input[q_pos + num_kv_heads * head_dim:] + split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) + + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): + + # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following + # bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w] + # glmtype: [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w] + # codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file. + + if fused_qkv_type == 'bloomtype': + return _bloom_type_transpose(src, mp_size) + elif fused_qkv_type == 'codegentype': + return _codegen_type_transpose(src, mp_size) + elif fused_qkv_type == 'glmtype': + return _glm_type_transpose(src, mp_size) + elif fused_qkv_type == 'qwentype': + return _qwen_type_transpose(src, mp_size, module) + elif fused_qkv_type == 'bigcodetype': + return _bigcode_type_transpose(src, mp_size) + elif fused_qkv_type == 'phi3type': + return _phi3_type_transpose(src, mp_size) + + raise ValueError("unknown fused_qkv_type") + + module_name_matches = [k for k in fused_type_dict.keys() if k in module_str] + if module_name_matches: + # There can be overlap with matches (e.g., "DecoderLayer" and "FalconDecoderLayer"). + # We take the longest matching module_name + module_name = max(module_name_matches, key=len) + fused_type = fused_type_dict[module_name] + return _transpose_fused_qkvw(src, mp_size, fused_type, module) + warning_once("Unrecognized fusedkqv weight type, default to using bloom type," + "please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") + return _bloom_type_transpose(src, mp_size) + + +# For share qk type: +# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}] +# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn] +# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type. +def shard_value_with_share_qk( + weight, + bias, + rank, + world_size, + shard_value=True # True -> shard_value; False -> shard_oproj +): + if shard_value: + total_size = weight.shape[0] + weight_cat_dim = 0 + else: + total_size = weight.shape[1] + weight_cat_dim = 1 + num_heads = get_num_kv_heads() + head_dim = total_size // num_heads + assert (num_heads % world_size == 0) + if world_size > num_heads // 2: + RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}") + head_per_rank = num_heads // world_size + q_head_start = rank * head_per_rank + # mapping q_head to v_head + v_head_ids = [] + i = 0 + # mapping neighbor q_head to v_head + while i < head_per_rank: + v_head_ids.append(q_head_start // 2) + q_head_start += 2 + i = i + 2 + + # mapping neighbor k_head to v_head + v_head_ids.extend([i + num_heads // 2 for i in v_head_ids]) + sharded_weight = [] + sharded_bias = [] + for head_id in v_head_ids: + if shard_value: + sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim]) + if bias is not None: + sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim]) + else: + sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim]) + sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim) + if bias is not None: + if shard_value: + sharded_bias = torch.cat(sharded_bias, dim=0) + else: + bias = bias / float(world_size) + return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias) + else: + return torch.nn.Parameter(sharded_weight), None + + +# For phi3 with chunk mlp, adjust the weight order. +def shard_chunk_mlp( + weight, + bias, + rank, + world_size, +): + weight_gate, weight_states = weight.chunk(2, dim=0) + total_size = weight_gate.shape[0] + split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) + if bias is not None: + bias_gate, bias_states = bias.chunk(2, dim=0) + split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) + + return shard_weight, None diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 70dd1a3af0e1..0a8cda1d1daa 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -4,55 +4,977 @@ # DeepSpeed Team import torch +import re from deepspeed import comm as dist from torch import nn from torch.nn import functional as F - from torch.nn.parameter import Parameter from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +from deepspeed.runtime.zero.utils import is_zero_param +from abc import ABC, abstractmethod +from typing import Iterable, Any, Optional, List, Tuple, Dict +from .fusedqkv_utils import shard_value_with_share_qk, shard_chunk_mlp, prepare_tp_fused_qkvw +from deepspeed.runtime.tensor_parallel import AUTOTP_MODE +from deepspeed.checkpoint.constants import DS_AUTOTP_UC_META +from copy import deepcopy +from typing import Union + +__all__ = [ + "TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce", + "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer", + "SubParamLinearLayer", "SubParamLinearAllreduce" +] + +DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE +DS_IS_REPLACED_MODULE = 'ds_is_replaced_module' +DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel' + + +def _normalize_uc_shape(value): + return tuple(value) if value is not None else None + + +def _build_param_uc_conversion_meta(*, + partition_type, + partition_dim=None, + sub_param_shape=None, + original_shape=None, + is_bias=False, + replicated=False): + """Build the conversion-facing subset of parameter UC metadata. + + This is the only schema that should flow into model-level + `UNIVERSAL_CHECKPOINT_INFO` via `collect_autotp_universal_checkpoint_info()`. + """ + return { + 'partition_type': partition_type, + 'partition_dim': partition_dim, + 'sub_param_shape': _normalize_uc_shape(sub_param_shape), + 'original_shape': _normalize_uc_shape(original_shape), + 'is_bias': is_bias, + 'replicated': replicated, + } + + +def _build_param_uc_restore_meta(*, + partition_type, + partition_dim=None, + logical_shape=None, + output_shape=None, + sub_param_shape=None, + sub_param_sizes=None, + target_partition_shape=None, + original_shape=None, + is_bias=False, + replicated=False): + """Build the restore-facing parameter UC metadata. + + Restore metadata stays on the parameter object and may include details that + are intentionally omitted from model-level conversion schema. + """ + return { + 'partition_type': + partition_type, + 'partition_dim': + partition_dim, + 'logical_shape': + _normalize_uc_shape(logical_shape), + 'output_shape': + _normalize_uc_shape(output_shape), + 'sub_param_shape': + _normalize_uc_shape(sub_param_shape), + 'sub_param_sizes': + _normalize_uc_shape(sub_param_sizes), + 'target_partition_shape': + _normalize_uc_shape(target_partition_shape), + 'original_shape': + _normalize_uc_shape(original_shape), + 'is_bias': + is_bias, + 'replicated': + replicated, + 'conversion': + _build_param_uc_conversion_meta(partition_type=partition_type, + partition_dim=partition_dim, + sub_param_shape=sub_param_shape, + original_shape=original_shape, + is_bias=is_bias, + replicated=replicated), + } + + +def get_auto_tp_mode(): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE + + +def is_autotp_training_mode(): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING + + +def set_autotp_mode(training=False): + """ + Set the DEEPSPEED_AUTOTP_MODE based on the training flag + """ + global DEEPSPEED_AUTOTP_MODE + if training: + DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.TRAINING + else: + DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE + + +def add_bias(input, bias): + if bias is None: + return input + if is_autotp_training_mode(): + # Training mode - avoid inplace to ensure correct autograd + input = input + bias + return input + else: + input += bias + return input + + +class RowParallel(torch.autograd.Function): + """ + A custom autograd function for performing row-wise parallelism. + """ + + @staticmethod + def symbolic(graph, input): + """Symbolic function for tracing.""" + return input + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, is_inference_mode: bool) -> torch.Tensor: + """ + Forward pass. + """ + ctx.group = group + if group == None: + return input + if is_inference_mode: + dist.inference_all_reduce(input, group=group) + else: + dist.all_reduce(input.contiguous(), group=group) + return input + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None]: + """ + Backward pass. + """ + return None, grad_output, None + + +class AsyncColumnParallel(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor: + """ + Forward pass. + """ + ctx.use_bias = bias is not None + ctx.group = group + output = torch.matmul(input, weight.transpose(-1, -2)) + if bias is not None: + output = add_bias(output, bias) + + ctx.save_for_backward(input, weight) + + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: + + input, weight = ctx.saved_tensors + grad_input = grad_output.matmul(weight) + handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True) + grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1])) + grad_bias = grad_output.sum(0) if ctx.use_bias else None + handle.wait() + return None, grad_input, grad_weight, grad_bias + + +class ColumnParallel(torch.autograd.Function): + """ + Custom autograd function for column-wise parallelism. + """ + + @staticmethod + def symbolic(graph, input): + """Symbolic function for tracing.""" + return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group()) + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + """ + ctx.group = group + return input + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: + """ + Backward pass. + """ + if ctx.group == None: + return None, grad_output -class LinearAllreduce(nn.Module): + dist.all_reduce(grad_output.contiguous(), group=ctx.group) + return None, grad_output - def __init__(self, weight, bias=None, mp_group=None): - super(LinearAllreduce, self).__init__() - self.weight = weight - self.bias = bias + +class TensorParallel_Layer(nn.Module, ABC): + """ + A base class for model layers with tensor parallelism support. + This class is designed to be extended by specific layers that require distributed + operations and parameter gather/partitioning during inference or training. + + Attributes: + mode (str): The mode of operation[INFERENCE or TRAINING], default is "INFERENCE". + mp_group (Optional[dist.ProcessGroup]): The process group used for model parallelism. + tp_world_size (int): The world size of tensor parallelism, i.e., the number of parallel workers. + tp_index (int): The rank (ID) of the current worker in tensor parallelism. + support_training (bool): Flag indicating whether the layer supports training (default: False). + name (Optional[str]): The name of the layer, if provided. + """ + ##### Initialize Parameter List ##### + + # keep_module_on_host determines whether to keep the module on the host. + # Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory), + # so an additional copy is unnecessary. + keep_module_on_host: bool = False + + ##### Runtime Parameter List ##### + tp_overlap_comm: bool = False + """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """ + + def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any): + """ + Initializes the TensorParallel_Layer with optional model parallelism group and layer name. + + Args: + mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism. + If None, no model parallelism is set. + """ + super().__init__() + self.support_training: bool = False self.mp_group = mp_group + if mp_group is not None: + self.tp_world_size: int = dist.get_world_size(self.mp_group) + self.tp_index: int = dist.get_rank(self.mp_group) + else: + self.tp_world_size: int = 1 + self.tp_index: int = 0 + + # backward compatibility + self.world_size = self.tp_world_size + self.rank = self.tp_index + + self.name = getattr(self, 'name', None) + if kwargs.get('name') is not None: + self.name = kwargs.get('name') # Set the layer name if provided. + + @classmethod + def set_keep_module_on_host(cls, value: bool): + """ + Set the static variable keep_module_on_host. + + Args: + value (bool): The new value for keep_module_on_host. + """ + cls.keep_module_on_host = value + + @abstractmethod + def forward(self, input): + """ + Forward pass method. Must be implemented by subclasses to define layer-specific operations. + """ + pass + + @abstractmethod + def gather_params(self, params_list): + """ + Gathers parameters across devices for distributed training. Must be implemented by subclasses in "TRAINING" mode. + """ + pass + + @abstractmethod + def _tp_partition(self, params_list: List[torch.Tensor]): + """ + Partitions the parameters for tensor parallelism. + It is necessary to ensure that this function only involves the logic of params partitioning. + """ + pass + + def config_requires_grad(self, weight): + if weight is not None: + if self.is_training_mode(): + if weight.requires_grad is None: + weight.requires_grad = True + else: + weight.requires_grad = False + + def config_tp_params(self, weight): + """ + Configures the weight tensor for training with tensor parallelism. This includes enabling gradients + and associating necessary methods for parameter gathering and partitioning. + + Args: + weight (Optional[torch.Tensor]): The weight tensor to configure for tensor parallelism. + If None, no action is taken. + """ + # # The RNG states have already been synchronized in init_inference. + if self.is_training_mode(): + assert self.support_training, "No implementation of backward." + if weight is not None: + self.config_requires_grad(weight) + weight.gather_params = self.gather_params + weight._tp_partition = self._tp_partition + setattr(weight, DS_TENSOR_MODEL_PARALLEL, True) + setattr(weight, DS_IS_REPLACED_MODULE, True) + + def _set_param_uc_meta(self, + param, + *, + partition_type, + partition_dim=None, + logical_shape=None, + output_shape=None, + sub_param_shape=None, + sub_param_sizes=None, + target_partition_shape=None, + original_shape=None, + is_bias=False, + replicated=False): + if param is None: + return + setattr( + param, DS_AUTOTP_UC_META, + _build_param_uc_restore_meta(partition_type=partition_type, + partition_dim=partition_dim, + logical_shape=logical_shape, + output_shape=output_shape, + sub_param_shape=sub_param_shape, + sub_param_sizes=sub_param_sizes, + target_partition_shape=target_partition_shape, + original_shape=original_shape, + is_bias=is_bias, + replicated=replicated)) + + def _mark_uc_metadata(self): + return + + def _should_materialize_tp_partition(self): + # AutoTP partitioning should only materialize parameters when an actual + # TP process group is present. Metadata-only construction with + # mp_group=None should not touch device placement. + return self.mp_group is not None + + def is_training_mode(self): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING + + def __deepcopy__(self, memo): + # This function is designed for + # 'mp_group' (a 'ProcessGroup') cannot be pickled during deepcopy in some usage. + cls = self.__class__ + new_obj = cls.__new__(cls) + + for key, value in vars(self).items(): + if key == 'mp_group': + new_obj.mp_group = self.mp_group + else: + setattr(new_obj, key, deepcopy(value, memo)) + + memo[id(self)] = new_obj + return new_obj + + def extra_repr(self): + out_features, in_features = None, None + if self.weight is not None: + out_features, in_features = self.weight.ds_shape[-2:] if is_zero_param( + self.weight) else self.weight.shape[-2:] + dtype = self.weight.dtype if self.weight is not None else None + return "in_features={}, out_features={}, bias={}, dtype={}".format(in_features, out_features, self.bias + is not None, dtype) + + def move(self, tensor): + # TODO: consider the timing of deletion + # to save host resources when DP > 1。 + + # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some + # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. + if tensor.is_meta: + # Keep tensor in meta device if tensor is meta. + return tensor + else: + device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name() + return_new_copy = not self.__class__.keep_module_on_host + + # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). + # Using copy=True instead of clone() will help in case of cpu --> cpu. + # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. + cloned_tensor = tensor.to(device, copy=return_new_copy) + + if return_new_copy: + # free the memory of the original tensor to reduce memory peak + # Equivalent to directly deleting the tensor reference outside the function. + # see https://github.com/microsoft/DeepSpeed/pull/4353 + tensor.data = torch.empty(0, device=tensor.device) + return cloned_tensor + + +def configure_tensor_parallel_runtime(config): + runtime_keys = ['tp_overlap_comm'] + for key in runtime_keys: + if hasattr(config, key): + setattr(TensorParallel_Layer, key, getattr(config, key)) + + +def _get_param_uc_conversion_meta(param: torch.Tensor) -> Optional[Dict[str, Any]]: + """Return the conversion-facing view of AutoTP UC metadata for a parameter. + + AutoTP keeps a single parameter-level metadata object with two roles: + - top-level fields: restore-time details consumed by `universal_checkpoint.py` + - `conversion`: conversion-time details consumed by + `collect_autotp_universal_checkpoint_info()` and then aggregated into + model-level `UNIVERSAL_CHECKPOINT_INFO` for `ds_to_universal.py` + """ + meta = getattr(param, DS_AUTOTP_UC_META, None) + if not meta: + return None + return meta.get('conversion', None) + + +def collect_autotp_universal_checkpoint_info(model: nn.Module) -> Dict[str, Any]: + """Collect the model-level conversion schema for AutoTP universal checkpoints. + + The returned `UNIVERSAL_CHECKPOINT_INFO` is intentionally limited to the + pattern/schema data needed during checkpoint conversion. It does not include + restore-time per-parameter details such as `sub_param_sizes` or + `target_partition_shape`, which stay on the parameter metadata object. + """ + from deepspeed.checkpoint.constants import (ORIGINAL_VOCAB_SIZE, PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, + PARAMETER_WITH_SUB_PARAMS, TP_REPLICATED_PARAMETER_PATTERNS, + UNIVERSAL_CHECKPOINT_VERSION_KEY, UNIVERSAL_CHECKPOINT_VERSION_VALUE, + VOCABULARY_PARAMETER_PATTERNS) + + row_parallel_patterns = [] + replicated_patterns = [] + vocabulary_patterns = [] + parameter_with_sub_params = [] + original_vocab_size = None + + for module_name, module in model.named_modules(): + marker = getattr(module, "_mark_uc_metadata", None) + if marker is not None: + marker() + + for param_name, param in module.named_parameters(recurse=False): + conversion_meta = _get_param_uc_conversion_meta(param) + if not conversion_meta: + continue + + full_name = f"{module_name}.{param_name}" if module_name else param_name + pattern = rf"^{re.escape(full_name)}$" + + if conversion_meta.get('replicated'): + replicated_patterns.append(pattern) + + if conversion_meta.get('partition_type') == 'row' and not conversion_meta.get('is_bias', False): + row_parallel_patterns.append(pattern) + + original_shape = conversion_meta.get('original_shape') + if original_shape and len(original_shape) == 2 and ('embed' in full_name or 'lm_head' in full_name): + vocabulary_patterns.append(pattern) + if original_vocab_size is None: + original_vocab_size = original_shape[0] + + sub_param_shape = conversion_meta.get('sub_param_shape') + partition_dim = conversion_meta.get('partition_dim') + if sub_param_shape is not None and partition_dim is not None and not conversion_meta.get('is_bias', False): + parameter_with_sub_params.append({ + 'patterns': [pattern], + 'shape': list(sub_param_shape), + 'partition_dim': partition_dim, + }) + + uc_info = { + UNIVERSAL_CHECKPOINT_VERSION_KEY: UNIVERSAL_CHECKPOINT_VERSION_VALUE, + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS: sorted(set(row_parallel_patterns)), + TP_REPLICATED_PARAMETER_PATTERNS: sorted(set(replicated_patterns)), + VOCABULARY_PARAMETER_PATTERNS: sorted(set(vocabulary_patterns)), + PARAMETER_WITH_SUB_PARAMS: parameter_with_sub_params, + } + if original_vocab_size is not None: + uc_info[ORIGINAL_VOCAB_SIZE] = original_vocab_size + return uc_info + + +class GatherReplacedLayerParams: + """ + A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality + based on the configuration of the model. + """ + + def __init__(self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + module: torch.nn.Module, + enabled: bool = True): + """ + Initialize the context manager to handle parameter gathering and partitioning for a replaced layer. + + Args: + params (Iterable or torch.Tensor): A collection or single parameter to manage. + module (torch.nn.Module): The module that these parameters belong to. + enabled (bool): Flag indicating whether the parameter management is enabled (default: True). + """ + self.enabled = enabled + self.module = module + if not enabled: + return + + # Ensure params is a list, whether it's a single param or iterable (e.g., model.parameters()) + if isinstance(params, Iterable) and not isinstance(params, torch.Tensor): + self.params: List[torch.Tensor] = list(params) # Convert generators to a list for multiple iterations + else: + self.params: List[torch.Tensor] = [params] # Wrap single parameter in a list for uniform processing + + # Check if the parameters belong to a replaced layer (indicated by a specific attribute) + if not any(self._is_replaced_module_weight(p) for p in params): + self.enabled = False + return + + def _is_replaced_module_weight(self, param: torch.Tensor) -> bool: + """ + Helper function to determine if a parameter belongs to a replaced module. + + Args: + param (torch.Tensor): The parameter to check. + + Returns: + bool: True if the parameter belongs to a replaced module, False otherwise. + """ + return getattr(param, DS_IS_REPLACED_MODULE, False) + + def __enter__(self) -> None: + """ + Enter the context manager. If enabled, gather parameters for the replaced module. + """ + if self.enabled: + self.params[0].gather_params(self.params) + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Exit the context manager. If enabled, partition the parameters for the replaced module. + """ + #TODO : Check whether there are any missing attributes. + if self.enabled: + self.params[0]._tp_partition(self.params) + + +class LinearAllreduce(TensorParallel_Layer): + + def __init__(self, module, mp_group, **kwargs): + super(LinearAllreduce, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + + if self._should_materialize_tp_partition(): + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + # bias here is not tp params + self.config_requires_grad(self.bias) + self._mark_uc_metadata() def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) - if self.mp_group is not None: - dist.all_reduce(output, group=self.mp_group) + output = RowParallel.apply(self.mp_group, output, not self.is_training_mode()) if self.bias is not None: - output += self.bias + output = add_bias(output, self.bias) return output + @torch.no_grad() + def gather_params(self, params_list): -class LinearLayer(nn.Module): + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't gather bias + return + params_list[idx].data_partition = param.data + param = param.transpose(0, 1).contiguous() - def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): - super(LinearLayer, self).__init__() + output_param = torch.empty(self.tp_world_size * param.shape[0], + param.shape[1], + dtype=param.dtype, + device=param.device) + dist.all_gather_into_tensor(output_param, param, group=self.mp_group) + params_list[idx].data = output_param.transpose(0, 1).contiguous() + return + + @torch.no_grad() + def _tp_partition(self, params_list): + + if not self.is_training_mode(): + self.uneven_partition(params_list) + return + + else: + for idx, param in enumerate(params_list): + if param is None: + # don't slipt bias + return + if idx > 0: # move bias to device at initialization + _partition = self.move(param).detach() + params_list[idx].data = _partition + return + + _partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index] + + _partition = self.move(_partition).detach() + + params_list[idx].data = _partition + + def uneven_partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't slipt bias + return + assert self.name is not None, "The module name must be provided in the initialization." + _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[1], self.tp_world_size, + self.name), + dim=1)[self.tp_index] + + _partition = self.move(_partition).detach() + params_list[idx].data = _partition + + def _mark_uc_metadata(self): + original_weight_shape = (self.weight.shape[0], self.weight.shape[1] * self.tp_world_size) + self._set_param_uc_meta(self.weight, + partition_type='row', + partition_dim=1, + logical_shape=original_weight_shape, + output_shape=(original_weight_shape[0], ), + original_shape=original_weight_shape) + if self.bias is not None: + self._set_param_uc_meta(self.bias, + partition_type='row', + partition_dim=None, + logical_shape=tuple(self.bias.shape), + output_shape=tuple(self.bias.shape), + original_shape=tuple(self.bias.shape), + is_bias=True, + replicated=True) + + +#remove kwargs from partition. +class LinearLayer(TensorParallel_Layer): + + def __init__(self, module, mp_group=None, skip_partition=False, **kwargs): + super(LinearLayer, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + if not skip_partition and self._should_materialize_tp_partition(): + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + self._mark_uc_metadata() + + def forward(self, input): + if not self.__class__.tp_overlap_comm: + if getattr(self, 'mp_group', None) is not None: + input = ColumnParallel.apply(self.mp_group, input) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output = add_bias(output, self.bias) + else: + output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias) + + return output + + @torch.no_grad() + def gather_params(self, params_list): + # Does not support uneven shard. + for idx, param in enumerate(params_list): + + params_list[idx].data_partition = param.data + output_param = torch.empty((self.tp_world_size * param.shape[0], *param.shape[1:]), + dtype=param.dtype, + device=param.device) + dist.all_gather_into_tensor(output_param, param, group=self.mp_group) + params_list[idx].data = output_param.contiguous() + + @torch.no_grad() + def _tp_partition(self, params_list): + + if not self.is_training_mode(): + self.uneven_partition(params_list) + return + for idx, param in enumerate(params_list): + if param is None: + return + #split bias if provide + _partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index] + + _partition = self.move(_partition).detach() + + params_list[idx].data = _partition + + def uneven_partition(self, params_list): + + for idx, param in enumerate(params_list): + if param is None: + #split bias if provide + return + assert self.name is not None, "The module name must be provided in the initialization." + _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[0], self.tp_world_size, + self.name), + dim=0)[self.tp_index] + + _partition = self.move(_partition).detach() + + params_list[idx].data = _partition + + def _mark_uc_metadata(self): + original_out_dim = self.weight.shape[0] * self.tp_world_size + original_weight_shape = (original_out_dim, self.weight.shape[1]) + self._set_param_uc_meta(self.weight, + partition_type='column', + partition_dim=0, + logical_shape=original_weight_shape, + output_shape=(original_out_dim, ), + original_shape=original_weight_shape) + if self.bias is not None: + original_bias_shape = (self.bias.shape[0] * self.tp_world_size, ) + self._set_param_uc_meta(self.bias, + partition_type='column', + partition_dim=0, + logical_shape=original_bias_shape, + output_shape=original_bias_shape, + original_shape=original_bias_shape, + is_bias=True) + + # for bwc + @classmethod + def from_weights(cls, weight_shape=None, dtype=torch.half, weight=None, bias=None): if weight is not None: - self.weight = weight - self.bias = bias + in_features = weight.shape[1] + out_features = weight.shape[0] + linear = nn.Linear(in_features, out_features, bias=(bias is not None)) + linear.weight.data = weight + if bias is not None: + linear.bias.data = bias else: - self.weight = Parameter( - torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name())) + in_features = weight_shape[1] + out_features = weight_shape[0] + linear = nn.Linear(in_features, out_features, bias=(bias is not None)) + return cls(linear, skip_partition=True) - self.bias = Parameter( - torch.empty(weight_shape[0], - dtype=dtype, - device=get_accelerator().current_device_name())) \ - if bias is not None else None + +class FusedModuleWrapper: + + def __init__(self, fused_module: nn.Module): + self.fused_module = fused_module + + def __getattr__(self, module): + return self.fused_module + + +class fused_LinearLayer(LinearLayer): + + def __init__(self, module, mp_group, skip_partition=False, **kwargs): + assert kwargs.get('fused_module') is not None, "'fused_module' is required but not provided" + # Use the warp class to avoid module circular references. + self.fused_module = FusedModuleWrapper(kwargs.get('fused_module')) + super().__init__(module, mp_group, skip_partition, **kwargs) + + @torch.no_grad() + def _tp_partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + return + + _partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index) + + _partition = self.move(_partition).detach() + + params_list[idx].data = _partition + + +class conv_LinearLayer(LinearLayer): + + @torch.no_grad() + def _tp_partition(self, params_list): + weight = None + bias = None + if len(params_list) == 1: + weight = params_list[0] + elif len(params_list) == 2: + weight, bias = params_list[0], params_list[1] + _partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name), + dim=1)[self.tp_index] + _partition = self.move(_partition).detach() + weight.data = _partition + + if bias is not None: + _partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name), + dim=0)[self.tp_index] + _partition = self.move(_partition).detach() + + bias.data = _partition + + +#override the subclasses related to weight splitting. +class Yuan_LinearAllreduce(LinearAllreduce): + + #Yuan2 + @torch.no_grad() + def _tp_partition(self, params_list): + weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, + self.tp_world_size, False) + params_list[0].data = weight + if bias is not None: + params_list[1].data = bias + + +class Yuan_LinearLayer(LinearLayer): + #Yuan2 + @torch.no_grad() + def _tp_partition(self, params_list): + weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, + self.tp_world_size, True) + params_list[0].data = self.move(weight).detach() + if bias is not None: + params_list[1].data = self.move(bias).detach() + + +class GateUpPack_LinearLayer(LinearLayer): + # chatGLM2, chatGLM2 + @torch.no_grad() + def _tp_partition(self, params_list): + weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size) + params_list[0].data = self.move(weight).detach() + if bias is not None: + params_list[1].data = self.move(bias).detach() + + +class Conv_LinearALlreduce(LinearAllreduce): + + @torch.no_grad() + def _tp_partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + return + param.data = param.data.transpose(-1, -2).contiguous() + + _partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name), + dim=1)[self.tp_index] + + _partition = self.move(_partition).detach() + + params_list[idx].data = _partition + + +#override the subclasses related to fwd/bwd. +class LmHeadLinearAllreduce(LinearAllreduce): + + def __init__(self, module, mp_group, **kwargs): + # set the fixed name before partition + self.name = "lm_head" + + # In some tied_embedding cases, only the lm head is sharded, while the word embedding is not. + # Reinitialization is used to decouple them and prevent the word embedding from being sharded. + # This should also be effective for cases where both are sharded in tied_embedding scenarios. + + # TODO: Training scenario-related tests, is it necessary to re-implement the vocab parallel module? + module.weight = nn.Parameter(module.weight.clone().detach()) + if hasattr(module, 'bias') and module.bias is not None: + module.bias = nn.Parameter(module.bias.clone().detach()) + super().__init__(module, mp_group, **kwargs) def forward(self, input): - output = torch.matmul(input, self.weight.transpose(-1, -2)) + input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") + input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], + self.weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.inference_all_reduce(output, group=self.mp_group) if self.bias is not None: - output += self.bias + output = add_bias(output, self.bias) return output +class TensorParallelConv2d(nn.Module): + + def __init__(self, conv, rank, world_size, shard_by_oc): + super().__init__() + self.rank = rank + self.world_size = world_size + self.shard_by_oc = shard_by_oc + self.shard_weights(conv) + + # Split along the input/output channel depending on whether it is the last conv layer. + def shard_weights(self, conv): + if self.shard_by_oc: + total_size = conv.weight.shape[0] + else: + total_size = conv.weight.shape[1] + bias_data = None + cols_per_rank = [0] + for i in range(self.world_size - 1, -1, -1): + cols = total_size // self.world_size + if i < total_size % self.world_size: + cols += 1 + cols_per_rank.append(cols_per_rank[-1] + cols) + weight_data = conv.weight.data + if self.shard_by_oc: + # not last conv layer, split output channel + weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] + if conv.bias is not None: + bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] + else: + # last conv layer, split input channel + weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] + if conv.bias is not None: + bias_data = conv.bias.data / float(self.world_size) + self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding, + conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode) + self.conv.weight = torch.nn.Parameter(weight_data) + if conv.bias is not None: + self.conv.bias = torch.nn.Parameter(bias_data) + del conv + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.conv(input) + + +class TensorParallelOcShardConv2d(TensorParallelConv2d): + + def __init__(self, conv, rank, world_size): + super().__init__(conv, rank, world_size, True) + + +class TensorParallelIcShardConv2d(TensorParallelConv2d): + + def __init__(self, conv, rank, world_size): + super().__init__(conv, rank, world_size, False) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = self.conv(input) + if self.world_size > 1: + dist.inference_all_reduce(out) + return out + + class Normalize(nn.Module): def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None): @@ -99,7 +1021,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None): self.offset = 2 super().__init__(weight_shape, weight=weight) - def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long() @@ -110,3 +1032,422 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int positions = positions[:, past_key_values_length:] return super().forward(positions + self.offset) + + +def _shape_prod(values): + result = 1 + for val in values: + result *= val + return result + + +def _normalize_shape_spec(shape): + if isinstance(shape, list): + return tuple(_normalize_shape_spec(item) for item in shape) + if isinstance(shape, tuple): + return tuple(_normalize_shape_spec(item) if isinstance(item, list) else item for item in shape) + return shape + + +def _infer_subparam_logical_shapes(weight_shape, shape, partition_dim, name=None): + shape = _normalize_shape_spec(shape) + if not isinstance(shape, tuple): + raise ValueError("AutoTP shape must be a tuple for sub-parameter partitioning.") + if partition_dim < 0 or partition_dim >= len(shape): + raise ValueError(f"AutoTP partition_dim {partition_dim} is out of range for shape length {len(shape)}.") + + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + partition_elem = shape[partition_dim] + subparam_sizes = None + num_subparams = None + + if isinstance(partition_elem, tuple): + if len(partition_elem) == 0: + raise ValueError(f"{layer_label} sub-parameter size tuple cannot be empty.") + if any(isinstance(val, tuple) for val in partition_elem): + raise ValueError(f"{layer_label} supports only 1-level nesting at partition_dim.") + if any((not isinstance(val, int)) or val <= 0 for val in partition_elem): + raise ValueError(f"{layer_label} sub-parameter sizes must be positive integers.") + subparam_sizes = tuple(int(val) for val in partition_elem) + partition_dim_size = sum(subparam_sizes) + elif isinstance(partition_elem, int): + if partition_elem == -1: + partition_dim_size = None + elif partition_elem > 0: + num_subparams = partition_elem + partition_dim_size = None + else: + raise ValueError(f"{layer_label} partition_dim spec must be positive integer or -1.") + else: + raise ValueError(f"{layer_label} partition_dim spec must be int or tuple.") + + logical_dims = [] + for idx, dim in enumerate(shape): + if idx == partition_dim: + logical_dims.append(partition_dim_size) + continue + if isinstance(dim, tuple): + raise ValueError(f"{layer_label} nested tuple only allowed at partition_dim={partition_dim}.") + if isinstance(dim, int): + if dim == -1: + logical_dims.append(None) + elif dim > 0: + logical_dims.append(dim) + else: + raise ValueError(f"{layer_label} shape dimensions must be positive integers or -1.") + else: + raise ValueError(f"{layer_label} shape dimensions must be integers.") + + total_numel = _shape_prod(weight_shape) + known_product = _shape_prod([dim for dim in logical_dims if dim is not None]) + unknown_indices = [idx for idx, dim in enumerate(logical_dims) if dim is None] + + if len(unknown_indices) == 0: + if known_product != total_numel: + raise ValueError(f"{layer_label} shape product {known_product} != weight numel {total_numel}.") + elif len(unknown_indices) == 1: + inferred = total_numel // known_product + if inferred * known_product != total_numel: + raise ValueError(f"{layer_label} cannot infer shape for weight with numel {total_numel}.") + logical_dims[unknown_indices[0]] = inferred + else: + if len(shape) == len(weight_shape): + for idx in unknown_indices: + logical_dims[idx] = weight_shape[idx] + if _shape_prod(logical_dims) != total_numel: + raise ValueError( + f"{layer_label} shape product {_shape_prod(logical_dims)} != weight numel {total_numel}.") + else: + raise ValueError(f"{layer_label} shape has multiple inferred dims and is ambiguous for weight.") + + logical_shape = tuple(logical_dims) + if logical_shape[-1] != weight_shape[-1]: + raise ValueError( + f"{layer_label} shape last dim {logical_shape[-1]} must match weight input dim {weight_shape[-1]}.") + + output_shape = logical_shape[:-1] + if len(output_shape) == 0: + raise ValueError(f"{layer_label} shape must include at least one output dimension.") + if _shape_prod(output_shape) != weight_shape[0]: + raise ValueError( + f"{layer_label} output shape product {_shape_prod(output_shape)} != weight output dim {weight_shape[0]}.") + + partition_dim_size = logical_shape[partition_dim] + if partition_dim_size is None or partition_dim_size <= 0: + raise ValueError(f"{layer_label} partition_dim size must be a positive integer.") + + if num_subparams is not None: + if partition_dim_size % num_subparams != 0: + raise ValueError( + f"{layer_label} partition_dim size {partition_dim_size} not divisible by sub-param count {num_subparams}." + ) + subparam_sizes = tuple([partition_dim_size // num_subparams] * num_subparams) + + if subparam_sizes is not None and sum(subparam_sizes) != partition_dim_size: + raise ValueError( + f"{layer_label} sub-parameter sizes sum {sum(subparam_sizes)} != partition_dim size {partition_dim_size}.") + + bias_partition_dim = partition_dim if partition_dim < len(output_shape) else None + return logical_shape, output_shape, subparam_sizes, bias_partition_dim + + +def _partition_logical_tensor(tensor, partition_dim, tp_world_size, tp_index, name=None, subparam_sizes=None): + if tp_world_size == 1: + return tensor + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + if subparam_sizes: + for size in subparam_sizes: + if size % tp_world_size != 0: + raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.") + sub_params = torch.split(tensor, subparam_sizes, dim=partition_dim) + partitioned_sub_params = [torch.chunk(sp, tp_world_size, dim=partition_dim)[tp_index] for sp in sub_params] + return torch.cat(partitioned_sub_params, dim=partition_dim) + if tensor.shape[partition_dim] % tp_world_size != 0: + raise ValueError( + f"{layer_label} partition_dim size {tensor.shape[partition_dim]} not divisible by tp_size {tp_world_size}." + ) + return torch.chunk(tensor, tp_world_size, dim=partition_dim)[tp_index] + + +def _all_gather_along_dim(tensor, partition_dim, mp_group, tp_world_size): + if mp_group is None or tp_world_size == 1: + return tensor + perm = [partition_dim] + [idx for idx in range(tensor.dim()) if idx != partition_dim] + inv_perm = [0] * len(perm) + for idx, dim in enumerate(perm): + inv_perm[dim] = idx + tensor_perm = tensor.permute(perm).contiguous() + output = torch.empty((tp_world_size * tensor_perm.shape[0], *tensor_perm.shape[1:]), + dtype=tensor.dtype, + device=tensor.device) + dist.all_gather_into_tensor(output, tensor_perm, group=mp_group) + return output.permute(inv_perm).contiguous() + + +def _gather_logical_tensor(tensor, + logical_shape, + partition_dim, + mp_group, + tp_world_size, + name=None, + subparam_sizes=None): + if mp_group is None or tp_world_size == 1: + return tensor.reshape(logical_shape) + layer_label = f"AutoTP layer '{name}'" if name else "AutoTP layer" + if logical_shape[partition_dim] % tp_world_size != 0: + raise ValueError( + f"{layer_label} partition_dim size {logical_shape[partition_dim]} not divisible by tp_size {tp_world_size}." + ) + partitioned_shape = list(logical_shape) + partitioned_shape[partition_dim] = logical_shape[partition_dim] // tp_world_size + tensor_view = tensor.reshape(partitioned_shape) + + if subparam_sizes: + for size in subparam_sizes: + if size % tp_world_size != 0: + raise ValueError(f"{layer_label} sub-parameter size {size} not divisible by tp_size {tp_world_size}.") + partitioned_sizes = [size // tp_world_size for size in subparam_sizes] + sub_params = torch.split(tensor_view, partitioned_sizes, dim=partition_dim) + gathered_sub_params = [_all_gather_along_dim(sp, partition_dim, mp_group, tp_world_size) for sp in sub_params] + return torch.cat(gathered_sub_params, dim=partition_dim) + return _all_gather_along_dim(tensor_view, partition_dim, mp_group, tp_world_size) + + +class SubParamLinearLayer(TensorParallel_Layer): + """ + Column-parallel linear layer with sub-parameter support. + + Handles cases where weights contain multiple logical sub-parameters + that need to be partitioned separately (e.g., fused QKV, chunked MLP, GQA). + + The `shape` parameter controls how the weight is viewed and partitioned: + - (3, -1) with partition_dim=0: 3 equal sub-params, partition each at dim 0 + - ((q, k, v), -1) with partition_dim=0: 3 unequal sub-params (1-level nesting) + """ + + def __init__(self, module, mp_group, shape, partition_dim=0, **kwargs): + super(SubParamLinearLayer, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + self.shape = shape + self.partition_dim = partition_dim + + self._orig_weight_shape = tuple(module.weight.shape) + self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None + (self._logical_shape, self._output_shape, self._subparam_sizes, + self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape, + self.partition_dim, self.name) + if self.bias is not None and self.bias.numel() != _shape_prod(self._output_shape): + raise ValueError(f"AutoTP layer '{self.name}' bias size {self.bias.numel()} does not match output shape " + f"{self._output_shape}.") + + if self._should_materialize_tp_partition(): + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + self._mark_uc_metadata() + + def forward(self, input): + if getattr(self, 'mp_group', None) is not None: + input = ColumnParallel.apply(self.mp_group, input) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output = add_bias(output, self.bias) + return output + + @torch.no_grad() + def gather_params(self, params_list): + """Gather partitioned parameters back to full size.""" + for idx, param in enumerate(params_list): + if param is None: + continue + params_list[idx].data_partition = param.data + if idx == 0: + full_view = _gather_logical_tensor(param, + self._logical_shape, + self.partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_view.reshape(self._orig_weight_shape) + else: + if self._bias_partition_dim is None: + params_list[idx].data = param.data + else: + full_bias_view = _gather_logical_tensor(param, + self._output_shape, + self._bias_partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_bias_view.reshape(self._orig_bias_shape) + + @torch.no_grad() + def _tp_partition(self, params_list): + weight = params_list[0] + if weight is None: + return + + weight_view = weight.reshape(self._logical_shape) + partitioned_view = _partition_logical_tensor(weight_view, + self.partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach() + + if params_list[1] is not None: + if self._bias_partition_dim is None: + params_list[1].data = self.move(params_list[1]).detach() + else: + bias_view = params_list[1].reshape(self._output_shape) + bias_partitioned = _partition_logical_tensor(bias_view, + self._bias_partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[1].data = self.move(bias_partitioned.reshape(-1)).detach() + + def _mark_uc_metadata(self): + self._set_param_uc_meta(self.weight, + partition_type='column', + partition_dim=self.partition_dim, + logical_shape=self._logical_shape, + output_shape=self._output_shape, + sub_param_shape=self.shape, + sub_param_sizes=self._subparam_sizes, + target_partition_shape=self.weight.shape, + original_shape=self._orig_weight_shape) + if self.bias is not None: + self._set_param_uc_meta( + self.bias, + partition_type='column', + partition_dim=self._bias_partition_dim, + logical_shape=self._output_shape, + output_shape=self._output_shape, + sub_param_shape=self.shape if self._bias_partition_dim is not None else None, + sub_param_sizes=self._subparam_sizes if self._bias_partition_dim is not None else None, + target_partition_shape=self.bias.shape, + original_shape=self._orig_bias_shape, + is_bias=True, + replicated=self._bias_partition_dim is None) + + +class SubParamLinearAllreduce(TensorParallel_Layer): + """ + Row-parallel linear layer with sub-parameter support (AllReduce after forward). + + Handles cases where weights contain multiple logical sub-parameters + that need to be partitioned separately. + """ + + def __init__(self, module, mp_group, shape, partition_dim=1, **kwargs): + super(SubParamLinearAllreduce, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + self.shape = shape + self.partition_dim = partition_dim + + self._orig_weight_shape = tuple(module.weight.shape) + self._orig_bias_shape = tuple(module.bias.shape) if self.bias is not None else None + (self._logical_shape, self._output_shape, self._subparam_sizes, + self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape, + self.partition_dim, self.name) + + if self._should_materialize_tp_partition(): + self._tp_partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_requires_grad(self.bias) + self._mark_uc_metadata() + + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + output = RowParallel.apply(self.mp_group, output, not self.is_training_mode()) + if self.bias is not None: + output = add_bias(output, self.bias) + return output + + @torch.no_grad() + def gather_params(self, params_list): + """Gather partitioned parameters back to full size.""" + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't gather bias for row parallel + return + params_list[idx].data_partition = param.data + full_view = _gather_logical_tensor(param, + self._logical_shape, + self.partition_dim, + self.mp_group, + self.tp_world_size, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[idx].data = full_view.reshape(self._orig_weight_shape) + + @torch.no_grad() + def _tp_partition(self, params_list): + weight = params_list[0] + if weight is None: + return + + weight_view = weight.reshape(self._logical_shape) + partitioned_view = _partition_logical_tensor(weight_view, + self.partition_dim, + self.tp_world_size, + self.tp_index, + name=self.name, + subparam_sizes=self._subparam_sizes) + params_list[0].data = self.move(partitioned_view.reshape(-1, partitioned_view.shape[-1])).detach() + + # Bias is not partitioned for row parallel (it's applied after all-reduce) + if params_list[1] is not None: + params_list[1].data = self.move(params_list[1]).detach() + + def _mark_uc_metadata(self): + self._set_param_uc_meta(self.weight, + partition_type='row', + partition_dim=self.partition_dim, + logical_shape=self._logical_shape, + output_shape=self._output_shape, + sub_param_shape=self.shape, + sub_param_sizes=self._subparam_sizes, + target_partition_shape=self.weight.shape, + original_shape=self._orig_weight_shape) + if self.bias is not None: + self._set_param_uc_meta(self.bias, + partition_type='row', + partition_dim=None, + logical_shape=self._orig_bias_shape, + output_shape=self._orig_bias_shape, + original_shape=self._orig_bias_shape, + target_partition_shape=self.bias.shape, + is_bias=True, + replicated=True) + + +class RMSNormalize(nn.Module): + + def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): + super(RMSNormalize, self).__init__() + if weight is not None: + self.weight = weight + else: + self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=get_accelerator().current_device_name())) + + self.eps = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return hidden_states * self.weight diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index aee47e77bbe9..280897e3617a 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -9,13 +9,15 @@ from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference +from deepspeed.model_implementations.transformers.ds_llama2 import DeepSpeedLlama2Inference import deepspeed.ops.transformer as transformer_inference -from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding +from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding, RMSNormalize import torch import gc from deepspeed.accelerator import get_accelerator import re +from .utils import transpose def load_model_with_checkpoint(r_module, @@ -29,22 +31,18 @@ def load_model_with_checkpoint(r_module, error_msgs = [] def prefix_check(): - # if keys start with 'model.', don't skip level 0 prefix + # if keys start with 'model.' or 'transformer.', don't skip level 0 prefix for key in sd[0].keys(): + # OPT models if re.match("^model[.]", key): return False + # BLOOM models + if re.match("^transformer[.]", key): + return False return True skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix - def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - def load(module, prefix): args = (sd[0], prefix, {}, True, [], [], error_msgs) @@ -175,8 +173,26 @@ def load_parameters(module, prefix): try: import transformers OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding - except: + if hasattr(transformers.models, "llama"): + LlamaRMSNorm = transformers.models.llama.modeling_llama.LlamaRMSNorm + else: + LlamaRMSNorm = None + except Exception: OPTLearnedPositionalEmbedding = None + try: + from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, + ) + except Exception: + ColumnParallelLinear = None + ParallelEmbedding = None + RowParallelLinear = None + try: + from llama.model import RMSNorm + except Exception: + RMSNorm = None layer_policies = { nn.Linear: load, nn.Embedding: load, @@ -190,8 +206,15 @@ def load_parameters(module, prefix): DeepSpeedBERTInference: load_transformer_layer, DeepSpeedMegatronGPTInference: load_transformer_layer, DeepSpeedOPTInference: load_transformer_layer, + DeepSpeedLlama2Inference: load_transformer_layer, OPTLearnedPositionalEmbedding: load, - OPTEmbedding: load + OPTEmbedding: load, + LlamaRMSNorm: load, + RMSNormalize: load, + ColumnParallelLinear: load, + ParallelEmbedding: load, + RowParallelLinear: load, + RMSNorm: load } all_ds_ids = {} @@ -206,7 +229,7 @@ def load_module_recursive(module, prefix='', level=0): child.weight.ds_id in all_ds_ids): prefix1 = all_ds_ids[child.weight.ds_id] if child.__class__ is nn.Linear: - child = LinearLayer(weight=all_ds_ids[child.weight.ds_id]) + child = LinearLayer.from_weights(weight=all_ds_ids[child.weight.ds_id]) setattr(module, name, child) continue child_params = list(child.parameters()) @@ -218,12 +241,19 @@ def load_module_recursive(module, prefix='', level=0): if child.__class__ is nn.LayerNorm: child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) - elif child.__class__ is nn.Linear: - child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias) + elif child.__class__ in [nn.Linear, ColumnParallelLinear, RowParallelLinear]: + child = LinearLayer.from_weights(weight_shape=child.weight.shape, + dtype=child.weight.dtype, + bias=child.bias) setattr(module, name, child) elif child.__class__ is OPTLearnedPositionalEmbedding: child = OPTEmbedding(weight_shape=ds_shape) setattr(module, name, child) + elif child.__class__ in [LlamaRMSNorm, RMSNorm]: + child = RMSNormalize(dim=ds_shape[-1], + dtype=child.weight.dtype, + eps=child.eps if hasattr(child, 'eps') else child.variance_epsilon) + setattr(module, name, child) else: ds_id = None if hasattr(child.weight, 'ds_id'): @@ -242,13 +272,6 @@ def load_module_recursive(module, prefix='', level=0): load_module_recursive(r_module) - embedding_weight = None - - for n, p in r_module.named_parameters(): - if "word_embeddings." in n or "embed_tokens." in n or "wte." in n: - embedding_weight = p - if embedding_weight is not None and r_module.lm_head.weight.is_meta: - r_module.lm_head.weight = embedding_weight for sd_ in sd: del sd_ sd = None diff --git a/deepspeed/module_inject/module_quantize.py b/deepspeed/module_inject/module_quantize.py index 765a7e96bd54..1f5b2f8a1d28 100755 --- a/deepspeed/module_inject/module_quantize.py +++ b/deepspeed/module_inject/module_quantize.py @@ -10,7 +10,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal """ Quantize bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, - e.g., transformers.modeling_bert.BertLayer. + e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer model (torch.nn.Module): user's nn.module representing their model megatron (bool): megatron model-parallel implementation (this is supported for inference only) diff --git a/deepspeed/module_inject/policy.py b/deepspeed/module_inject/policy.py index 87b34e5aab5a..dff12b6c64b3 100644 --- a/deepspeed/module_inject/policy.py +++ b/deepspeed/module_inject/policy.py @@ -4,9 +4,10 @@ # DeepSpeed Team from abc import ABC, abstractmethod -from deepspeed.utils.types import ActivationFuncType +from deepspeed.utils.types import ActivationFuncType, NormType import torch from deepspeed.accelerator import get_accelerator +from .utils import transpose transformer_param_names = ( 'attn_qkvw', \ @@ -58,7 +59,9 @@ def __init__( # this flag shows whether or not using prefix in loading the checkpoint use_load_prefix=False, # whether or not the qkv is stored in the split-format - split_qkv=True): + split_qkv=True, + # Type of normalization to perform + norm_type=NormType.LayerNorm): super().__init__() self.cuda_graph_supported = False self.inference = inference @@ -70,9 +73,10 @@ def __init__( self.pre_attn_norm = pre_attn_norm self.use_load_prefix = use_load_prefix self.split_qkv = split_qkv + self.norm_type = norm_type @abstractmethod - def attention(self, enable_training=False): + def attention(self): """ Returns attention qkv and dense parameters weight: (3*hidden, hidden) and (hidden, hidden) @@ -80,13 +84,6 @@ def attention(self, enable_training=False): """ raise NotImplementedError - @abstractmethod - def get_q_k_v(self): - """ - return all q,k,v parameters without merging them together - """ - raise NotImplementedError - @abstractmethod def get_hidden_heads(self): """ @@ -112,28 +109,10 @@ def layernorm(self): """ raise NotImplementedError - @abstractmethod - def get_lora_params(self): - """ - Returns lora parameters used in transformer layer - - """ - raise NotImplementedError - - -# TODO (lekurile): This function exists in base container as well, consolidate as some point -def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - # TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point def _transpose(x, heads=1, mp_replace=None): - heads = heads // mp_replace.mp_size + heads = heads // mp_replace.mp_size # type: ignore outer_dim = -1 attention_head_size = x.shape[outer_dim] // heads new_x_shape = x.size()[:outer_dim] + (heads, attention_head_size) @@ -164,15 +143,15 @@ def maybe_copy(module, tmp = sd[src_name] if len(dst.shape) == 1: if split_qkv: - dst = mp_replace.qkv_copy(dst, tmp) + dst = mp_replace.strided_copy(dst, tmp, num_splits=3) else: dst = mp_replace.copy(dst, tmp) if qkv and megatron_v2: dst = torch.nn.parameter.Parameter(_transpose(dst, heads=heads, mp_replace=mp_replace).contiguous()) else: if split_qkv: - dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \ - (transpose(tmp).contiguous())), int8=weight_quantizer.q_int8) + dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \ + (transpose(tmp).contiguous())), num_splits=3, int8=weight_quantizer.q_int8) else: if qkv and megatron_v2: tmp = _transpose(transpose(tmp), heads=heads, mp_replace=mp_replace).contiguous() @@ -193,19 +172,33 @@ def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names dst = getattr(module, dst_name) if len(dst.shape) == 1: if split_qkv: - dst = mp_replace.qkv_copy(dst, qkv_data.contiguous()) + dst = mp_replace.strided_copy(dst, qkv_data.contiguous(), num_splits=3) else: dst = mp_replace.copy(dst, qkv_data) else: if split_qkv: - dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ - ((transpose(qkv_data)).contiguous())), int8=weight_quantizer.q_int8) + dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ + ((transpose(qkv_data)).contiguous())), num_splits=3, int8=weight_quantizer.q_int8) else: dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ transpose(qkv_data)), int8=weight_quantizer.q_int8) setattr(module, dst_name, dst) +# Extending the `maybe_copy` function for when mlp1 is in separate parameters for GeGLU +def maybe_copy_geglu(module, sd, weight_quantizer, mp_replace, dst_name, src_names): + if src_names[0] in sd: + reg_proj = sd[src_names[0]] + gate_proj = sd[src_names[1]] + + mlp1_data = torch.cat((reg_proj, gate_proj), dim=0) + dst = getattr(module, dst_name) + + dst = mp_replace.strided_copy(dst, weight_quantizer.quantize(mlp1_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ + transpose(mlp1_data)), num_splits=2, int8=weight_quantizer.q_int8) + setattr(module, dst_name, dst) + + def pack_lora_weights(p): return [ p.lora_right_weight, \ diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index b6f20845dda0..263369fc0484 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -12,114 +12,18 @@ from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig from deepspeed.accelerator import get_accelerator -from .replace_policy import HFGPT2LayerPolicy from .replace_policy import replace_policies, generic_policies - +from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading +from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d +from deepspeed.module_inject.layers import is_autotp_training_mode from deepspeed import comm as dist -from torch import nn +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size -from .layers import LinearAllreduce, LinearLayer from .load_checkpoint import load_model_with_checkpoint import time from .utils import policy_to_ds_container - - -class ReplaceWithTensorSlicing: - - def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): - if mp_group is not None: - self.gpu_index = dist.get_rank(group=mp_group) - else: - self.gpu_index = 0 - self.out_dim = out_dim - self.in_dim = in_dim - self.mp_size = mp_size - - def merge_assert(self, dim1, dim2): - assert dim1 > dim2, \ - 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\ - for merging your checkpoints before replacing the transformer layer with\ - inference-kernels' - - def qkv_copy(self, dst, src, int8=False): - if src is None: - return src - src_shape = src.shape - dst_shape = dst.shape - - outer_dim = 0 if int8 else -1 - inner_dim = -1 if int8 else 0 - - src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim) - if (len(src_shape) == 2 and len(dst_shape) == 2): - if src_shape[outer_dim] == dst_shape[self.out_dim]: - dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) - dst = torch.nn.parameter.Parameter(dst, requires_grad=False) - if hasattr(src, 'scale'): - dst.scale = src.scale - return dst - if self.out_dim == 1: - self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) - qkv_size = dst_shape[self.out_dim] // 3 - qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split] - - weight_split = [ - torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0])) - ] - dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape( - weight_split[self.gpu_index].shape) - else: - dst.data.copy_(src_split[self.gpu_index].to(get_accelerator().current_device_name()).contiguous()) - else: - if src_shape[0] == dst_shape[0]: - return torch.nn.parameter.Parameter(src) - if self.out_dim == 1: - qkv_size = dst_shape[0] // 3 - qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] - bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))] - dst.data.copy_(bias_split[self.gpu_index].contiguous()) - else: - dst.data.copy_(src_split[self.gpu_index].contiguous()) - - dst = torch.nn.parameter.Parameter(dst, requires_grad=False) - if hasattr(src, 'scale'): - dst.scale = src.scale - return dst - - def copy(self, dst, src, int8=False, allocat_tensor=False): - if src is None: - return src - assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors - if allocat_tensor: - dst = torch.empty_like(dst) - outer_dim = 0 if int8 else 1 - inner_dim = 1 if int8 else 0 - src_shape = src.shape - dst_shape = dst.shape - if (len(src_shape) == 2 and len(dst_shape) == 2): - - if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]: - dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) - else: - if src_shape[inner_dim] != dst_shape[self.in_dim]: - self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim]) - dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \ - src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :]) - else: - self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) - dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \ - src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :]) - else: - if src_shape[0] == dst_shape[0]: - dst = src - else: - dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]]) - dst = torch.nn.parameter.Parameter(dst, requires_grad=False) - if hasattr(src, 'scale'): - dst.scale = src.scale - - return dst +import gc def get_transformer_name(replaced_module): @@ -181,7 +85,7 @@ def _module_match(module): return None -def generic_injection(module, fp16=False, enable_cuda_graph=True): +def generic_injection(module, dtype=None, enable_cuda_graph=True): def replace_attn(child, policy): policy_attn = policy.attention(child) @@ -189,13 +93,15 @@ def replace_attn(child, policy): return child if len(policy_attn) == 5: qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn + qw, kw, vw = torch.empty(0), torch.empty(0), torch.empty(0) else: qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn + qkvw = torch.empty(0) config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=heads, - fp16=fp16, + dtype=dtype, triangular_masking=False, max_out_tokens=4096, ) @@ -209,11 +115,15 @@ def transpose(data): return data if len(policy_attn) == 5: + assert qkvw is not None and qkvw.data is not None, "qkvw can't be None" attn_module.attn_qkvw.data = transpose(qkvw.data) else: attn_module.attn_qkvw = None + assert qw is not None and qw.data is not None, "qw can't be None" attn_module.attn_qw.data = transpose(qw.data) + assert kw is not None and kw.data is not None, "kw can't be None" attn_module.attn_kw.data = transpose(kw.data) + assert vw is not None and vw.data is not None, "vw can't be None" attn_module.attn_vw.data = transpose(vw.data) attn_module.attn_qkvb = None @@ -228,12 +138,15 @@ def replace_attn_block(child, policy): if isinstance(module, torch.nn.Module): pass else: - if fp16 is False: + if dtype not in [torch.float16, torch.half]: raise ValueError("Generic injection only supported with FP16") try: import diffusers - cross_attention = diffusers.models.attention.CrossAttention + if hasattr(diffusers.models.attention, 'CrossAttention'): + cross_attention = diffusers.models.attention.CrossAttention + else: + cross_attention = diffusers.models.attention_processor.Attention attention_block = diffusers.models.attention.BasicTransformerBlock new_policies = { cross_attention: replace_attn, @@ -277,7 +190,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, - e.g., transformers.modeling_bert.BertLayer. + e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer model (torch.nn.Module): user's nn.module representing their model checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine config: top-level DS Inference config defined in inference/config.py @@ -286,7 +199,6 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m Updated nn.module with replaced transformer layers """ # defining globals as internally defined functions inherit these everywhere - fp16 = (config.dtype == torch.float16 or config.dtype == torch.int8) quantize = (config.dtype == torch.int8) # todo: Refactor later. In future, let's minimize the style used above and use config.** instead @@ -319,7 +231,6 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, model_config=model_config, layer_id=layer_id, child=child) - _container.set_dtype(fp16) _container.set_moe(moe) # 2. Set the tensor parallelism config @@ -329,12 +240,12 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, _container.initialize_tensors() # 4. deal with data types -- needs refactor to use dtype instead of fp16 - if fp16: - _container.convert_to_required_dtype(dtype=torch.half) + if config.dtype in [torch.float16, torch.bfloat16, torch.int8]: + _container.convert_to_required_dtype() # 5. Set the quantization config quantizer = GroupQuantizer(q_int8=quantize) - _container.set_quantization_config(quantize, quantizer) + _container.set_quantization_config(quantizer) # 6. create a DS Inference config object _container.create_ds_model_config() @@ -359,148 +270,156 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module - def replace_wo_policy(module, all_reduce_linears): - mp_size = config.tensor_parallel.tp_size - mp_group = config.tensor_parallel.tp_group - - def _replace(child, name, conv_linear_layer): - mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) - weight_shape = child.weight.shape - if name in all_reduce_linears: - new_weight = torch.empty(( - weight_shape[1] if conv_linear_layer else weight_shape[0], - (weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size, - ), - device=child.weight.device, - dtype=child.weight.dtype) - if conv_linear_layer: - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - data = mp_replace.copy(new_weight, child.weight.data) - new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype) - if child.bias is not None: - new_bias.data.copy_(child.bias.data) - return LinearAllreduce(data, child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group) - else: - new_weight = torch.empty(( - (weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size, - weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], - ), - device=child.weight.device, - dtype=child.weight.dtype) - if conv_linear_layer: - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - data = mp_replace.copy(new_weight, child.weight.data) - - new_bias = torch.empty((weight_shape[0] // mp_size), - device=child.weight.device, - dtype=child.weight.dtype) - bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to( - get_accelerator().current_device_name()) - return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data) - - def _slice_embedding(child, name, conv_linear_layer): - mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) - new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size), - device=child.weight.device, - dtype=child.weight.dtype) - data = mp_replace.copy(new_weight, - child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \ - child.weight.data) - new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size) - new_embedding.weight.data.copy_(data) - return new_embedding - - def update_mp_params(child): - if hasattr(child, 'n_heads'): - assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format( - child.n_heads, mp_size) - child.n_heads = child.n_heads // mp_size - if hasattr(child, 'inner_dim'): - assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format( - child.inner_dim, mp_size) - child.inner_dim = child.inner_dim // mp_size - if hasattr(child, 'num_heads'): - assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format( - child.num_heads, mp_size) - child.num_heads = child.num_heads // mp_size - if hasattr(child, 'num_attention_heads'): - assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format( - child.num_attention_heads, mp_size) - child.num_attention_heads = child.num_attention_heads // mp_size - if hasattr(child, 'num_attn_heads'): - assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format( - child.num_attn_heads, mp_size) - child.num_attn_heads = child.num_attn_heads // mp_size - if hasattr(child, 'all_head_size'): - assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format( - child.all_head_size, mp_size) - child.all_head_size = child.all_head_size // mp_size - if hasattr(child, 'embed_dim'): - assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format( - child.embed_dim, mp_size) - child.embed_dim = child.embed_dim // mp_size - if hasattr(child, 'hidden_size'): - assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format( - child.hidden_size, mp_size) - child.hidden_size = child.hidden_size // mp_size - - conv_linear_layer = False - if linear_layer_setting is not None: - linear_policies = {linear_layer_setting[0]: _replace} - if len(linear_layer_setting) == 2: - linear_policies.update({linear_layer_setting[1]: _slice_embedding}) - else: - if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class: - try: - import transformers - conv_linear_layer = True - linear_policies = {transformers.model_utils.Conv1D: _replace} - except ImportError: - linear_policies = {nn.Linear: _replace} - else: - linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding} - - def _replace_module(r_module, prev_name=''): - for name, child in r_module.named_children(): - if child.__class__ in linear_policies: - setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, - conv_linear_layer)) - else: - update_mp_params(child) - _replace_module(child, name) - return r_module + def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): + #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) - return _replace_module(module) + # Get the configurable partition_config if available + partition_config = None + if hasattr(config, 'get_partition_config_object'): + partition_config = config.get_partition_config_object() - def replace_fn(child, _policy, layer_id=0): - training = False # todo: refactor this part to go in the config - if training: - # copy relevant state from child -> new module - new_module = replace_with_policy(child, _policy, config.triangular_masking) + # 1. Create AutoTP object + _autotp = AutoTP(module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + config.keep_module_on_host, + partition_config=partition_config) - else: - # copy relevant state from child -> new module - if config.replace_with_kernel_inject: - new_module = replace_with_policy(child, - _policy, - config.triangular_masking, - inference=True, - layer_id=layer_id) + # 2. Set the tensor parallelism config + _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) + + # 3. Try to get num_key_heads from model_config.num_key_value_heads + if hasattr(model_config, "vision_config"): + if "MllamaVisionEncoderLayer" in str(module): + num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config) + elif hasattr(model_config, "text_config"): + num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config) else: - new_module = replace_wo_policy(child, _policy) + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) + else: + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) + + # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + set_num_kv_heads(num_kv_heads) + + # 4.1 Get n_embd + n_embd = None + multi_query_n_embd_names = ['n_embd', 'hidden_size'] + for name in multi_query_n_embd_names: + if hasattr(model_config, name): + n_embd = getattr(model_config, name) + if n_embd != None: + break + + # 4.2 set n_embd + set_n_embd(n_embd) + + # 4.3 set attention_heads + if hasattr(model_config, 'num_attention_heads'): + set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + + # 4.4 set tp_grain_size + set_tp_grain_size(config.tensor_parallel.tp_grain_size) + + # 5. Set linear policies + _autotp.update_linear_policies() + + # 6. Replace modules + if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears: + return _autotp._replace_last_linear_module(module) + return _autotp._replace_module(module) + + def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): + # copy relevant state from child -> new module + if not is_autotp_training_mode() and config.replace_with_kernel_inject: + new_module = replace_with_policy(child, + _policy, + config.triangular_masking, + inference=True, + layer_id=layer_id) + else: + new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict) return new_module - replaced_module = replace_module(model=model, - orig_class=orig_layer_impl, - replace_fn=replace_fn, - _replace_policy=config.injection_policy_tuple) + def set_lm_head(module): + if is_autotp_training_mode(): + # we need to handle autoTP training mode separately. + return + + embedding_weight = None + for n, p in module.named_parameters(): + if "word_embeddings." in n or "embed_tokens." in n or "wte." in n: + embedding_weight = p + if embedding_weight is not None and hasattr(module, "lm_head") and hasattr( + module.lm_head, "weight") and module.lm_head.weight.is_meta: + module.lm_head.weight = embedding_weight + # enable tensor parallel for the last linear + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance( + module.lm_head, torch.nn.Linear): + module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") + elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance( + module.embed_out, torch.nn.Linear): + module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") + elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"): + module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head") + return module + + def conv2d_parallel_shard_weights(model, rank, world_size): + # add conv policy + shard_oc_name = ["conv1"] + shard_ic_name = ["conv2"] + for name, sub_m in model.named_children(): + for l_name, l_sub_m in sub_m.named_children(): + if l_name in shard_oc_name: + TPConv2d = TensorParallelOcShardConv2d( + l_sub_m, + rank, + world_size, + ) + setattr(sub_m, l_name, TPConv2d) + if l_name in shard_ic_name: + TPConv2d = TensorParallelIcShardConv2d( + l_sub_m, + rank, + world_size, + ) + setattr(sub_m, l_name, TPConv2d) + conv2d_parallel_shard_weights(sub_m, rank, world_size) + + if checkpoint_dict is not None and not config.replace_with_kernel_inject: + # AutoTP shard loading + checkpoint = checkpoint_dict["checkpoints"] + pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") + for i in range(len(checkpoint)): + checkpoint_file = os.path.join(config.base_dir, checkpoint[i]) + replaced_module = replace_module(model=model, + orig_class=orig_layer_impl, + replace_fn=replace_fn, + _replace_policy=config.injection_policy_tuple, + checkpoint=checkpoint_file) + pbar.update(1) + gc.collect() + # conv2d tp module replace + # Now is for yuan model. Add model list and conv policy to decide whether to replace conv. + if 'Yuan' in str(replaced_module): + conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size()) + else: + replaced_module = replace_module(model=model, + orig_class=orig_layer_impl, + replace_fn=replace_fn, + _replace_policy=config.injection_policy_tuple) + # AutoTP default set lm_head tp + if not config.replace_with_kernel_inject: + replaced_module = set_lm_head(replaced_module) quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 - if checkpoint_dict is not None: + if checkpoint_dict is not None and config.replace_with_kernel_inject: assert container_g.ckpt_load_enabled, \ f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container" start_time = time.time() @@ -515,7 +434,7 @@ def replace_fn(child, _policy, layer_id=0): pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): - sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] + sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sd, mp_replace, @@ -525,7 +444,6 @@ def replace_fn(child, _policy, layer_id=0): container=container_g) pbar.update(1) else: - import gc num_checkpoints = len(ckpt_list) // ckpt_mp_size tp_split_size = (world_size / ckpt_mp_size) sd_offset = int(rank / tp_split_size) @@ -538,7 +456,7 @@ def replace_fn(child, _policy, layer_id=0): os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j] for j in range(sd_count) ] - sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -558,7 +476,7 @@ def replace_fn(child, _policy, layer_id=0): pbar.update(1) ckpt_file = os.path.join(base_dir1, checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i] - sds = [torch.load(ckpt_file, map_location='cpu')] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -569,9 +487,10 @@ def replace_fn(child, _policy, layer_id=0): container=container_g) sds = [None for _ in sds] gc.collect() + set_lm_head(replaced_module) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") - if config.save_mp_checkpoint_path is not None: + if not is_autotp_training_mode() and config.save_mp_checkpoint_path is not None: from collections import OrderedDict import json num_partitions = 8 @@ -589,16 +508,25 @@ def replace_fn(child, _policy, layer_id=0): if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) - non_tp_ckpt_name = f'non-tp.pt' + non_tp_ckpt_name = 'non-tp.pt' ckpt_files = [non_tp_ckpt_name] os.makedirs(config.save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( - OrderedDict({k: v - for k, v in dict(replaced_module.state_dict()).items() - if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + OrderedDict({ + k: v + for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k + }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + + dtype_reprs = { + torch.float32: 'float32', + torch.float16: 'float16', + torch.int8: 'int8', + torch.bfloat16: 'bfloat16' + } + ckpt_config = json.dumps({ 'type': ckpt_name, 'base_dir': f'{config.save_mp_checkpoint_path}', @@ -609,7 +537,7 @@ def replace_fn(child, _policy, layer_id=0): 'version': 1.0, 'parallelization': 'tp', 'tp_size': world_size, - 'dtype': 'int8' if quantize else ('float16' if fp16 else 'float32') + 'dtype': dtype_reprs[config.dtype] }) with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg: cfg.write(ckpt_config) @@ -634,7 +562,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False): """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced, - e.g., transformers.modeling_bert.BertLayer. + e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer model (torch.nn.Module): user's nn.module representing their model config (dict): model config containing hidden size, attention heads, etc. Returns: @@ -699,7 +627,7 @@ def replace_fn(child, _replace_policy, layer_id): _replace_policy=None) -def replace_module(model, orig_class, replace_fn, _replace_policy): +def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None): """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``. Arguments: model (torch.nn.Module): the model to augment @@ -709,6 +637,14 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): Returns: A modified ``model``. """ + sd = None + if checkpoint is not None: + if checkpoint.endswith(".safetensors"): + from safetensors.torch import load_file + sd = load_file(checkpoint) + else: + sd = torch.load(checkpoint, map_location='cpu', weights_only=False) + policy = {} if orig_class is not None: policy.update({orig_class: (replace_fn, _replace_policy)}) @@ -723,16 +659,35 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): policy.update({plcy._orig_layer_class: (replace_fn, plcy)}) assert len(policy.items()) > 0,\ "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\ - "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py" + "You can find some samples here: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py" - replaced_module, _ = _replace_module(model, policy) + replaced_module, _ = _replace_module(model, policy, state_dict=sd) return replaced_module from ..pipe import PipelineModule +import re -def _replace_module(model, policies, layer_id=0): + +def skip_level_0_prefix(model, state_dict): + model = str(model) + key = re.search(r": (.*?)Model", model) + if key is None: + key = re.search(r": (.*?)Stack", model) + if key is None: + key = re.match(r"(.*?)Model", model) + # if keys start with 'model.', don't skip level 0 prefix + if state_dict is not None: + for item in state_dict.keys(): + if re.match("^model[.]", item): + return False + if key is not None and key.group(1).lower() in ["bloom", "opt"]: + return True + return False + + +def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None): """ Traverse model's children recursively and apply any transformations in ``policies``. Arguments: model (torch.nn.Module): model to augment @@ -742,7 +697,11 @@ def _replace_module(model, policies, layer_id=0): """ for name, child in model.named_children(): if child.__class__ in policies: - replaced_module = policies[child.__class__][0](child, policies[child.__class__][-1], layer_id) + replaced_module = policies[child.__class__][0](child, + policies[child.__class__][-1], + layer_id, + prefix=prefix + name, + state_dict=state_dict) setattr(model, name, replaced_module) if isinstance(model, PipelineModule): assert hasattr(model, 'forward_funcs'),\ @@ -750,7 +709,25 @@ def _replace_module(model, policies, layer_id=0): model.forward_funcs[model.fwd_map[name]] = replaced_module layer_id += 1 else: - _, layer_id = _replace_module(child, policies, layer_id=layer_id) + checking_key = prefix + name + '.' + if Loading.is_load_module(child) and state_dict is not None: + if any(checking_key in item for item in state_dict): + Loading.load( + child, + state_dict, + checking_key, + ) + else: + continue + if len(child._buffers) != 0 and state_dict is not None: + Loading.load_buffer(child, state_dict, checking_key) + _, layer_id = _replace_module(child, + policies, + prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \ + prefix + name + '.', + layer_id=layer_id, + level_id=level_id + 1, + state_dict=state_dict) # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation. model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index af58d3d8d2d7..2c06e31aaa41 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -13,13 +13,17 @@ from .containers import MegatronLayerPolicy from .containers import HFDistilBertLayerPolicy from .containers import HFCLIPLayerPolicy +from .containers import LLAMALayerPolicy from .containers import UNetPolicy from .containers import VAEPolicy +from .containers import LLAMA2LayerPolicy +from .containers import InternLMLayerPolicy # transformer-based policies replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy, - HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy + HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy, + LLAMALayerPolicy, LLAMA2LayerPolicy, InternLMLayerPolicy ] # non-transformer-based policies diff --git a/deepspeed/module_inject/tp_plan_converter.py b/deepspeed/module_inject/tp_plan_converter.py new file mode 100644 index 000000000000..bed0567e4445 --- /dev/null +++ b/deepspeed/module_inject/tp_plan_converter.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import logging +from typing import List, Dict, Optional +from .autotp_config import TPLayerSpec, PartitionType + +logger = logging.getLogger(__name__) + +SUPPORTED_STYLES = {"colwise", "rowwise"} + + +class TPPlanConverter: + """Convert HuggingFace tp_plan format to DeepSpeed TPLayerSpec format.""" + + @staticmethod + def convert(hf_tp_plan: Dict[str, str]) -> Optional[List[TPLayerSpec]]: + """Convert HF tp_plan to DeepSpeed layer specs. + + Returns None if the plan contains any unsupported partition styles, + allowing the caller to fall back to the existing AutoTP path. + """ + unsupported = {style for style in hf_tp_plan.values() if style.lower() not in SUPPORTED_STYLES} + if unsupported: + logger.warning( + "HuggingFace tp_plan contains unsupported partition style(s): %s. " + "Falling back to AutoTP preset-based partitioning.", sorted(unsupported)) + return None + + layer_specs = [] + + for pattern, partition in hf_tp_plan.items(): + regex_pattern = TPPlanConverter._wildcard_to_regex(pattern) + + if partition.lower() == "colwise": + partition_type = PartitionType.COLUMN + elif partition.lower() == "rowwise": + partition_type = PartitionType.ROW + + # Only add .weight suffix if not already present + if not regex_pattern.endswith(r"\.weight"): + regex_pattern += r"\.weight$" + else: + regex_pattern += r"$" + + layer_specs.append(TPLayerSpec( + patterns=[regex_pattern], + partition_type=partition_type, + )) + + return layer_specs + + @staticmethod + def _wildcard_to_regex(pattern: str) -> str: + regex = pattern.replace('.', r'\.') + regex = regex.replace('*', r'.*') + return ".*" + regex diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py new file mode 100644 index 000000000000..f1dbaae43ec9 --- /dev/null +++ b/deepspeed/module_inject/tp_shard.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed import comm as dist + +# Defaults for optional TP globals. These can be overridden by setters. +num_kv_heads = None +num_attention_heads = None +n_embd = None +tp_grain_size = 1 + + +def set_num_kv_heads(num): + global num_kv_heads + num_kv_heads = num + + +def set_num_attention_heads(num): + global num_attention_heads + num_attention_heads = num + + +def set_n_embd(num): + global n_embd + n_embd = num + + +def set_tp_grain_size(num): + global tp_grain_size + tp_grain_size = num + + +def get_num_kv_heads(): + global num_kv_heads + if 'num_kv_heads' in globals(): + return num_kv_heads + return None + + +def get_num_attention_heads(): + global num_attention_heads + return num_attention_heads + + +def get_shard_size(total_size, mp_size, name=None, rank=None): + global num_kv_heads + last_linear = ["lm_head", "embed_out"] + # MoE MLP layer use near even division will get better perf. + moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"] + not_moe_mlp_layer = True + if name != None and any(s in str(name) for s in moe_mlp_layer): + not_moe_mlp_layer = False + # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division + if rank == None: + rank = dist.get_rank() + if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( + name) not in last_linear and not_moe_mlp_layer: + my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) + return total_size * my_slices // num_kv_heads + else: + if total_size >= tp_grain_size: + grain_size = total_size // tp_grain_size + return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size + else: + return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) + + +def get_n_embd(): + global n_embd + return n_embd + + +def get_shard_size_list(total_size, mp_size, name=None): + shard_sizes = [] + for i in range(mp_size): + shard_sizes.append(get_shard_size(total_size, mp_size, name, i)) + return shard_sizes diff --git a/deepspeed/module_inject/utils.py b/deepspeed/module_inject/utils.py index ad60e225fcea..1837063ec63f 100644 --- a/deepspeed/module_inject/utils.py +++ b/deepspeed/module_inject/utils.py @@ -3,9 +3,19 @@ # DeepSpeed Team +import torch from deepspeed.utils import log_dist +def transpose(data): + with torch.no_grad(): + data = data.contiguous() + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None + return data.reshape(data.shape[-1], data.shape[-2]) + + # helper function to map between DS policies and DS containers def policy_to_ds_container(**kwargs): from .containers import HFGPT2LayerPolicy, DS_GPT2Container @@ -17,6 +27,9 @@ def policy_to_ds_container(**kwargs): from .containers import HFOPTLayerPolicy, DS_OPTContainer from .containers import MegatronLayerPolicy, DS_MegatronGPTContainer from .containers import HFDistilBertLayerPolicy, DS_DistilBERTContainer + from .containers import LLAMALayerPolicy, DS_LLAMAContainer + from .containers import LLAMA2LayerPolicy, DS_LLAMA2Container + from .containers import InternLMLayerPolicy, DS_InternLMContainer policy_to_container = { HFGPT2LayerPolicy: DS_GPT2Container, @@ -28,6 +41,9 @@ def policy_to_ds_container(**kwargs): HFOPTLayerPolicy: DS_OPTContainer, MegatronLayerPolicy: DS_MegatronGPTContainer, HFDistilBertLayerPolicy: DS_DistilBERTContainer, + LLAMALayerPolicy: DS_LLAMAContainer, + LLAMA2LayerPolicy: DS_LLAMA2Container, + InternLMLayerPolicy: DS_InternLMContainer } container = None diff --git a/deepspeed/moe/experts.py b/deepspeed/moe/experts.py index 8cadb0c387fa..0863221d7edf 100644 --- a/deepspeed/moe/experts.py +++ b/deepspeed/moe/experts.py @@ -3,33 +3,36 @@ # DeepSpeed Team -import torch import copy +from typing import List, Optional + +import torch +from torch import nn -class Experts(torch.nn.Module): +class Experts(nn.Module): - def __init__(self, expert, num_local_experts=1, expert_group_name=None): + def __init__(self, expert: nn.Module, num_local_experts: int = 1, expert_group_name: Optional[str] = None) -> None: super(Experts, self).__init__() - self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)]) + self.deepspeed_experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(num_local_experts)]) self.num_local_experts = num_local_experts # TODO: revisit allreduce for moe.gate... for expert in self.deepspeed_experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) - for name, param in expert.named_parameters(): + for param in expert.parameters(): param.allreduce = False param.group_name = expert_group_name - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: chunks = inputs.chunk(self.num_local_experts, dim=1) - expert_outputs = [] + expert_outputs: List[torch.Tensor] = [] + for chunk, expert in zip(chunks, self.deepspeed_experts): out = expert(chunk) - if type(out) is tuple: + if isinstance(out, tuple): out = out[0] # Ignore the bias term for now expert_outputs += [out] - expert_output = torch.cat(expert_outputs, dim=1) - return expert_output + return torch.cat(expert_outputs, dim=1) diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index 89fe2bb46c3c..6777788ab885 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -3,22 +3,23 @@ # DeepSpeed Team -import torch +from typing import Optional, Tuple -from deepspeed.utils import log_dist +import torch +from torch import nn +from torch.nn import functional as F -from deepspeed.utils import groups -from .sharded_moe import MOELayer, TopKGate +from deepspeed.utils import groups, log_dist from .experts import Experts -import typing +from .sharded_moe import MOELayer, TopKGate -class MoE(torch.nn.Module): +class MoE(nn.Module): """Initialize an MoE layer. Arguments: hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension. - expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). + expert (nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). num_experts (int, optional): default=1, the total number of experts per layer. ep_size (int, optional): default=1, number of ranks in the expert parallel world or group. k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. @@ -31,23 +32,25 @@ class MoE(torch.nn.Module): use_rts (bool, optional): default=True, whether to use Random Token Selection. use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed). enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts + top2_2nd_expert_sampling (bool, optional): default=True, whether to perform sampling for 2nd expert """ def __init__(self, - hidden_size, - expert, - num_experts=1, - ep_size=1, - k=1, - capacity_factor=1., - eval_capacity_factor=1., - min_capacity=4, - use_residual=False, - noisy_gate_policy: typing.Optional[str] = None, + hidden_size: int, + expert: nn.Module, + num_experts: int = 1, + ep_size: int = 1, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 4, + use_residual: bool = False, + noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, - use_rts=True, + use_rts: bool = True, use_tutel: bool = False, - enable_expert_tensor_parallelism: bool = False): + enable_expert_tensor_parallelism: bool = False, + top2_2nd_expert_sampling: bool = True) -> None: super(MoE, self).__init__() @@ -68,7 +71,8 @@ def __init__(self, experts = Experts(expert, self.num_local_experts, self.expert_group_name) self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, - min_capacity, noisy_gate_policy, drop_tokens, use_rts), + min_capacity, noisy_gate_policy, drop_tokens, use_rts, None, + top2_2nd_expert_sampling), experts, self.expert_group_name, self.ep_size, @@ -77,26 +81,30 @@ def __init__(self, if self.use_residual: self.mlp = expert # coefficient is used for weighted sum of the output of expert and mlp - self.coefficient = torch.nn.Linear(hidden_size, 2) + self.coefficient = nn.Linear(hidden_size, 2) - def set_deepspeed_parallelism(self): - self._create_process_groups() + def set_deepspeed_parallelism(self, use_data_before_expert_parallel_: bool = False) -> None: + self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_) - def _create_process_groups(self): + def _create_process_groups(self, use_data_before_expert_parallel_: bool = False) -> None: # Create process group for a layer if needed if self.expert_group_name not in groups._get_expert_parallel_group_dict(): print(f"No existing process group found, creating a new group named: {self.expert_group_name}") if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism): # Condition 1 - no groups.mpu means no tensor parallelism # Condition 2 - disabling expert tensor parallelism on purpose - groups._create_expert_and_data_parallel(self.ep_size) + groups._create_expert_and_data_parallel( + self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_) else: # expert tensor parallelism is enabled - groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu) + groups._create_expert_data_and_model_parallel( + self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_) # Set the group handle for the MOELayer (deepspeed_moe) object self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name)) - def forward(self, hidden_states, used_token=None): + def forward(self, + hidden_states: torch.Tensor, + used_token: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ MoE forward Arguments: @@ -110,15 +118,15 @@ def forward(self, hidden_states, used_token=None): * l_aux (Tensor): gate loss value - * exp_counts (int): expert count + * exp_counts (Tensor): expert count """ output = self.deepspeed_moe(hidden_states, used_token) if self.use_residual: # Residual MoE output_mlp = self.mlp(hidden_states) - if type(output_mlp) is tuple: + if isinstance(output_mlp, tuple): output_mlp = output_mlp[0] # Ignore the bias term for now coef = self.coefficient(hidden_states) - coef = torch.nn.functional.softmax(coef, dim=-1) + coef = F.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts diff --git a/deepspeed/moe/mappings.py b/deepspeed/moe/mappings.py index 6c501ea6503a..e57f66b85193 100644 --- a/deepspeed/moe/mappings.py +++ b/deepspeed/moe/mappings.py @@ -23,6 +23,8 @@ import torch import deepspeed +from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank, + bwc_tensor_model_parallel_group) def _gather_tokens(input_, dim=0): @@ -30,15 +32,23 @@ def _gather_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu input_ = input_.contiguous() - # Size and dimension. - rank = mpu.get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())] - tensor_list[rank] = input_ - deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group()) + world_size = bwc_tensor_model_parallel_world_size(mpu) + if world_size == 1: + return input_ - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() + gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) + deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu)) + if dim == 0: + shape = list(input_.size()) + shape[0] = shape[0] * world_size + output = gather_buffer.view(shape) + else: + tensor_list = [ + gather_buffer.narrow(0, + input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) + ] + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -47,8 +57,10 @@ def _drop_tokens(input_, dim=0): """Divide a tensor among the tensor parallel ranks""" mpu = deepspeed.utils.groups.mpu - total_chunks = mpu.get_tensor_model_parallel_world_size() - this_chunk = mpu.get_tensor_model_parallel_rank() + total_chunks = bwc_tensor_model_parallel_world_size(mpu) + if total_chunks == 1: + return input_ + this_chunk = bwc_tensor_model_parallel_rank(mpu) assert input_.shape[ dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" chunk_size = input_.shape[dim] // total_chunks @@ -92,7 +104,7 @@ def backward(ctx, input_): def gather_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu - if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1: + if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _GatherTokens.apply(input_, dim) @@ -100,7 +112,7 @@ def gather_tokens(input_, dim=0): def drop_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu - if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1: + if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _DropTokens.apply(input_, dim) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 93dff21ea702..d2a6c089e8e7 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -17,7 +17,9 @@ from deepspeed.utils.timer import SynchronizedWallClockTimer from deepspeed.utils import logger -from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple +from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size +from deepspeed.utils.torch import jit_script_compat +from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union import torch from torch import Tensor @@ -31,16 +33,21 @@ else: Base = Module +TOPK_GATE_TIMER = 'topk_gate' +MOE_TIMER = 'moe' +FIRST_ALLTOALL_TIMER = '1st_a2a' +SECOND_ALLTOALL_TIMER = '2nd_a2a' + uniform_map: Dict[torch.device, Callable] = {} gumbel_map: Dict[torch.device, Callable] = {} exp_selection_uniform_map: Dict[torch.device, Callable] = {} try: # To enable Tutel MoE optimizations: - # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x + # python3 -m pip install --user --upgrade git+https://github.com/deepspeedai/tutel@v0.1.x from tutel import moe as tutel_moe TUTEL_INSTALLED = True -except: +except Exception: # Fail silently so we don't spam logs unnecessarily if user isn't using tutel TUTEL_INSTALLED = False pass @@ -90,11 +97,7 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: class _AllToAll(torch.autograd.Function): @staticmethod - def forward( - ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, - input: Tensor) -> Tensor: # type: ignore + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input) @@ -122,6 +125,8 @@ def einsum(rule, a, b): return a.unsqueeze(2) * b.unsqueeze(1) elif rule == 'se,se->s': return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'se,sec->sec': + return a.unsqueeze(2) * b elif rule == 'sec,sm->ecm': s = a.shape[0] e = a.shape[1] @@ -153,7 +158,7 @@ def einsum(rule, a, b): # includes stateful caching logic which is incompatible with ONNX. -@torch.jit.script +@jit_script_compat def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: # gates has shape of SE num_tokens = gates.shape[0] @@ -166,12 +171,12 @@ def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> T return capacity -@torch.jit.script +@jit_script_compat def _top_idx(source, k): return torch.topk(source, k=k, dim=0)[1] -@torch.jit.script +@jit_script_compat def _one_hot_to_float(x, num_classes): return F.one_hot(x, num_classes=num_classes).float() @@ -183,13 +188,14 @@ def top1gating(logits: Tensor, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function - gates = F.softmax(logits, dim=1) + gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token @@ -203,13 +209,21 @@ def top1gating(logits: Tensor, mask1 = einsum("s,se->se", used_token, mask1) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device) # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) - capacity = new_capacity + # Communicate across expert processes to pick the maximum capacity. + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + # Make sure the capacity value does not exceed the number of tokens. + capacity = min(new_capacity, torch.tensor(mask1.size(0)).to(new_capacity.device)) # Compute l_aux me = torch.mean(gates, dim=0) @@ -274,23 +288,28 @@ def top1gating(logits: Tensor, return l_aux, combine_weights, dispatch_mask, exp_counts -def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def top2gating(logits: Tensor, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, + top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) - capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) - # Create a mask for 1st's expert per token indices1_s = torch.argmax(gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) - # Create a mask for 2nd's expert per token using Gumbel-max trick - # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ - logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + if top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value - logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) indices2_s = torch.argmax(logits_except1, dim=1) mask2 = F.one_hot(indices2_s, num_classes=num_experts) @@ -300,17 +319,30 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) - # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') - # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts - # Remove locations outside capacity from mask - mask1 *= torch.lt(locations1, capacity) - mask2 *= torch.lt(locations2, capacity) + # gating decisions + exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) + + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) + mask1 *= torch.lt(locations1, capacity) + mask2 *= torch.lt(locations2, capacity) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = torch.max(exp_counts) + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + capacity = new_capacity # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) @@ -340,6 +372,83 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup return l_aux, combine_weights, dispatch_mask, exp_counts +def topkgating( + logits: Tensor, + k: int, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, + drop_policy: str = "probs", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements TopKGating on logits.""" + + # everything is in fp32 in this function + # gating decisions + gates = F.softmax(logits, dim=1) + num_experts = int(gates.shape[1]) + + # get topk gates + top_gate, top_idx = torch.topk(gates, k=k, dim=1) + + mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) + + exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = torch.mean(me * ce) * num_experts * num_experts / k + locations = None + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) + # update mask and locations by capacity + + if drop_policy == 'probs': + topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate) + _, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(0, capacity_indices, True) + mask &= capacity_mask + locations = torch.cumsum(mask, dim=0) - 1 + + elif drop_policy == "position": + locations = torch.cumsum(mask, dim=0) - 1 + mask *= torch.lt(locations, capacity) + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = torch.max(exp_counts) + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + capacity = new_capacity + locations = torch.cumsum(mask, dim=0) - 1 + + # normalize gates + gates_masked = gates * mask + gates_s = torch.sum(gates_masked, dim=-1, keepdim=True) + denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps) + gates_masked = gates_masked / denom_s + + if locations is None: + raise ValueError(f"Locations is not set: {locations}") + # dispatch_mask + locations_sc = _one_hot_to_float((locations * mask), capacity) + + combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + class TopKGate(Module): """Gate module which implements Top2Gating as described in Gshard_. :: @@ -352,7 +461,7 @@ class TopKGate(Module): Args: model_dim (int): size of model embedding dimension - num_experts (ints): + num_experts (int): number of experts in model """ @@ -367,13 +476,13 @@ def __init__(self, min_capacity: int = 8, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, - use_rts: bool = True) -> None: + use_rts: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, + top2_2nd_expert_sampling: bool = True) -> None: super().__init__() - # Only top-1 and top-2 are supported at the moment. - if k != 1 and k != 2: - raise ValueError('Only top-1 and top-2 gatings are supported.') - self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) + self.ep_group = ep_group self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor @@ -384,6 +493,11 @@ def __init__(self, self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts + self.top2_2nd_expert_sampling = top2_2nd_expert_sampling + + def _set_ep_group(self, ep_group): + assert self.ep_group is None, 'Attempting to override an existing ep_group' + self.ep_group = ep_group def forward(self, input: torch.Tensor, @@ -391,28 +505,30 @@ def forward(self, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore if self.wall_clock_breakdown: - self.timers('TopKGate').start() + self.timers(TOPK_GATE_TIMER).start() - if self.wg.weight.dtype != torch.float32: - self.wg = self.wg.float() input_fp32 = input.float() # input jittering if self.noisy_gate_policy == 'Jitter' and self.training: input_fp32 = multiplicative_jitter(input_fp32, device=input.device) - logits = self.wg(input_fp32) + logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None) if self.k == 1: gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, - self.drop_tokens, self.use_rts, use_tutel) + self.drop_tokens, self.use_rts, self.ep_group, use_tutel) - else: + elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity) + self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) + else: + gate_output = topkgating(logits, self.k, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, self.drop_tokens, self.ep_group) if self.wall_clock_breakdown: - self.timers('TopKGate').stop() - self.gate_time = self.timers('TopKGate').elapsed(reset=False) + self.timers(TOPK_GATE_TIMER).stop() + self.gate_time = self.timers(TOPK_GATE_TIMER).elapsed(reset=False) return gate_output @@ -468,11 +584,12 @@ def __init__(self, def _set_ep_group(self, ep_group): self.ep_group = ep_group + self.gate._set_ep_group(ep_group) def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: - self.timers('moe').start() + self.timers(MOE_TIMER).start() # Implement Algorithm 2 from GShard paper. d_model = input[0].shape[-1] @@ -495,41 +612,61 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) if self.wall_clock_breakdown: - self.timers('falltoall').start() + self.timers(FIRST_ALLTOALL_TIMER).start() - if groups._get_expert_model_parallel_world_size() == 1: - # If the non-expert is tensor-parallel, it will create + tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu) + if tensor_model_world_size > 1: + # If the non-expert is tensor-parallel, + # Whether expert is tensor-parallel or not , it will create # duplicate tokens on the tensor-parallel ranks. - # Since our experts are not tensor-parallel, these duplicates - # need to be dropped to ensure correctness. - # this also doubles up as a communication optimization as we are - # reducing the all-to-all communication volume. + # drop duplicate tokens also doubles up as a communication + # optimization as we are reducing the all-to-all communication volume. + # 1: for not tensor-parallel expert,drop duplicate tokens to ensure + # both correctness and reduce all-to-all communication. + # 2: for tensor-parallel expert,drop duplicate tokens to reduce all-to-all + # communication volume,before expert execution, it is necessary to perform + # an allgather to ensure correctness, dispatched_input = drop_tokens(dispatched_input, dim=1) - dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + if self.ep_size > 1: + dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + else: + dispatched_input = dispatched_input.contiguous() if self.wall_clock_breakdown: - self.timers('falltoall').stop() - self.time_falltoall = self.timers('falltoall').elapsed(reset=False) + self.timers(FIRST_ALLTOALL_TIMER).stop() + self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False) + + if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: + # if both expert and non-expert are tensor-parallel + # the dropped duplicate tokens need to be gathered on each + # tensor parallel rank again to ensure correctness + dispatched_input = gather_tokens(dispatched_input, dim=1) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) - expert_output = self.experts(dispatched_input) + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) + if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: + # if both expert and non-expert are tensor-parallel + # drop duplicate tokens to ensure both correctness + # and reduce all-to-all communication. + expert_output = drop_tokens(expert_output, dim=1) if self.wall_clock_breakdown: - self.timers('salltoall').start() + self.timers(SECOND_ALLTOALL_TIMER).start() - expert_output = _AllToAll.apply(self.ep_group, expert_output) + if self.ep_size > 1: + expert_output = _AllToAll.apply(self.ep_group, expert_output) + else: + expert_output = expert_output.contiguous() if self.wall_clock_breakdown: - self.timers('salltoall').stop() - self.time_salltoall = self.timers('salltoall').elapsed(reset=False) + self.timers(SECOND_ALLTOALL_TIMER).stop() + self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False) - # Re-shape back: gecm -> ecm - expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) - - if groups._get_expert_model_parallel_world_size() == 1: + if tensor_model_world_size > 1: # the dropped duplicate tokens need to be gathered on each # tensor parallel rank again for the tensor-parallel # non-expert of the next layer. @@ -543,7 +680,7 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: a = combined_output.reshape(input[0].shape) if self.wall_clock_breakdown: - self.timers('moe').stop() - self.time_moe = self.timers('moe').elapsed(reset=False) + self.timers(MOE_TIMER).stop() + self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False) return a diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index da31f550aabc..20866378efac 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -3,16 +3,20 @@ # DeepSpeed Team -from typing import List, Tuple, Dict +from collections import defaultdict +from typing import Any, Dict, List, Set, Tuple, Union, cast + import torch +from torch import nn + from .layer import MoE -def has_moe_layers(m): +def has_moe_layers(m: nn.Module) -> Tuple[bool, int]: has_moe = False num_experts = 0 - for _, module in m.named_modules(): + for module in m.modules(): if isinstance(module, MoE): has_moe = True num_experts = module.num_experts @@ -27,8 +31,10 @@ def is_moe_param(param: torch.Tensor) -> bool: def split_params_into_shared_and_expert_params( - params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: - shared_params, expert_params = [], [] + params: List[torch.nn.Parameter]) -> Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: + shared_params: List[nn.Parameter] = [] + expert_params: List[nn.Parameter] = [] + for p in params: if is_moe_param(p): expert_params.append(p) @@ -38,7 +44,7 @@ def split_params_into_shared_and_expert_params( def split_params_grads_into_shared_and_expert_params( - group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + group: List[torch.nn.Parameter]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Split grad of parameters into grads of non-expert params and grads of expert params. This is useful while computing grad-norms for clipping and overflow detection @@ -48,11 +54,12 @@ def split_params_grads_into_shared_and_expert_params( The group of parameters to split Returns: - Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: + Tuple[List[torch.Tensor], List[torch.Tensor]]: list of gradients for non MoE params, list of gradients of MoE params """ - expert_grads = [] - shared_grads = [] + expert_grads: List[torch.Tensor] = [] + shared_grads: List[torch.Tensor] = [] + for p in group: if p.grad is not None: if is_moe_param(p): @@ -62,16 +69,17 @@ def split_params_grads_into_shared_and_expert_params( return shared_grads, expert_grads -def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], - max_group_size=178956971) -> Tuple[Dict]: +def split_params_into_different_moe_groups_for_optimizer( + param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]], + max_group_size: Union[int, float] = 178956971) -> List[Dict[str, Any]]: """Split parameters into different MoE groups for optimizer Args: - param_groups (Tuple[Dict]): + param_groups (Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]]) The list of parameter groups to split Returns: - Tuple[Dict]: + List[Dict[str, Any]]: list of MoE/non-MoE groups for optimizer """ if isinstance(param_groups, tuple): @@ -82,45 +90,43 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic raise ValueError(f"Unknown param group type of {type(param_groups)}") # gather all data parallel group names - data_parallel_group_names = set() + data_parallel_group_names: Set[str] = set() for param_group in param_groups: - for param in param_group["params"]: + for param in cast(List[nn.Parameter], param_group["params"]): if is_moe_param(param): data_parallel_group_names.add(param.group_name) - data_parallel_group_names = list(data_parallel_group_names) - group_moe = {} + # Create the param MoE groups, leave param assign to next step + group_moe: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) for param_group in param_groups: - group_moe[param_group['name']] = {} for key in data_parallel_group_names: - group_moe[param_group['name']][key] = {} - group_moe[param_group['name']][key]['name'] = key - group_moe[param_group['name']][key]['moe'] = True - for ori_key in param_group.keys(): - if ori_key != 'name': - if ori_key == 'params': - group_moe[param_group['name']][key][ori_key] = [] - else: - group_moe[param_group['name']][key][ori_key] = param_group[ori_key] + group_moe[param_group['name']][key] = { + **param_group, + 'name': key, + 'moe': True, + 'params': [], + } + # Assign param for param_group in param_groups: - new_params = [] - for param in param_group['params']: + new_params: List[nn.Parameter] = [] + + for param in cast(List[nn.Parameter], param_group['params']): if is_moe_param(param): group_moe[param_group['name']][param.group_name]['params'].append(param) - # param_group['params'].remove(param) else: new_params.append(param) param_group['params'] = new_params # Flatten the moe groups if max_group_size is not None: - for k, v in group_moe.items(): - for k1, v1 in v.items(): - cur_group = [] - all_groups = [] + for moe_group in group_moe.values(): + for param_group in moe_group.values(): + cur_group: List[nn.Parameter] = [] + all_groups: List[List[nn.Parameter]] = [] size_of_cur_group = 0 - for param in v1['params']: + + for param in cast(List[nn.Parameter], param_group['params']): if size_of_cur_group + param.numel() <= max_group_size: cur_group.append(param) size_of_cur_group += param.numel() @@ -128,18 +134,49 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic all_groups.append(cur_group) cur_group = [param] size_of_cur_group = param.numel() + if cur_group: all_groups.append(cur_group) + for group in all_groups: - new_dict = {} - for key, val in v1.items(): - if key != 'params': - new_dict[key] = val - new_dict['params'] = group - param_groups.append(new_dict) + param_groups.append({**param_group, 'params': group}) else: - for k, v in group_moe.items(): - for k1, v1 in v.items(): - param_groups.append(v1) + for moe_group in group_moe.values(): + for param_group in moe_group.values(): + param_groups.append(param_group) + + return param_groups + + +def is_moe_param_group(param_group): + return param_group.get('moe', False) + - return tuple(param_groups) +def configure_moe_param_groups(model_parameters: List): + assert isinstance(model_parameters, list), "model_parameters must be a list" + + for p in model_parameters: + # match torch.optim.Optimizer expectations, + # see: https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/torch/optim/optimizer.py#L270-L272 + if not isinstance(p, (torch.Tensor, dict)): + raise TypeError("param argument that would be given to the optimizer should be " + f"an iterable of Tensors or dicts, but got {type(p)}") + + # peak at the first element to determine how to proceed + first = model_parameters[0] + + # Case 1: model_parameters is a list of torch.nn.Parameter + # -> need to create moe compatible param groups + if isinstance(first, torch.nn.Parameter): + param_group = {'params': model_parameters, 'name': 'dense-params'} + return split_params_into_different_moe_groups_for_optimizer(param_group) + + # Case 2: model_parameters is a list of param groups List[dict] + # -> moe compatible param groups might already exist, if not create them + elif isinstance(first, dict): + #there are no moe groups created + if not any(['moe' in param_group for param_group in model_parameters]): + return split_params_into_different_moe_groups_for_optimizer(model_parameters) + else: + # moe groups exist, nothing to do + return model_parameters diff --git a/deepspeed/monitor/comet.py b/deepspeed/monitor/comet.py new file mode 100644 index 000000000000..d8bc4017800f --- /dev/null +++ b/deepspeed/monitor/comet.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional + +from .utils import check_comet_availability +from .monitor import Monitor + +import deepspeed.comm as dist + +if TYPE_CHECKING: + import comet_ml + from .config import CometConfig + +Name = str +Value = Any +GlobalSamples = int +Event = Tuple[Name, Value, GlobalSamples] + + +class CometMonitor(Monitor): + + def __init__(self, comet_config: "CometConfig"): + super().__init__(comet_config) + check_comet_availability() + import comet_ml + + self.enabled = comet_config.enabled + self._samples_log_interval = comet_config.samples_log_interval + self._experiment: Optional["comet_ml.ExperimentBase"] = None + + if self.enabled and dist.get_rank() == 0: + self._experiment = comet_ml.start( + api_key=comet_config.api_key, + project=comet_config.project, + workspace=comet_config.workspace, + experiment_key=comet_config.experiment_key, + mode=comet_config.mode, + online=comet_config.online, + ) + + if comet_config.experiment_name is not None: + self._experiment.set_name(comet_config.experiment_name) + + self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval) + + @property + def experiment(self) -> Optional["comet_ml.ExperimentBase"]: + return self._experiment + + @property + def samples_log_interval(self) -> int: + return self._samples_log_interval + + def write_events(self, event_list: List[Event]) -> None: + if not self.enabled or dist.get_rank() != 0: + return None + + for event in event_list: + name = event[0] + value = event[1] + engine_global_samples = event[2] + + if self._events_log_scheduler.needs_logging(name, engine_global_samples): + self._experiment.__internal_api__log_metric__( + name=name, + value=value, + step=engine_global_samples, + ) + + +class EventsLogScheduler: + + def __init__(self, samples_log_interval: int): + self._samples_log_interval = samples_log_interval + self._last_logged_events_samples: Dict[str, int] = {} + + def needs_logging(self, name: str, current_sample: int) -> bool: + if name not in self._last_logged_events_samples: + self._last_logged_events_samples[name] = current_sample + return True + + last_logged_sample = self._last_logged_events_samples[name] + samples_delta = current_sample - last_logged_sample + + if samples_delta >= self._samples_log_interval: + self._last_logged_events_samples[name] = current_sample + return True + + return False diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index 0cd02603bd35..960ce1ba997a 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -3,12 +3,14 @@ # DeepSpeed Team -from pydantic import root_validator +from typing import Optional + +from pydantic import model_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel def get_monitor_config(param_dict): - monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")} + monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor", "comet")} return DeepSpeedMonitorConfig(**monitor_dict) @@ -34,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel): enabled: bool = False """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """ - group: str = None + group: Optional[str] = None """ Name for the WandB group. This can be used to group together runs. """ - team: str = None + team: Optional[str] = None """ Name for the WandB team. """ project: str = "deepspeed" @@ -60,21 +62,83 @@ class CSVConfig(DeepSpeedConfigModel): """ Name for the current job. This will become a new directory inside `output_path`. """ +class CometConfig(DeepSpeedConfigModel): + """ + Sets parameters for Comet monitor. For logging data Comet uses + experiment object. + https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/ + """ + + enabled: bool = False + """ Whether logging to Comet is enabled. Requires `comet_ml` package is installed. """ + + samples_log_interval: int = 100 + """ Metrics will be submitted to Comet after processing every `samples_log_intervas` samples""" + + project: Optional[str] = None + """ + Comet project name. Can be set through .comet.config file or environment variable COMET_PROJECT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + workspace: Optional[str] = None + """ + Comet workspace name. Can be set through .comet.config file or environment variable COMET_WORKSPACE + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + api_key: Optional[str] = None + """ + Comet API key. Can be set through .comet.config file or environment variable COMET_API_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_name: Optional[str] = None + """ + The name for comet experiment to be used for logging. + Can be set through .comet.config file or environment variable COMET_EXPERIMENT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_key: Optional[str] = None + """ + The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. + Can be set through .comet.config or environment variable COMET_EXPERIMENT_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + online: Optional[bool] = None + """ + If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment + Defaults to True. + """ + + mode: Optional[str] = None + """ + Control how the Comet experiment is started, 3 options are possible.: + - "get": Continue logging to an existing experiment identified by the `experiment_key` value. + - "create": Always creates of a new experiment, useful for HPO sweeps. + - "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. + """ + + class DeepSpeedMonitorConfig(DeepSpeedConfigModel): """Sets parameters for various monitoring methods.""" tensorboard: TensorBoardConfig = {} """ TensorBoard monitor, requires `tensorboard` package is installed. """ + comet: CometConfig = {} + """ Comet monitor, requires `comet_ml` package is installed """ + wandb: WandbConfig = {} """ WandB monitor, requires `wandb` package is installed. """ csv_monitor: CSVConfig = {} """ Local CSV output of monitoring data. """ - @root_validator - def check_enabled(cls, values): - values["enabled"] = False - if (values.get("tensorboard").enabled or values.get("wandb").enabled or values.get("csv_monitor").enabled): - values["enabled"] = True - return values + @model_validator(mode="after") + def check_enabled(self): + enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled + self.__dict__["enabled"] = enabled + return self diff --git a/deepspeed/monitor/monitor.py b/deepspeed/monitor/monitor.py index 5a32b8bbcadd..e7e26dc483d9 100644 --- a/deepspeed/monitor/monitor.py +++ b/deepspeed/monitor/monitor.py @@ -24,6 +24,7 @@ def write_events(self, event_list): from .wandb import WandbMonitor from .tensorboard import TensorBoardMonitor from .csv_monitor import csvMonitor +from .comet import CometMonitor class MonitorMaster(Monitor): @@ -33,6 +34,7 @@ def __init__(self, monitor_config): self.tb_monitor = None self.wandb_monitor = None self.csv_monitor = None + self.comet_monitor = None self.enabled = monitor_config.enabled if dist.get_rank() == 0: @@ -42,6 +44,8 @@ def __init__(self, monitor_config): self.wandb_monitor = WandbMonitor(monitor_config.wandb) if monitor_config.csv_monitor.enabled: self.csv_monitor = csvMonitor(monitor_config.csv_monitor) + if monitor_config.comet.enabled: + self.comet_monitor = CometMonitor(monitor_config.comet) def write_events(self, event_list): if dist.get_rank() == 0: @@ -51,3 +55,5 @@ def write_events(self, event_list): self.wandb_monitor.write_events(event_list) if self.csv_monitor is not None: self.csv_monitor.write_events(event_list) + if self.comet_monitor is not None: + self.comet_monitor.write_events(event_list) diff --git a/deepspeed/monitor/utils.py b/deepspeed/monitor/utils.py index a9bd915f43f3..f5530e8532e1 100644 --- a/deepspeed/monitor/utils.py +++ b/deepspeed/monitor/utils.py @@ -3,12 +3,14 @@ # DeepSpeed Team +from packaging import version as pkg_version + def check_tb_availability(): try: # torch.utils.tensorboard will fail if `tensorboard` is not available, # see their docs for more details: https://pytorch.org/docs/1.8.0/tensorboard.html - import tensorboard # noqa: F401 + import tensorboard # noqa: F401 # type: ignore except ImportError: print('If you want to use tensorboard logging, please `pip install tensorboard`') raise @@ -16,9 +18,20 @@ def check_tb_availability(): def check_wandb_availability(): try: - import wandb # noqa: F401 + import wandb # noqa: F401 # type: ignore except ImportError: print( 'If you want to use wandb logging, please `pip install wandb` and follow the instructions at https://docs.wandb.ai/quickstart' ) raise + + +def check_comet_availability(): + try: + import comet_ml + comet_version = pkg_version.parse(comet_ml.__version__) + if comet_version < pkg_version.Version("3.41.0"): + raise ImportError("`comet_ml` must have at least version 3.41.0") + except ImportError: + print('If you want to use comet logging, please `pip install "comet_ml>=3.41.0"`') + raise diff --git a/deepspeed/monitor/wandb.py b/deepspeed/monitor/wandb.py index 30209191171a..174a2eb2d3b7 100644 --- a/deepspeed/monitor/wandb.py +++ b/deepspeed/monitor/wandb.py @@ -24,10 +24,10 @@ def __init__(self, wandb_config): if self.enabled and dist.get_rank() == 0: wandb.init(project=self.project, group=self.group, entity=self.team) - def log(self, data, step=None, commit=None, sync=None): + def log(self, data, step=None, commit=None): if self.enabled and dist.get_rank() == 0: import wandb - return wandb.log(data, step=step, commit=commit, sync=sync) + return wandb.log(data, step=step, commit=commit) def write_events(self, event_list): if self.enabled and dist.get_rank() == 0: diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index dcc23681bbab..9fa5769b5597 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -29,8 +29,8 @@ # There is a case where customer want to load the checkpoint saved # by raw torch. Because nebula cannot load torch checkpoint directly # as they have different folder structures to bring the gap for -# loading(the data are totaly same in bytes for torch and enbula s -# aving). +# loading(the data are totally same in bytes for torch and nebula +# saving). # In this case, we must disable nebula load to use raw torch load. # Customer can just set NEBULA_ENABLE_NEBULA_LOAD to False. Then use # original way of deepspeed to load, i.e. set the value of "--load". @@ -60,7 +60,7 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 -# Neubla envs +# Nebula envs NEBULA_EXPORT_ENVS = [ 'DLTS_JOB_ID', 'DLTS_NUM_WORKER', 'NEBULA_PERSISTENT_STORAGE_PATH', 'NEBULA_PERSISTENT_TIME_INTERVAL', 'AML_RUN_ID', 'AZUREML_RUN_TOKEN', 'AZUREML_WORKSPACE_SCOPE', 'AZUREML_EXPERIMENT_SCOPE', diff --git a/deepspeed/nvme/__init__.py b/deepspeed/nvme/__init__.py new file mode 100644 index 000000000000..6d0de857cbd3 --- /dev/null +++ b/deepspeed/nvme/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .perf_run_sweep import sweep_main, parse_sweep_arguments +from .perf_generate_param import generate_main +from .test_ds_aio import ds_io_main diff --git a/deepspeed/nvme/ds_aio_args.py b/deepspeed/nvme/ds_aio_args.py new file mode 100644 index 000000000000..210f21b7c4d6 --- /dev/null +++ b/deepspeed/nvme/ds_aio_args.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import argparse +import os +from .test_ds_aio_utils import refine_integer_value +from .ds_aio_constants import AIO_HANDLE, AIO_BASIC, TORCH_FAST_IO, TORCH_IO, VALID_ENGINES +from deepspeed.accelerator import get_accelerator + +MAPPING_DELIMITER = ':' + + +def refine_args(args): + if args.io_size and type(args.io_size) == str: + args.io_size = refine_integer_value(args.io_size) + + if args.block_size and type(args.block_size) == str: + args.block_size = refine_integer_value(args.block_size) + + if args.fast_io_size and type(args.fast_io_size) == str: + args.fast_io_size = refine_integer_value(args.fast_io_size) + + return args + + +def _get_mapping_dict(args): + if args.folder is not None: + d = {i: args.folder for i in range(args.multi_process)} + else: + d = {} + for m in args.folder_to_device_mapping: + fields = m.split(MAPPING_DELIMITER) + d[fields[1]] = fields[0] + + return d + + +def _validate_folder_mapping(args): + no_error = True + error_messages = [] + invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m] + if len(invalid_mappings) > 0: + error_messages.append( + f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}') + no_error = False + + folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping] + invalid_folders = [d for d in folder_list if not os.path.exists(d)] + if len(invalid_folders) > 0: + error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}') + no_error = False + + if args.gpu: + device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping] + invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()] + if len(invalid_device_list) > 0: + error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}') + no_error = False + + return no_error, error_messages + + +def validate_args(args): + no_error = True + error_messages = [] + + if args.folder is not None and len(args.folder_to_device_mapping) > 0: + error_messages.append('--folder and --folder_to_device_mapping cannot be specified together.') + no_error = False + elif args.folder is None and len(args.folder_to_device_mapping) == 0: + error_messages.append('At least one of --folder or --folder_to_device_mapping must be specified.') + no_error = False + + # Validate --folder + if args.folder is not None and not os.path.exists(args.folder): + no_error = False + error_messages.append(f'Invalid folder in --folder: {args.folder} ') + + # Validate --folder_mapping_to_device + if len(args.folder_to_device_mapping) > 0: + no_mapping_error, mapping_error_messages = _validate_folder_mapping(args) + no_error = no_error and no_mapping_error + error_messages += mapping_error_messages + + # Validate --engine + if args.engine not in VALID_ENGINES: + no_error = False + error_messages.append(f'Invalid engine {args.engine}. Valid options = {VALID_ENGINES}') + + # Validate --engine=torch_io + if args.engine == TORCH_IO: + if args.read: + no_error = False + error_messages.append(f'Read not currently supported for --engine={TORCH_IO}') + + if not no_error: + print(f'Found {len(error_messages)} validation error(s)') + # Validate --gpu, --use_gds + if args.use_gds and not args.gpu: + error_messages.append('--gpu must be set to transfer with --use_gds') + no_error = False + + if not no_error: + print(f'Found {len(error_messages)} validation errors') + for i, msg in enumerate(error_messages): + print(f'{i+1}: {msg}') + + return no_error + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.') + + parser.add_argument('--folder_to_device_mapping', + default=[], + nargs='+', + help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).' + 'Can be specified multiple times for multi-process runs,' + 'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu' + 'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15') + + parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.') + + parser.add_argument('--fast_io_size', type=str, default='64M', help='Size of fast_io pinned buffer (bytes).') + + parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)') + + parser.add_argument('--multi_process', + type=int, + default=1, + help='Number of parallel processes doing I/O (default 1).') + + parser.add_argument('--block_size', + type=str, + default='1M', + help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).') + + parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).') + + parser.add_argument('--single_submit', + action='store_true', + help='Submit I/O requests in singles (default is submit queue_depth amount at once.).') + + parser.add_argument( + '--sequential_requests', + action='store_true', + help= + 'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).' + ) + + parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.') + + parser.add_argument( + '--engine', + type=str, + default=AIO_HANDLE, + help= + f'Engine to perform I/O. Options are [{AIO_HANDLE}, {AIO_BASIC}, {TORCH_IO}, {TORCH_FAST_IO}]. Default is aio_handle' + ) + + parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions') + + parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism') + + parser.add_argument('--gpu', action='store_true', help='Use GPU memory') + + parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO') + + parser.add_argument('--slow_bounce_buffer', + action='store_true', + help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.') + + parser.add_argument('--torch_legacy_save', action='store_true', help='Use torch legacy save approach') + + parser.add_argument('--use_accelerator_pin_memory', + action='store_true', + help='Obtain pinned (CPU page-locked) tensors from accelerator') + + parser.add_argument('--warmup_loops', type=int, default=1, help='Count of operation warmup repetitions') + + parser.add_argument('--include_warmup_time', action='store_true', help='Include warmup latency in results') + + parser.add_argument('--different_file_each_iteration', + action='store_true', + help='Read/write a different file on each iteration.') + + args = parser.parse_args() + print(f'args = {args}') + return args + + +def get_validated_args(): + args = parse_arguments() + args = refine_args(args) + if not validate_args(args): + quit() + print('Successful validation of command line arguments') + args.total_loops = args.warmup_loops + args.loops + peer_tag = 'gpu' if args.gpu else 'process' + args.mapping_dict = _get_mapping_dict(args) + args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()] + assert len(args.mapping_dict) == len(args.mapping_list) + print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping') + for i, (device_id, folder) in enumerate(args.mapping_list): + print(f'[{i}]: {peer_tag} {device_id} <----> {folder}') + + return args diff --git a/deepspeed/nvme/ds_aio_basic.py b/deepspeed/nvme/ds_aio_basic.py new file mode 100755 index 000000000000..a640d562691b --- /dev/null +++ b/deepspeed/nvme/ds_aio_basic.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from .ds_aio_constants import * + + +class AIOBasic_Engine(object): + + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) + + def fini(self): + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None + + def read(self, args, tid, loop_id): + start_time = time.time() + AsyncIOBuilder().load().aio_read(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth, + args.single_submit, not args.sequential_requests, args.validate) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def write(self, args, tid, loop_id): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) + + start_time = time.time() + AsyncIOBuilder().load().aio_write(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth, + args.single_submit, not args.sequential_requests, args.validate) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + buffer = create_page_locked_tensor(args.io_size, True) + + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}') + + task_log(tid, 'created deepspeed aio basic engine') + + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[ELAPSED_SEC] = 0 + return ctxt diff --git a/deepspeed/nvme/ds_aio_constants.py b/deepspeed/nvme/ds_aio_constants.py new file mode 100644 index 000000000000..1b07ed8672ef --- /dev/null +++ b/deepspeed/nvme/ds_aio_constants.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +AIO_HANDLE = 'aio_handle' +AIO_BASIC = 'aio_basic' +TORCH_IO = 'torch_io' +TORCH_FAST_IO = 'torch_fastio' +VALID_ENGINES = [AIO_HANDLE, AIO_BASIC, TORCH_IO, TORCH_FAST_IO] + +BUFFER = 'buffer' +BOUNCE_BUFFER = 'bounce_buffer' +NUM_BYTES = 'num_bytes' +FILE = 'file' +HANDLE = 'handle' +ELAPSED_SEC = 'elapsed_sec' +FAST_IO_BUFFER = 'fast_io_buffer' +USE_CPU_LOCKED_TENSOR = 'cpu_locked_tensor' +USE_GDS = 'gds' diff --git a/deepspeed/nvme/ds_aio_handle.py b/deepspeed/nvme/ds_aio_handle.py new file mode 100755 index 000000000000..19edd04a71b5 --- /dev/null +++ b/deepspeed/nvme/ds_aio_handle.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import torch +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.ops.op_builder import GDSBuilder +from deepspeed.accelerator import get_accelerator +from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from .ds_aio_constants import * + + +class AIOHandle_Engine(object): + + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) + + def fini(self): + for buf in [BUFFER, BOUNCE_BUFFER]: + if self.ctxt[buf] is not None: + if self.ctxt[USE_CPU_LOCKED_TENSOR]: + self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf]) + + self.ctxt[buf].detach() + self.ctxt[buf] = None + + def read(self, args, tid, loop_id): + handle = self.ctxt[HANDLE] + + start_time = time.time() + dest_buffer = BOUNCE_BUFFER if self.ctxt[BOUNCE_BUFFER] is not None else BUFFER + ret = handle.pread(self.ctxt[dest_buffer], self.ctxt[FILE][loop_id], args.validate, True) + assert ret != -1 + handle.wait() + if dest_buffer == BOUNCE_BUFFER: + self.ctxt[BUFFER].data.copy_(self.ctxt[BOUNCE_BUFFER].data) + end_time = time.time() + self.ctxt[ELAPSED_SEC].append(end_time - start_time) + + def write(self, args, tid, loop_id): + # Avoid overwriting existing files as it could be artificially faster + # if os.path.isfile(self.ctxt[FILE]): + # os.remove(self.ctxt[FILE]) + + handle = self.ctxt[HANDLE] + start_time = time.time() + if self.ctxt[BOUNCE_BUFFER] is not None: + source_buffer = BOUNCE_BUFFER + self.ctxt[BOUNCE_BUFFER].data.copy_(self.ctxt[BUFFER].data) + else: + source_buffer = BUFFER + ret = handle.pwrite(self.ctxt[source_buffer], self.ctxt[FILE][loop_id], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + self.ctxt[ELAPSED_SEC].append(end_time - start_time) + + def _create_files(self, args, folder, tid): + if args.different_file_each_iteration: + filenames = [ + create_filename(folder, args.read, args.io_size, f'{tid}_{l}') for l in range(args.total_loops) + ] + else: + filenames = [ + create_filename(folder, args.read, args.io_size, f'{tid}_{0}') for _ in range(args.total_loops) + ] + + if args.read: + for f in filenames: + if not (os.path.isfile(f) and os.path.getsize(f) == args.io_size): + create_file(f, args.io_size) + else: + for f in filenames: + if os.path.isfile(f): + os.remove(f) + + return filenames + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filenames = self._create_files(args, folder, tid) + + gds = True if args.use_gds else False + io_parallel = args.io_parallel if args.io_parallel else 1 + if gds: + handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + else: + handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + task_log(tid, 'Created DeepNVMe handle engine') + + bounce_buffer = None + if args.gpu: + device_name = get_accelerator().device_name(device_id) + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name) + if gds: + handle.pin_device_tensor(buffer) + elif not args.slow_bounce_buffer: + bounce_buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle) + else: + buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle) + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + ctxt = {} + ctxt[FILE] = filenames + ctxt[NUM_BYTES] = args.io_size + ctxt[HANDLE] = handle + ctxt[USE_GDS] = gds + ctxt[BUFFER] = buffer + ctxt[BOUNCE_BUFFER] = bounce_buffer + ctxt[ELAPSED_SEC] = [] + ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory + + task_log(tid, + f'{io_string} file {filenames} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + return ctxt diff --git a/deepspeed/nvme/ds_aio_job.py b/deepspeed/nvme/ds_aio_job.py new file mode 100644 index 000000000000..0f9c8b5f1bcc --- /dev/null +++ b/deepspeed/nvme/ds_aio_job.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping tensors to/from (NVMe) storage devices. +""" +import subprocess +import shlex + + +class Job(object): + + def __init__(self, cmd_line, output_file=None, work_dir=None): + self.cmd_line = cmd_line + self.output_file = output_file + self.work_dir = work_dir + self.output_fd = None + + def cmd(self): + return self.cmd_line + + def get_stdout(self): + return self.output_fd + + def get_stderr(self): + return self.output_fd + + def get_cwd(self): + return self.work_dir + + def open_output_file(self): + if self.output_file is not None: + self.output_fd = open(self.output_file, 'w') + + def close_output_file(self): + if self.output_fd is not None: + self.output_fd.close() + self.output_fd = None + + +def run_job(job, verbose=False): + args = shlex.split(' '.join(job.cmd())) + if verbose: + print(f'args = {args}') + job.open_output_file() + proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) + job.close_output_file() + assert proc.returncode == 0, \ + f"This command failed: {job.cmd()}" diff --git a/deepspeed/nvme/io_engine.py b/deepspeed/nvme/io_engine.py new file mode 100644 index 000000000000..33a7c035c7aa --- /dev/null +++ b/deepspeed/nvme/io_engine.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +from multiprocessing import Pool, Barrier + +from .ds_aio_constants import AIO_BASIC, TORCH_FAST_IO, TORCH_IO +from .test_ds_aio_utils import report_results, task_log, task_barrier +from .ds_aio_handle import AIOHandle_Engine +from .ds_aio_basic import AIOBasic_Engine +from .torch_io import TorchIO_Engine +from .torch_fastio_engine import Torch_FastIO_Engine + + +def prepare_operation(args, tid, read_op): + if args.engine == TORCH_IO: + io_engine = TorchIO_Engine(args, tid, read_op) + elif args.engine == AIO_BASIC: + io_engine = AIOBasic_Engine(args, tid, read_op) + elif args.engine == TORCH_FAST_IO: + io_engine = Torch_FastIO_Engine(args, tid, read_op) + else: + io_engine = AIOHandle_Engine(args, tid, read_op) + + return io_engine + + +def prepare_read(pool_params): + args, tid = pool_params + return prepare_operation(args, tid, True) + + +def prepare_write(pool_params): + args, tid = pool_params + return prepare_operation(args, tid, False) + + +def post_operation(pool_params): + _, _, io_engine = pool_params + io_engine.fini() + + +def read_operation(pool_params): + args, tid, loop_id, io_engine = pool_params + return io_engine.read(args, tid, loop_id) + + +def write_operation(pool_params): + args, tid, loop_id, io_engine = pool_params + return io_engine.write(args, tid, loop_id) + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = prepare_read + schedule['post'] = post_operation + schedule['main'] = read_operation + else: + schedule['pre'] = prepare_write + schedule['post'] = post_operation + schedule['main'] = write_operation + + return schedule + + +def io_engine_tasklet(pool_params): + args, tid, read_op = pool_params + num_processes = len(args.mapping_dict) + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, num_processes) + + # Run pre task + task_log(tid, 'running pre-task') + io_engine = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, num_processes) + + # Run main tasks in a loop + io_engine.ctxt["main_task_sec"] = [] + for i in range(args.total_loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + schedule["main"]((args, tid, i, io_engine)) + task_barrier(aio_barrier, num_processes) + stop_time = time.time() + io_engine.ctxt["main_task_sec"].append(stop_time - start_time) + + # Run post task + task_log(tid, 'running post-task') + schedule["post"]((args, tid, io_engine)) + task_barrier(aio_barrier, num_processes) + + ctxt = io_engine.ctxt + # return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + if args.include_warmup_time: + e2e_latency_sec = sum(ctxt["main_task_sec"]) + task_latency_sec = sum(ctxt["elapsed_sec"]) + actual_loops = args.total_loops + else: + e2e_latency_sec = sum(ctxt["main_task_sec"][args.warmup_loops:]) + task_latency_sec = sum(ctxt["elapsed_sec"][args.warmup_loops:]) + actual_loops = args.loops + + l = ctxt["elapsed_sec"] + task_log(tid, f'task_latency_sec = {l}') + return e2e_latency_sec, task_latency_sec, ctxt["num_bytes"] * actual_loops + + +def _init_takslet(b): + global aio_barrier + aio_barrier = b + + +def io_engine_multiprocessing(args, read_op): + num_processes = len(args.mapping_dict) + b = Barrier(num_processes) + pool_params = [(args, p, read_op) for p in range(num_processes)] + with Pool(processes=num_processes, initializer=_init_takslet, initargs=(b, )) as p: + pool_results = p.map(io_engine_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/deepspeed/nvme/parse_nvme_stats.py b/deepspeed/nvme/parse_nvme_stats.py new file mode 100755 index 000000000000..44c955a857b8 --- /dev/null +++ b/deepspeed/nvme/parse_nvme_stats.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import argparse + +READ_SPEED = 'read_speed' +WRITE_SPEED = 'write_speed' + +PERF_METRICS = [READ_SPEED, WRITE_SPEED] + +METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--log_dir', type=str, required=True, help='Folder of statistics logs') + + parser.add_argument('--metric', + type=str, + required=True, + help='Performance metric to report: [read_speed|write_speed]') + + args = parser.parse_args() + print(f'args = {args}') + + return args + + +def extract_value(key, file): + INVALID_PREFIXES = ["ds"] + for p in INVALID_PREFIXES: + if key.startswith(p): + return key + try: + if key[0] in ['t', 'd', 'p']: + return int(key[1:]) + if key.startswith("bs"): + if key.endswith('K'): + v = key[2:].split('K') + return int(v[0]) * 1024 + elif key.endswith('M'): + v = key[2:].split('M') + return int(v[0]) * 1024 * 1024 + else: + return int(key[2:]) + except Exception: + print(f"{file}: extract_value fails on {key}") + return None + + return key + + +def get_file_key(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + values = [extract_value(k, file) for k in fields] + return tuple(values) + + +def get_thread_count(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + for key in fields: + if key[0] == 't': + return int(key[1:]) + return 1 + + +""" +Extract performance metric from log file. +Sample file lines are: +Task Read Latency = 0.031647682189941406 sec +Task Read Speed = 12.342926020792527 GB/sec +E2E Read Latency = 0.031697988510131836 sec +E2E Read Speed = 12.323337169333062 GB/sec + +For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned +""" + + +def get_metric(file, metric): + thread_count = get_thread_count(file) + with open(file) as f: + for line in f.readlines(): + if line.startswith(METRIC_SEARCH[metric]): + if metric in [READ_SPEED, WRITE_SPEED]: + fields = line.split() + return float(fields[-2]) + else: + fields = line.split('=') + return float(fields[-1]) + + return None + + +def validate_args(args): + if args.metric not in PERF_METRICS: + print(f'{args.metric} is not a valid performance metrics') + return False + + if not os.path.isdir(args.log_dir): + print(f'{args.log_dir} folder is not existent') + return False + + return True + + +def get_results(log_files, metric): + results = {} + for f in log_files: + file_key = get_file_key(f) + value = get_metric(f, metric) + results[file_key] = value + + return results + + +def get_sorted_results(log_dir, metric): + log_files = [f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, f))] + + log_files_path = [os.path.join(log_dir, f) for f in log_files] + results = get_results(log_files_path, metric) + result_keys = list(results.keys()) + sorted_keys = sorted(result_keys) + return sorted_keys, results + + +def main(): + print("Parsing aio statistics") + args = parse_arguments() + + if not validate_args(args): + quit() + + sorted_keys, results = get_sorted_results(args.log_dir, args.metric) + for k in sorted_keys: + print(f'{k} = {results[k]}') + + +if __name__ == "__main__": + main() diff --git a/deepspeed/nvme/perf_generate_param.py b/deepspeed/nvme/perf_generate_param.py new file mode 100644 index 000000000000..99be75c2d919 --- /dev/null +++ b/deepspeed/nvme/perf_generate_param.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" +import os +import argparse +import json +from .parse_nvme_stats import READ_SPEED, WRITE_SPEED, get_sorted_results +from .perf_sweep_utils import BENCH_LOG_DIR, READ_LOG_DIR, WRITE_LOG_DIR + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--log_dir', + type=str, + default=BENCH_LOG_DIR, + help=f'Folder of performance sweep logs. Default is {os.path.join(".", BENCH_LOG_DIR)}') + parser.add_argument('--verbose', action='store_true', help='Print debugging information.') + + args = parser.parse_args() + if args.verbose: + print(f'args = {args}') + + return args + + +def validate_args(args): + for d in [READ_LOG_DIR, WRITE_LOG_DIR]: + log_dir = os.path.join(args.log_dir, d) + if not os.path.isdir(log_dir): + print(f'{log_dir} folder is not existent') + return False + + return True + + +def convert_to_param(key): + assert len(key) == 6 + return { + "single_submit": "true" if key[0] == "single" else "false", + "overlap_events": "true" if key[1] == "overlap" else "false", + "num_threads": int(key[5]), + "queue_depth": int(key[3]), + "block_size": int(key[4]) + } + + +def generate_aio_param(read_log_dir, write_log_dir): + _, read_results = get_sorted_results(read_log_dir, READ_SPEED) + _, write_results = get_sorted_results(write_log_dir, WRITE_SPEED) + + read_results_count = len(read_results.items()) + write_results_count = len(write_results.items()) + assert read_results_count == write_results_count, f"Mismatch in number of read & write results: {read_results_count=} != {write_results_count=}" + + combined_perf = {key[1:]: value for key, value in read_results.items()} + for key, value in write_results.items(): + new_key = key[1:] + if new_key in combined_perf: + combined_perf[new_key] += value + else: + combined_perf[new_key] = 0 + + optimal_key = None + optimal_perf = 0.0 + for key, value in combined_perf.items(): + if value > optimal_perf: + optimal_perf = value + optimal_key = key + + aio_param = {"aio": convert_to_param(optimal_key)} + + read_perf_keys = {key[1:]: key for key in read_results.keys()} + write_perf_keys = {key[1:]: key for key in write_results.keys()} + optimal_config_read = read_results.get(read_perf_keys[optimal_key], None) + optimal_config_write = write_results.get(write_perf_keys[optimal_key], None) + + print(f'Best performance (GB/sec): read = {optimal_config_read:5.2f}, write = {optimal_config_write:5.2f}') + print(json.dumps(aio_param, indent=3)) + + +def generate_main(log_dir): + read_log_dir = os.path.join(log_dir, READ_LOG_DIR) + write_log_dir = os.path.join(log_dir, WRITE_LOG_DIR) + generate_aio_param(read_log_dir, write_log_dir) + + +def main(): + args = parse_arguments() + if not validate_args(args): + quit() + print(f'Generate DeepNVMe configuration from {args.log_dir} logs') + generate_main(args.log_dir) + + +if __name__ == "__main__": + generate_main() diff --git a/deepspeed/nvme/perf_run_sweep.py b/deepspeed/nvme/perf_run_sweep.py new file mode 100644 index 000000000000..4560911bcb98 --- /dev/null +++ b/deepspeed/nvme/perf_run_sweep.py @@ -0,0 +1,319 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" +import os +import sys +import argparse +import json +import itertools +import shutil + +from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder +from .ds_aio_job import Job, run_job +from .perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \ + READ_LOG_DIR, WRITE_LOG_DIR + +OTHER_OPTIONS = '--engine aio_handle' +PERF_SCRIPT = 'ds_io' +DEFAULT_SWEEP_CONFIG = { + "block_size": ["1M", "8M"], + "queue_depth": [32, 128], + "sequential_requests": [True, False], + "single_submit": [False, True], + "io_parallel": [1, 2, 4, 8], +} + + +class SweepConfig(object): + + def __init__(self, args): + self.folder_to_device_mapping = get_ftd_map(args.nvme_dir) + self.search_space = get_sweep_config_dict(args.sweep_config) + self.search_space.update(self.folder_to_device_mapping) + self.read = not args.no_read + self.write = not args.no_write + self.flush_cache = args.flush_page_cache + self.log_dir = args.log_dir + self.verbose = args.verbose + self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}' + if args.gpu: + self.other_options += ' --gpu' + if args.gds: + self.other_options += ' --use_gds' + + +def validate_arguments(args): + if not async_io_setup(): + error_msg = """ + Failing because environment is not properly configured for deepspeed async i/o module. + Possible fix: apt install libaio-dev. + """ + print(error_msg) + quit() + + if args.gds and not gds_io_setup(): + error_msg = """ + Failing because environment is not properly configured for deepspeed GDS I/O operator. + """ + print(error_msg) + quit() + + +def parse_sweep_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--nvme_dir', + nargs='+', + required=True, + help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.') + + parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.') + + parser.add_argument('--no_read', action='store_true', help='Disable read performance measurements.') + + parser.add_argument('--no_write', action='store_true', help='Disable write performance measurements.') + + parser.add_argument('--io_size', + type=str, + default="400M", + help='Number of I/O bytes to read/write for performance measurements.') + + parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.') + + parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator') + + parser.add_argument( + '--flush_page_cache', + action='store_true', + help= + 'Page cache will not be flushed and reported read speeds may be higher than actual ***Requires sudo access***.' + ) + + parser.add_argument( + '--log_dir', + type=str, + default=BENCH_LOG_DIR, + help=f'Output directory for performance log files. Default is {os.path.join(".", BENCH_LOG_DIR)}') + + parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions') + + parser.add_argument('--verbose', action='store_true', help='Print debugging information.') + + args = parser.parse_args() + if args.verbose: + print(f'args = {args}') + validate_arguments(args) + + return args + + +def dump_cmd_lines(cmd_lines): + print(f'cmd line count = {len(cmd_lines)}') + for i, cmd in enumerate(cmd_lines): + print(f'{i}: {cmd}') + + +def get_ftd_map(nvme_dir_list): + ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)] + ftd_arg = [' '.join(ftd for ftd in ftd_list)] + return {'folder_to_device_mapping': ftd_arg} + + +def get_sweep_config_dict(sweep_config_json): + if sweep_config_json is None: + return DEFAULT_SWEEP_CONFIG + + with open(sweep_config_json) as fp: + sweep_config = json.load(fp) + return sweep_config + + +def get_sweep_cmd_lines(sweep_config_dict): + + def flatten_options(key, value_list): + flat_list = [] + for v in value_list: + if not type(v) is bool: + flat_list.append(f'--{key} {v}') + elif v: + flat_list.append(f'--{key}') + else: + flat_list.append(' ') + + return flat_list + + flat_list = [flatten_options(key, value) for key, value in sweep_config_dict.items()] + cmd_list = list(itertools.product(*flat_list)) + cmd_list = [list(cmd) for cmd in cmd_list] + #dump_cmd_lines(cmd_list) + return cmd_list + + +def launch_sweep(sweep_jobs, sync_job, flush_cache_job, verbose): + for perf_job in sweep_jobs: + if flush_cache_job is not None: + run_job(sync_job, verbose) + run_job(flush_cache_job, verbose) + + run_job(perf_job, verbose) + + run_job(sync_job, verbose) + + +def create_cmd_tags(cmd_line): + tags = {} + for param_value in cmd_line: + fields = param_value.split() + if len(fields) == 1: + tags[fields[0]] = None + elif len(fields) == 2: + if fields[0] == '--folder_to_device_mapping': + tags[fields[0]] = len(fields[1:]) + else: + tags[fields[0]] = fields[1] + elif len(fields) > 2: + tags[fields[0]] = len(fields[1:]) + return tags + + +def get_log_file(io_op_desc, cmd_line): + QUEUE_DEPTH = "--queue_depth" + BLOCK_SIZE = "--block_size" + SINGLE_SUBMIT = "--single_submit" + SEQUENTIAL_REQUESTS = "--sequential_requests" + FTD_MAP = "--folder_to_device_mapping" + IO_PARALLEL = "--io_parallel" + + tag_map = { + QUEUE_DEPTH: "d", + BLOCK_SIZE: "bs", + SINGLE_SUBMIT: "single", + SEQUENTIAL_REQUESTS: "sequential", + FTD_MAP: "ftd", + IO_PARALLEL: "p" + } + + tag_default = { + QUEUE_DEPTH: 1, + BLOCK_SIZE: "1M", + SINGLE_SUBMIT: "block", + SEQUENTIAL_REQUESTS: "overlap", + FTD_MAP: 1, + IO_PARALLEL: 1 + } + + def get_default_value(tag): + value = tag_default[tag] + if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]: + return value + return f'{tag_map[tag]}{value}' + + def get_config_value(tag, value): + tag_key = tag_map[tag] + if value is None: + return tag_key + return f'{tag_key}{value}' + + tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL] + log_tags = [io_op_desc] + cmd_tags = create_cmd_tags(cmd_line) + for tag in tag_list: + if tag in cmd_tags: + log_tags.append(get_config_value(tag, cmd_tags[tag])) + else: + log_tags.append(get_default_value(tag)) + + log_file = '_'.join(log_tags) + log_file += '.txt' + return log_file + + +def create_perf_jobs(io_op_desc, log_dir, cmd_lines): + py_cmd = [os.path.join(script_path(), PERF_SCRIPT)] + + perf_jobs = [] + for cmd in cmd_lines: + log_file = os.path.join(log_dir, get_log_file(io_op_desc, cmd)) + job = Job(cmd_line=py_cmd + cmd, output_file=log_file) + perf_jobs.append(job) + + return perf_jobs + + +def script_path(): + return os.path.dirname(os.path.realpath(sys.argv[0])) + + +def async_io_setup(): + return AsyncIOBuilder().is_compatible() + + +def gds_io_setup(): + return GDSBuilder().is_compatible() + + +def remove_folder(folder): + assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found" + shutil.rmtree(folder) + + +def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): + read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines] + # dump_cmd_lines(cmd_lines) + + log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}') + os.makedirs(log_folder, exist_ok=True) + + perf_jobs = create_perf_jobs(io_op_desc=READ_OP_DESC, log_dir=log_folder, cmd_lines=read_cmd_lines) + + launch_sweep(sweep_jobs=perf_jobs, + sync_job=sync_job, + flush_cache_job=flush_cache_job, + verbose=sweep_config.verbose) + + +def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines): + write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines] + # dump_cmd_lines(write_cmd_lines) + + log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}') + os.makedirs(log_folder, exist_ok=True) + + perf_jobs = create_perf_jobs(io_op_desc=WRITE_OP_DESC, log_dir=log_folder, cmd_lines=write_cmd_lines) + + launch_sweep(sweep_jobs=perf_jobs, + sync_job=sync_job, + flush_cache_job=flush_cache_job, + verbose=sweep_config.verbose) + + +def sweep_main(args): + sweep_config = SweepConfig(args) + cmd_lines = get_sweep_cmd_lines(sweep_config.search_space) + + if sweep_config.flush_cache: + flush_cache_job = Job(cmd_line=['sudo', 'bash -c', "'echo 1 > /proc/sys/vm/drop_caches'"]) + else: + flush_cache_job = None + + sync_job = Job(cmd_line=['sync']) + + if sweep_config.read: + run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines) + + if sweep_config.write: + run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines) + + +def main(): + args = parse_sweep_arguments() + print(f"Running DeepNVMe performance sweep on {args.nvme_dir}") + sweep_main(args) + + +if __name__ == "__main__": + sweep_main() diff --git a/deepspeed/nvme/perf_sweep_utils.py b/deepspeed/nvme/perf_sweep_utils.py new file mode 100644 index 000000000000..e6832c1baa49 --- /dev/null +++ b/deepspeed/nvme/perf_sweep_utils.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +SCRIPT_PREFIX = '_aio_bench' +WRITE_OP_DESC = 'write' +READ_OP_DESC = 'read' +READ_IO_DIR = f'{SCRIPT_PREFIX}_{READ_OP_DESC}_io' +WRITE_IO_DIR = f'{SCRIPT_PREFIX}_{WRITE_OP_DESC}_io' +BENCH_LOG_DIR = f'{SCRIPT_PREFIX}_logs' +READ_LOG_DIR = f'{SCRIPT_PREFIX}_{READ_OP_DESC}_logs' +WRITE_LOG_DIR = f'{SCRIPT_PREFIX}_{WRITE_OP_DESC}_logs' diff --git a/deepspeed/nvme/test_ds_aio.py b/deepspeed/nvme/test_ds_aio.py new file mode 100755 index 000000000000..3347da182c50 --- /dev/null +++ b/deepspeed/nvme/test_ds_aio.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import multiprocessing as mp +from .ds_aio_args import get_validated_args +from .io_engine import io_engine_multiprocessing + + +def ds_io_main(): + print('Testing DeepNVMe python frontend') + + args = get_validated_args() + mp.set_start_method('spawn', force=True) + multiprocess_function = io_engine_multiprocessing + multiprocess_function(args, args.read) + + +if __name__ == "__main__": + ds_io_main() diff --git a/deepspeed/nvme/test_ds_aio_utils.py b/deepspeed/nvme/test_ds_aio_utils.py new file mode 100755 index 000000000000..90b994b0b532 --- /dev/null +++ b/deepspeed/nvme/test_ds_aio_utils.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +from .ds_aio_job import Job, run_job +import torch +from deepspeed.accelerator import get_accelerator + +BYTES_PER_GB = 1024**3 +BYTES_PER_MB = 1024**2 +BYTES_PER_KB = 1024 +LOG_TIDS = [0] + + +def task_log(tid, msg, force=False): + if force or tid in LOG_TIDS: + print(f'tid {tid}: {msg}') + + +def task_barrier(barrier, num_parties): + assert barrier.parties == num_parties + barrier.wait() + assert barrier.broken == False + + +def report_results(args, read_op, pool_results): + #print(f'pool_results = {pool_results}') + io_string = 'Read' if read_op else 'Write' + if None in pool_results: + print(f'Failure in one of {args.threads} {io_string} processes') + return + + total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) + + task_latency_sec = max([sec for _, sec, _ in pool_results]) + task_speed_GB = 0 if task_latency_sec == 0 else total_bytes / task_latency_sec / BYTES_PER_GB + print(f'Task {io_string} Latency = {task_latency_sec} sec') + print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') + + e2e_latency_sec = max([sec for sec, _, _ in pool_results]) + e2e_speed_GB = 0 if e2e_latency_sec == 0 else total_bytes / e2e_latency_sec / BYTES_PER_GB + print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') + print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') + + +def get_block_size_and_count(io_bytes): + if io_bytes > BYTES_PER_MB and io_bytes % BYTES_PER_MB == 0: + block_size = BYTES_PER_MB + block_size_string = '1M' + else: + assert io_bytes % BYTES_PER_KB == 0 + block_size = BYTES_PER_KB + block_size_string = '1K' + block_count = io_bytes / block_size + + return block_size_string, int(block_count) + + +def refine_integer_value(value): + unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} + + if value[-1] in list(unit_dict.keys()): + int_value = int(value[:-1]) * unit_dict[value[-1]] + return int_value + return int(value) + + +def create_filename(folder, read_op, size, tid): + io_string = "read" if read_op else "write" + return os.path.join(folder, f'_aio_{io_string}_{size}.pt.{tid}') + + +def create_file(filename, num_bytes): + block_size, block_count = get_block_size_and_count(num_bytes) + dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={filename} bs={block_size} count={block_count}']) + print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....') + run_job(dd_job) + print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....') + + +def create_page_locked_tensor(num_elem, use_accelerator, aio_handle=None): + if use_accelerator: + return get_accelerator().pin_memory(torch.randint(high=128, size=(num_elem, ), dtype=torch.uint8, + device='cpu')) + else: + return aio_handle.new_cpu_locked_tensor(num_elem, torch.empty(0, dtype=torch.uint8)) diff --git a/deepspeed/nvme/torch_fastio_engine.py b/deepspeed/nvme/torch_fastio_engine.py new file mode 100644 index 000000000000..8929e175865e --- /dev/null +++ b/deepspeed/nvme/torch_fastio_engine.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from .ds_aio_constants import * +from deepspeed.io import FastFileWriter + + +class Torch_FastIO_Engine(object): + + def __init__(self, args, tid, read_op): + assert read_op is False, 'Read operation is not currently supported' + self.ctxt = self._create_context(args, tid, read_op) + self.zipfile_serialization = not args.torch_legacy_save + + def fini(self): + if self.ctxt[USE_CPU_LOCKED_TENSOR]: + for buf in [BUFFER, FAST_IO_BUFFER]: + self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf]) + + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None + + def read(self, args, tid): + start_time = time.time() + torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def write(self, args, tid): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) + + ds_file_writer = FastFileWriter(file_path=self.ctxt[FILE], + aio_handle=self.ctxt[HANDLE], + pinned_tensor=self.ctxt[FAST_IO_BUFFER]) + + start_time = time.time() + torch.save(obj=self.ctxt[BUFFER], f=ds_file_writer, _use_new_zipfile_serialization=self.zipfile_serialization) + ds_file_writer.close() # Force flush to storage + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + ds_file_writer._dump_state() + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + io_parallel = args.io_parallel if args.io_parallel else 1 + aio_handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit, + not args.sequential_requests, io_parallel) + + if args.gpu: + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}') + else: + buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, aio_handle) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + fast_io_buffer = create_page_locked_tensor(args.fast_io_size, args.use_accelerator_pin_memory, aio_handle) + + task_log(tid, 'created torch_fastio engine') + + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[HANDLE] = aio_handle + ctxt[FAST_IO_BUFFER] = fast_io_buffer + ctxt[ELAPSED_SEC] = 0 + ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory + + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + return ctxt diff --git a/deepspeed/nvme/torch_io.py b/deepspeed/nvme/torch_io.py new file mode 100644 index 000000000000..04d653544ccd --- /dev/null +++ b/deepspeed/nvme/torch_io.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import os +import time +from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor +from .ds_aio_constants import * + + +class TorchIO_Engine(object): + + def __init__(self, args, tid, read_op): + self.ctxt = self._create_context(args, tid, read_op) + self.zipfile_serialization = not args.torch_legacy_save + + def fini(self): + self.ctxt[BUFFER].detach() + self.ctxt[BUFFER] = None + + def read(self, args, tid): + start_time = time.time() + torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def write(self, args, tid): + # Avoid overwriting existing files as it could be artificially faster + if os.path.isfile(self.ctxt[FILE]): + os.remove(self.ctxt[FILE]) + + start_time = time.time() + torch.save(obj=self.ctxt[BUFFER], f=self.ctxt[FILE], _use_new_zipfile_serialization=self.zipfile_serialization) + end_time = time.time() + self.ctxt[ELAPSED_SEC] += end_time - start_time + + def _create_context(self, args, tid, read_op): + io_string = "Read" if read_op else "Write" + device_id, folder = args.mapping_list[tid] + filename = create_filename(folder, args.read, args.io_size, tid) + if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size): + create_file(filename, args.io_size) + + task_log(tid, f'Allocate tensor of size {args.io_size} bytes') + + if args.gpu: + buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}') + else: + buffer = create_page_locked_tensor(args.io_size, True) + + task_log(tid, + f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}', + force=True) + + task_log(tid, 'created torch_io engine') + + ctxt = {} + ctxt[FILE] = filename + ctxt[NUM_BYTES] = args.io_size + ctxt[BUFFER] = buffer + ctxt[ELAPSED_SEC] = 0 + return ctxt diff --git a/deepspeed/nvme/validate_async_io.py b/deepspeed/nvme/validate_async_io.py new file mode 100644 index 000000000000..10fb638347bc --- /dev/null +++ b/deepspeed/nvme/validate_async_io.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" +from deepspeed.ops.op_builder import AsyncIOBuilder +assert AsyncIOBuilder().is_compatible() +assert AsyncIOBuilder().load() diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index b5a03c458a46..15179984173c 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -6,11 +6,10 @@ from . import adam from . import adagrad from . import lamb -#from ..git_version_info_installed import installed_ops as __installed_ops__ -#if __installed_ops__['sparse_attn']: +from . import lion from . import sparse_attention from . import transformer - +from . import fp_quantizer from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index c356a52777f2..dbde6d95f652 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -34,7 +34,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, closure=None): """Update the model parameters. .. note:: @@ -46,8 +46,6 @@ def step(self, closure=None, fp16_param_groups=None): Args: closure (callable, optional): closure to compute the loss. Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. Returns: loss: if ``closure`` is provided. Otherwise ``None``. @@ -94,16 +92,7 @@ def step(self, closure=None, fp16_param_groups=None): sparse_exp_avg_sq.values()) p[sparse_param.indices()] = sparse_param.values() state['exp_avg_sq'][sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() - if fp16_param_groups is not None: - fp16_param_groups[group_id][param_id][sparse_param.indices()] = sparse_param.values() else: - if fp16_param_groups is not None: - self.ds_opt_adagrad.adagrad_update_copy(self.opt_id, state['step'], group['lr'], group['eps'], - group['weight_decay'], p.data, p.grad.data, - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'], - group['weight_decay'], p.data, p.grad.data, - state['exp_avg_sq']) + self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'], + group['weight_decay'], p.data, p.grad.data, state['exp_avg_sq']) return loss diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index a29bb9447d01..4f021f05136a 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -1,7 +1,9 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from .cpu_adam import DeepSpeedCPUAdam from .fused_adam import FusedAdam +from .zenflow_cpu_adam import ZenFlowCPUAdam +from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 9fdf7311a764..02f55f609a65 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -38,11 +38,6 @@ def __init__(self, the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. - For calling step function, there are two options available: (1) update optimizer's states and (2) update - optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second - option can bring 30% higher throughput than the doing the copy separately using option one. - - .. note:: We recommend using our `config `_ @@ -63,8 +58,10 @@ def __init__(self, algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! adamw_mode: select between Adam and AdamW implementations (default: AdamW) - full_precision_optimizer_states: creates momementum and variance in full precision regardless of - the precision of the parameters (default: True) + fp32_optimizer_states: creates momentum and variance in full precision regardless of + the precision of the parameters. Set to False to keep optimizer states + in the parameter dtype (e.g. bf16), which reduces the optimizer-state + memory footprint at the cost of lower state precision. (default: True) """ default_args = dict(lr=lr, @@ -107,7 +104,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, closure=None): """Update the model parameters. .. note:: @@ -119,8 +116,6 @@ def step(self, closure=None, fp16_param_groups=None): Args: closure (callable, optional): closure to compute the loss. Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. Returns: loss: if ``closure`` is provided. Otherwise ``None``. @@ -134,13 +129,6 @@ def step(self, closure=None, fp16_param_groups=None): # intended device for step device = torch.device('cpu') - # converting the fp16 params to a group of parameter - if type(fp16_param_groups) is list: - if type(fp16_param_groups[0]) is not list: - fp16_param_groups = [fp16_param_groups] - elif fp16_param_groups is not None: - fp16_param_groups = [[fp16_param_groups]] - for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group['params']): @@ -169,13 +157,88 @@ def step(self, closure=None, fp16_param_groups=None): state['step'] += 1 beta1, beta2 = group['betas'] - if fp16_param_groups is not None: - self.ds_opt_adam.adam_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2, - group['eps'], group['weight_decay'], group['bias_correction'], - p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq']) + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) + return loss + + @torch.no_grad() + def step_subgroup(self, subgroup_id: int, closure=None): + """Update the model parameters in a single subgroup (by index).""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Intended device for step + device = torch.device('cpu') + + for group in self.param_groups: + for p in group['params']: + + if p.grad is None: + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[subgroup_id] + + if len(state) == 0: + state['step'] = 0 + + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + + state['step'] += 1 + beta1, beta2 = group['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) + return loss + + @torch.no_grad() + def rollback_subgroup(self, sub_group_id: int, closure=None): + """ + Rollback the optimizer state for a specific subgroup. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Intended device for step + device = torch.device('cpu') + + # Validate subgroup state exists and is initialized + if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0: + raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} " + f"as it has not been initialized.") + + subgroup_state = self.state[sub_group_id] + + # Check if we can rollback (step count must be > 0) + if subgroup_state.get('step', 0) <= 0: + raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: " + f"step count is {subgroup_state.get('step', 0)}") + + for _, group in enumerate(self.param_groups): + for _, param in enumerate(group['params']): + if param.grad is None: + continue + + assert param.device == device, ( + f"CPUAdam param is on {param.device} and must be 'cpu', " + f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.") + + beta1, beta2 = group['betas'] + + self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2, + group['eps'], group['weight_decay'], group['bias_correction'], + param.data, param.grad.data, subgroup_state['exp_avg'], + subgroup_state['exp_avg_sq']) + + subgroup_state['step'] -= 1 return loss diff --git a/deepspeed/ops/adam/fused_adam.py b/deepspeed/ops/adam/fused_adam.py index ae7a6f0a87ce..53f859e9cc87 100644 --- a/deepspeed/ops/adam/fused_adam.py +++ b/deepspeed/ops/adam/fused_adam.py @@ -4,7 +4,7 @@ # DeepSpeed Team """ Copyright NVIDIA/apex -This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +This file is adapted from fused adam in NVIDIA/apex, commit 6bd01c4 """ import torch @@ -18,13 +18,36 @@ class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. - Currently GPU-only. + Currently GPU-only. Requires Apex to be installed via + ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. This version of fused Adam implements 2 fusions. * Fusion of the Adam update's elementwise operations * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adam_w_mode=False``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + ... + opt.step() + + :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, + you may choose any ``opt_level``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") + ... + opt.step() + + In general, ``opt_level="O1"`` is recommended. + + + .. warning:: + A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. These additional arguments + are now deprecated and unnecessary. + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. Arguments: @@ -81,7 +104,7 @@ def zero_grad(self): else: super(FusedAdam, self).zero_grad() - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): """Performs a single optimization step. Arguments: @@ -99,14 +122,19 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no loss = closure() for group in self.param_groups: + if len(group['params']) == 0: + continue bias_correction = 1 if group['bias_correction'] else 0 beta1, beta2 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel if 'step' not in group: group['step'] = 0 # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] + g_bf, p_bf, m_bf, v_bf = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] for p in group['params']: @@ -120,7 +148,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no # State initialization if len(state) == 0: # DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately. - # While this is not an issue for ZeRO 1 & 2, since they apply a single optimizatin step to the whole param group at the same time. + # While this is not an issue for ZeRO 1 & 2, since they apply a single optimization step to the whole param group at the same time. # In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists. state['step'] = group.get('step', 0) # Exponential moving average of gradient values @@ -133,20 +161,32 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no p_16.append(p.data) m_16.append(state['exp_avg']) v_16.append(state['exp_avg_sq']) + elif p.dtype == torch.bfloat16: + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state['exp_avg']) + v_bf.append(state['exp_avg_sq']) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) m_32.append(state['exp_avg']) v_32.append(state['exp_avg_sq']) else: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + raise RuntimeError('FusedAdam only support fp16, bf16 and fp32.') - if (len(g_16) > 0): + if len(g_16) > 0: state['step'] += 1 multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], group['lr'], beta1, beta2, group['eps'], state['step'], self.adam_w_mode, bias_correction, group['weight_decay']) - if (len(g_32) > 0): + + if len(g_bf) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_bf, p_bf, m_bf, v_bf], + group['lr'], beta1, beta2, group['eps'], state['step'], self.adam_w_mode, + bias_correction, group['weight_decay']) + + if len(g_32) > 0: state['step'] += 1 multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], group['lr'], beta1, beta2, group['eps'], state['step'], self.adam_w_mode, diff --git a/deepspeed/ops/adam/zenflow_cpu_adam.py b/deepspeed/ops/adam/zenflow_cpu_adam.py new file mode 100644 index 000000000000..0809d7a0f7e0 --- /dev/null +++ b/deepspeed/ops/adam/zenflow_cpu_adam.py @@ -0,0 +1,138 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.ops.adam import DeepSpeedCPUAdam +import torch + + +class ZenFlowCPUAdam(DeepSpeedCPUAdam): + + def __init__(self, *args, overlap_step=False, **kwargs): + super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) + self.overlap_step = overlap_step + if not self.overlap_step: + print("ZenFlowCPUAdam initialized with normal step.") + self.step = self._sequential_step + else: + print("ZenFlowCPUAdam initialized with overlap step.") + self.step = self._parallel_step + + @torch.no_grad() + def _sequential_step(self, step_id, closure=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + + state['step'] = step_id + beta1, beta2 = group['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) + return loss + + @torch.no_grad() + def _parallel_step(self, step_id, now_state, group_info, closure=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + stale_param = None + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + assert p.data.is_shared(), "param.data must be in shared memory" + if not hasattr(p, 'overlap_grad'): + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + # print("creating", flush=True) + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) + exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) + state['exp_avg'] = [exp_avg, exp_avg.clone()] + state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] + + state['step'] = step_id + beta1, beta2 = group_info['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, + group_info['eps'], group_info['weight_decay'], + group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, + state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) + p.stale_param.data.copy_(p.data.clone()) + return loss diff --git a/deepspeed/ops/adam/zenflow_torch_adam.py b/deepspeed/ops/adam/zenflow_torch_adam.py new file mode 100644 index 000000000000..1d55210d6edc --- /dev/null +++ b/deepspeed/ops/adam/zenflow_torch_adam.py @@ -0,0 +1,987 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import cast, List, Optional, Tuple, Union +from torch import Tensor + +from deepspeed.utils.torch import required_torch_version + +# Check if we have PyTorch >= 2.0 for ZenFlow features +_ZENFLOW_AVAILABLE = required_torch_version(min_version=2.1) + +if _ZENFLOW_AVAILABLE: + try: + from torch.optim.optimizer import ( + _default_to_fused_or_foreach, + _disable_dynamo_if_unsupported, + _get_capturable_supported_devices, + _get_value, + _stack_if_compiling, + _view_as_real, + DeviceDict, + Optimizer, + ) + except ImportError as e: + # print(f"[WARNING] ZenFlow disabled: torch internal optimizer symbols could not be imported: {e}") + _ZENFLOW_AVAILABLE = False + +if not _ZENFLOW_AVAILABLE: + # safe disable dynamo if unsupported + def _disable_dynamo_if_unsupported(**kwargs): # noqa + + def wrapper(fn): + return fn + + return wrapper + + _ZENFLOW_AVAILABLE = False + + +class ZenFlowSelectiveAdamW(torch.optim.AdamW): + + def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): + if not _ZENFLOW_AVAILABLE: + raise RuntimeError("ZenFlow features are not available with PyTorch < 2.0. " + "Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' " + "from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.") + super(ZenFlowSelectiveAdamW, self).__init__(*args, **kwargs) + + self.offload = offload + + if offload: + self.step = self._step_with_offload + self.bucket_size = bucket_size + else: + self.step = self._step_without_offload + + def temp_copy_param(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + for param in params: + if hasattr(param, "selected_grad"): + temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len( + param.shape) != 1 else param.data.clone().detach() + if self.offload: + param.temp_selected_param = temp_selected_param.cpu() + else: + param.temp_selected_param = temp_selected_param + + def copy_mv_from_cpu(self, params): + for param in params: + param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True) + param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True) + + def copy_mv_to_cpu(self, params): + for param in params: + param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True) + param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True) + param.exp_avg = None + param.exp_avg_sq = None + + def clear_selected_mv(self): + print("Zenflow: clearing selective optimizer states...") + for group in self.param_groups: + for param in group['params']: + state = self.state.setdefault(param, {}) + if len(state) == 0: + continue + if self.offload: + param.exp_avg_cpu_data.zero_() + param.exp_avg_sq_cpu_data.zero_() + else: + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + + @torch.no_grad() + def _step_without_offload(self): + for group in self.param_groups: + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in group["params"]: + if hasattr(param, "selected_grad"): + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: + selected_param.copy_(param.temp_selected_param) + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for i, param in enumerate(group["params"]): + if hasattr(param, "selected_grad"): + if len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] + + for param in group["params"]: + if hasattr(param, "temp_selected_param"): + param.temp_selected_param = None + param.selected_grad = None + + @torch.no_grad() + def _step_with_offload(self): + """ + Performs parameter updates in offload mode. + + In this mode, group_step() calls adamw() on each pre-partitioned param bucket, + so memory can be released after each bucket update to reduce GPU overhead. + Without offload, adamw() is called directly for speed. + """ + for group_id, group in enumerate(self.param_groups): + params = group["params"] + + bucket = [] + bucket_numel = 0 + + def flush_bucket(): + if not bucket: + return + for param in bucket: + if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None: + selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data + temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True) + selected_param.copy_(temp_selected_param) + param.temp_selected_param = None + + self.group_step({group_id: bucket}) + bucket.clear() + + for param in params: + if hasattr(param, "selected_grad"): + bucket.append(param) + bucket_numel += param.numel() + if bucket_numel >= self.bucket_size: + flush_bucket() + bucket_numel = 0 + + flush_bucket() + + @torch.no_grad() + def group_step(self, group_to_paramlist): + for group_id, params in group_to_paramlist.items(): + group = self.param_groups[group_id] + + if self.offload: + self.copy_mv_from_cpu(params) + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in params: + if hasattr(param, "selected_grad"): + is_2d = (len(param.shape) != 1) + selected_param = param.data[:, param.selected_indices] if is_2d else param.data + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + if not self.offload: + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + + if self.offload: + exp_avg_t = param.exp_avg.view_as(selected_param) + exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param) + else: + exp_avg_t = state["exp_avg"] + exp_avg_sq_t = state["exp_avg_sq"] + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(exp_avg_t) + exp_avg_sqs.append(exp_avg_sq_t) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for i, param in enumerate(params): + if hasattr(param, "selected_grad") and len(param.shape) != 1: + param.data[:, param.selected_indices] = params_with_grad[i] + + if self.offload: + self.copy_mv_to_cpu(params) + + for param in params: + param.selected_grad = None + + +class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW): + + def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs): + super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs) + self.offload = offload + + if offload: + self.step = self._step_with_offload + self.bucket_size = bucket_size + else: + self.step = self._step_without_offload + + @torch.no_grad() + def temp_copy_param(self, paramlist): + for param in paramlist: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view( + param.complete_numel // num_row, num_row) + temp_selected_param = param_2d[param.selected_indices, :].clone().detach() + else: + temp_selected_param = param.ds_tensor.data.clone().detach() + + if self.offload: + param.temp_selected_param = temp_selected_param.cpu() + else: + param.temp_selected_param = temp_selected_param + + def clear_selected_mv(self): + print("Zenflow: clearing selective optimizer states...") + for group in self.param_groups: + for param in group['params']: + state = self.state.setdefault(param, {}) + if len(state) == 0: + continue + if self.offload: + param.exp_avg_cpu_data.zero_() + param.exp_avg_sq_cpu_data.zero_() + else: + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + + @torch.no_grad() + def _step_without_offload(self): + for group in self.param_groups: + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + for param in group["params"]: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data + if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None: + selected_param.copy_(param.temp_selected_param) + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + for i, param in enumerate(group["params"]): + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = params_with_grad[i] + + for param in group["params"]: + if hasattr(param, "temp_selected_param"): + param.temp_selected_param = None + param.selected_grad = None + + def copy_mv_from_cpu(self, params): + for param in params: + param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True) + param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True) + + def copy_mv_to_cpu(self, params): + for param in params: + param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True) + param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True) + param.exp_avg = None + param.exp_avg_sq = None + + @torch.no_grad() + def group_step(self, paramlist): + + group_to_paramlist = {} + for param in paramlist: + group_id = param.group_id + if group_id not in group_to_paramlist: + group_to_paramlist[group_id] = [] + group_to_paramlist[group_id].append(param) + + for group_id in sorted(group_to_paramlist.keys()): + params = group_to_paramlist[group_id] + group = self.param_groups[group_id] + + if self.offload: + self.copy_mv_from_cpu(params) + + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + for param in params: + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + selected_param = param_2d[param.selected_indices, :] + else: + selected_param = param.ds_tensor.data + + state = self.state.setdefault(param, {}) + if len(state) == 0: + state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device) + if amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like(selected_param) + if not self.offload: + state["exp_avg"] = torch.zeros_like(selected_param) + state["exp_avg_sq"] = torch.zeros_like(selected_param) + + if self.offload: + exp_avg_t = param.exp_avg.view_as(selected_param) + exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param) + else: + exp_avg_t = state["exp_avg"] + exp_avg_sq_t = state["exp_avg_sq"] + + params_with_grad.append(selected_param) + grads.append(param.selected_grad) + exp_avgs.append(exp_avg_t) + exp_avg_sqs.append(exp_avg_sq_t) + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + state_steps.append(state["step"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + for i, param in enumerate(params): + if hasattr(param, "selected_grad"): + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = params_with_grad[i] + + if self.offload: + self.copy_mv_to_cpu(params) + + for param in params: + param.selected_grad = None + + @torch.no_grad() + def _step_with_offload(self): + """ + Performs parameter updates in offload mode. + + In this mode, group_step() calls adamw() on each pre-partitioned param bucket, + so memory can be released after each bucket update to reduce GPU overhead. + Without offload, adamw() is called directly for speed. + """ + + for group_id, group in enumerate(self.param_groups): + params = group["params"] + + bucket = [] + bucket_numel = 0 + + def flush_bucket(): + if not bucket: + return + for param in bucket: + if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None: + temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True) + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row != 1: + param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, + param.complete_numel).view( + param.complete_numel // num_row, num_row) + param_2d[param.selected_indices, :] = temp_selected_param + else: + param.ds_tensor.data.copy_(temp_selected_param) + param.temp_selected_param = None + + self.group_step(bucket) + bucket.clear() + + for param in params: + if hasattr(param, "selected_grad"): + bucket.append(param) + bucket_numel += param.numel() + if bucket_numel >= self.bucket_size: + flush_bucket() + bucket_numel = 0 + + flush_bucket() + + +def _single_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if capturable or differentiable: + step = step_t + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + else: + denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = bias_correction2**0.5 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True") + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices(supports_xla=False) + assert all( + p.device.type == step.device.type and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert not differentiable, "_foreach ops don't support autograd" + + assert grad_scale is None and found_inf is None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if has_complex: + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0) + else: + torch._foreach_add_(device_state_steps, 1) + + # Perform stepweight decay + if weight_decay != 0: + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del device_grads + + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [1 - beta1**_get_value(step) for step in device_state_steps] + bias_correction2 = [1 - beta2**_get_value(step) for step in device_state_steps] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [ + bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] + ] + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, + device_exp_avgs, + exp_avg_sq_sqrt, + step_size, # type: ignore[arg-type] + ) + + +def _fused_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, + has_complex: bool, # Needed for consistency. +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ({grad_scale.device: grad_scale} if grad_scale is not None else {}) + found_inf_dict: DeviceDict = ({found_inf.device: found_inf} if found_inf is not None else {}) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ({lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault(device, grad_scale.to(device, non_blocking=True)) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device, non_blocking=True)) + if lr_dict is not None and device not in lr_dict: + lr = lr_dict.setdefault( + device, + lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + ) + torch._foreach_add_(device_state_steps, 1) + torch._fused_adamw_( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps)) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + if not _ZENFLOW_AVAILABLE: + raise RuntimeError("ZenFlow adamw function is not available with PyTorch < 2.0. " + "Please upgrade to PyTorch 2.0+ to use ZenFlow, or omit 'zenflow' " + "from your DeepSpeed configuration to use the default ZeRO-Offload optimizer.") + + if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adamw + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamw + else: + func = _single_tensor_adamw + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + ) diff --git a/deepspeed/ops/compile/__init__.py b/deepspeed/ops/compile/__init__.py new file mode 100755 index 000000000000..e38d56359fea --- /dev/null +++ b/deepspeed/ops/compile/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..op_builder import DeepCompileBuilder diff --git a/deepspeed/ops/deepspeed4science/__init__.py b/deepspeed/ops/deepspeed4science/__init__.py new file mode 100644 index 000000000000..1c5fd280fc32 --- /dev/null +++ b/deepspeed/ops/deepspeed4science/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .evoformer_attn import DS4Sci_EvoformerAttention, EvoformerFusedAttention diff --git a/deepspeed/ops/deepspeed4science/evoformer_attn.py b/deepspeed/ops/deepspeed4science/evoformer_attn.py new file mode 100644 index 000000000000..da5843d6de31 --- /dev/null +++ b/deepspeed/ops/deepspeed4science/evoformer_attn.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +from deepspeed.ops.op_builder import EvoformerAttnBuilder +from deepspeed.accelerator import get_accelerator + +kernel_ = None + + +def _attention(Q, K, V, bias1, bias2): + assert Q.shape[-3] > 16, "seq_len must be greater than 16" + O = torch.empty_like(Q, dtype=Q.dtype) + assert get_accelerator().on_accelerator(Q), "Q must be on cuda" + assert get_accelerator().on_accelerator(K), "K must be on cuda" + assert get_accelerator().on_accelerator(V), "V must be on cuda" + assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda" + assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda" + global kernel_ + if kernel_ is None: + kernel_ = EvoformerAttnBuilder().load() + nheads = Q.shape[-2] + nq = (Q.shape[-3] + 31) // 32 * 32 + nb = np.prod(Q.shape[:-3]) + lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device) + kernel_.attention(Q, K, V, bias1, bias2, O, lse) + return O, lse + + +def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad): + assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value" + dQ = torch.empty_like(Q, dtype=Q.dtype) + dK = torch.empty_like(K, dtype=K.dtype) + dV = torch.empty_like(V, dtype=V.dtype) + assert get_accelerator().on_accelerator(dO), "dO must be on cuda" + assert get_accelerator().on_accelerator(Q), "Q must be on cuda" + assert get_accelerator().on_accelerator(K), "K must be on cuda" + assert get_accelerator().on_accelerator(V), "V must be on cuda" + assert get_accelerator().on_accelerator(O), "O must be on cuda" + global kernel_ + if kernel_ is None: + kernel_ = EvoformerAttnBuilder().load() + delta = torch.empty_like(lse) + if bias1_grad: + dB1 = torch.zeros_like(bias1, dtype=torch.float32) + else: + dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device) + if bias2_grad: + dB2 = torch.zeros_like(bias2, dtype=torch.float32) + else: + dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device) + kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2) + return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype) + + +class EvoformerFusedAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, bias1=None, bias2=None): + """ + q, k, v: are in shape [*, L, H, D] + """ + bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device) + bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + o, lse = _attention(q, k, v, bias1_, bias2_) + ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_) + return o + + @staticmethod + def backward(ctx, grad_output): + q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors + is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3] + is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4] + dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad) + if not is_b1_grad: + dB1 = None + if not is_b2_grad: + dB2 = None + return dQ, dK, dV, dB1, dB2 + + +def DS4Sci_EvoformerAttention(Q, K, V, biases): + assert len(biases) <= 2 + + if (len(biases) == 0): + biases.append(None) + + if (len(biases) == 1): + biases.append(None) + + bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2]) + bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2]) + + if biases[0] is not None: + assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect" + + if biases[1] is not None: + assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect" + + return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1]) diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py new file mode 100644 index 000000000000..f9cf23373c26 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .quantize import FP_Quantize, Quantizer +from .fp8_gemm import matmul_fp8 diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py new file mode 100644 index 000000000000..db4fa5ae2c92 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch + + +def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer): + from deepspeed import get_accelerator + + if not get_accelerator().is_triton_supported(): + return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer) + else: + # Import dynamically to prevent failures on systems without triton. + from .fp8_gemm_triton import matmul_fp8_triton + return matmul_fp8_triton(inp, weight, scale, quantization_group_size) + + +def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer): + return torch.matmul(inp, quantizer.dequantize(weight, scale=scale)) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py new file mode 100644 index 000000000000..086525cc6442 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset + ((k * BLOCK_SIZE_K * stride_bk) // quantization_group_size)) + # Dequantize weight (fp8 -> bf16) + w = (weight & 0x80).to(tl.uint16) << 8 + w = w | ((weight & 0x7f).to(tl.uint16) << 4) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True).to(tl.float32) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8_triton(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py new file mode 100644 index 000000000000..71fe96267f85 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import abc +from abc import ABC + +import gc +from deepspeed.ops.op_builder import FPQuantizerBuilder +from deepspeed.accelerator import get_accelerator + +fp_quant_module = None + + +class Quantizer(ABC): + """ + Abstract Quantizer class that implements quantize/dequantize methods. + + Arguments: + group_size (int, optional): number of values or elements that are grouped + together for the quantization process. + """ + + def __init__(self, group_size=512) -> None: + self.group_size = group_size + + @abc.abstractmethod + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + ... + + @abc.abstractmethod + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + ... + + +class FP_Quantize(Quantizer): + + def __init__(self, quantization_config) -> None: + global fp_quant_module + super().__init__(group_size=quantization_config.group_size) + if fp_quant_module is None: + fp_quant_module = FPQuantizerBuilder().load() + self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True) + self.q_config = quantization_config + + self.orig_dtype = None + self.num_groups = None + self.input_q = None + self.scale = None + + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + assert input.dtype == torch.bfloat16, "only support bf16 for now" + if return_meta_tensor: + assert q_bits == 8, "meta tensor is only supported with q_bit=8" + + self.orig_dtype = input.dtype + self.orig_shape = input.shape + + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" + self.num_groups = input.numel() // self.group_size + self.input_q = torch.ones(self.num_groups, + int(self.group_size * q_bits) // 8 + 4, + dtype=torch.uint8, + device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) + if return_meta_tensor: + data, self.scale = out.split(self.group_size, dim=-1) + data = data.contiguous().reshape(input.shape) + self.scale = self.scale.contiguous() + del self.input_q + del out + gc.collect() + get_accelerator().empty_cache() + return data, self.scale + + return out + + def to(self, *args, **kwargs): + # Intermediate tensors may need to be moved to different devices + if hasattr(self, 'input_q'): + self.input_q = self.input_q.to(*args, **kwargs) + if hasattr(self, 'scale'): + self.scale = self.scale.to(*args, **kwargs) + + def get_scales(self): + return fp_quant_module.get_scales(self.scale, self.num_groups) + + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype, + device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + '[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) + return fp_out + + def selective_dequantize(self, + input_q, + indexes, + fp_out=None, + q_bits=8, + q_mantisa_bits=3, + scale=None) -> torch.Tensor: + assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \ + "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function." + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty( + (indexes.shape[0], + *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + '[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, + q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/deepspeed/ops/gds/__init__.py b/deepspeed/ops/gds/__init__.py new file mode 100755 index 000000000000..3c0762c81076 --- /dev/null +++ b/deepspeed/ops/gds/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..op_builder import GDSBuilder diff --git a/deepspeed/ops/lion/__init__.py b/deepspeed/ops/lion/__init__.py new file mode 100755 index 000000000000..2f90e5ec2e80 --- /dev/null +++ b/deepspeed/ops/lion/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cpu_lion import DeepSpeedCPULion +from .fused_lion import FusedLion diff --git a/deepspeed/ops/lion/cpu_lion.py b/deepspeed/ops/lion/cpu_lion.py new file mode 100755 index 000000000000..03342a3fcd34 --- /dev/null +++ b/deepspeed/ops/lion/cpu_lion.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from cpuinfo import get_cpu_info +from deepspeed.utils import logger +from deepspeed.utils.logging import should_log_le +from deepspeed.ops.op_builder import CPULionBuilder + + +class DeepSpeedCPULion(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, model_params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, fp32_optimizer_states=True): + """Fast vectorized implementation of Lion optimizer on CPU: + + See Symbolic Discovery of Optimization Algorithms (https://doi.org/10.48550/arXiv.2302.06675). + + .. note:: + We recommend using our `config + `_ + to allow :meth:`deepspeed.initialize` to build this optimizer + for you. + + + Arguments: + model_params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + full_precision_optimizer_states: creates momentum and variance in full precision regardless of + the precision of the parameters (default: True) + """ + + default_args = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super(DeepSpeedCPULion, self).__init__(model_params, default_args) + + cpu_info = get_cpu_info() + self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown" + if "amd" in self.cpu_vendor: + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + if p.dtype == torch.half: + logger.warning("FP16 params for CPULion may not work on AMD CPUs") + break + else: + continue + break + + self.opt_id = DeepSpeedCPULion.optimizer_id + DeepSpeedCPULion.optimizer_id = DeepSpeedCPULion.optimizer_id + 1 + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_lion = CPULionBuilder().load() + + self.ds_opt_lion.create_lion(self.opt_id, lr, betas[0], betas[1], weight_decay, should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_lion.destroy_lion(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPULion, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + assert p.device == device, f"CPULion param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2, + group['weight_decay'], p.data, p.grad.data, state['exp_avg']) + return loss diff --git a/deepspeed/ops/lion/fused_lion.py b/deepspeed/ops/lion/fused_lion.py new file mode 100644 index 000000000000..7332a7f96361 --- /dev/null +++ b/deepspeed/ops/lion/fused_lion.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +This file is modified from fused_adam.py +""" + +import torch +from .multi_tensor_apply import MultiTensorApply + +multi_tensor_applier = MultiTensorApply(2048 * 32) +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import FusedLionBuilder + + +class FusedLion(torch.optim.Optimizer): + """Implements Lion algorithm. + + Currently GPU-only. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + + .. _Symbolic Discovery of Optimization Algorithms: + https://doi.org/10.48550/arXiv.2302.06675 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0., set_grad_none=True): + + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super(FusedLion, self).__init__(params, defaults) + self.set_grad_none = set_grad_none + + fused_lion_cuda = FusedLionBuilder().load() + # Skip buffer + self._dummy_overflow_buf = get_accelerator().IntTensor([0]) + self.multi_tensor_lion = fused_lion_cuda.multi_tensor_lion + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedLion, self).zero_grad() + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError('FusedLion has been updated.') + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + if len(group['params']) == 0: + continue + beta1, beta2 = group['betas'] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' not in group: + group['step'] = 0 + + # create lists for multi-tensor apply + g_16, p_16, m_16 = [], [], [] + g_bf, p_bf, m_bf = [], [], [] + g_32, p_32, m_32 = [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise NotImplementedError('FusedLion does not support sparse gradients') + + state = self.state[p] + # State initialization + if len(state) == 0: + # DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately. + # While this is not an issue for ZeRO 1 & 2, since they apply a single optimization step to the whole param group at the same time. + # In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists. + state['step'] = group.get('step', 0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + + if p.dtype == torch.float16: + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state['exp_avg']) + elif p.dtype == torch.bfloat16: + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state['exp_avg']) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state['exp_avg']) + else: + raise RuntimeError('FusedLion only support fp16, bf16 and fp32.') + + if len(g_16) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_16, p_16, m_16], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + if len(g_bf) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_bf, p_bf, m_bf], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + if len(g_32) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_32, p_32, m_32], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + return loss diff --git a/deepspeed/ops/lion/multi_tensor_apply.py b/deepspeed/ops/lion/multi_tensor_apply.py new file mode 100644 index 000000000000..0ba228505cef --- /dev/null +++ b/deepspeed/ops/lion/multi_tensor_apply.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Copyright NVIDIA/apex +This file is adapted from NVIDIA/apex, commit a109f85 +""" + + +class MultiTensorApply(object): + + def __init__(self, chunk_size): + self.chunk_size = chunk_size + + def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/deepspeed/ops/random_ltd/dropping_utils.py b/deepspeed/ops/random_ltd/dropping_utils.py index bc491716b7a8..dd36c94537f8 100644 --- a/deepspeed/ops/random_ltd/dropping_utils.py +++ b/deepspeed/ops/random_ltd/dropping_utils.py @@ -32,7 +32,7 @@ def gpt_sample_tokens(reserved_length: int, sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length) # Not certain the optimized kernel is actually better here, cause it kind of screws - # with alignment right if the sequence length is not divisble by like 16 + # with alignment right if the sequence length is not divisible by like 16 # new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length) if attn_mask is not None: new_mask = attn_mask[:, :, :reserved_length, :reserved_length] diff --git a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py index e25621bd0977..37f065e48631 100755 --- a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py +++ b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py @@ -8,7 +8,7 @@ class BertSparseSelfAttention(nn.Module): - """Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 + """Implements Sparse Self Attention layer of Bert model based on https://github.com/deepspeedai/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 For more information please see, TODO DeepSpeed Sparse Transformer. diff --git a/deepspeed/ops/sparse_attention/matmul.py b/deepspeed/ops/sparse_attention/matmul.py index b30028fffbaa..1c67ff9651df 100755 --- a/deepspeed/ops/sparse_attention/matmul.py +++ b/deepspeed/ops/sparse_attention/matmul.py @@ -128,18 +128,18 @@ def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stri inc_b = TK * stride_kb else: pinc += 2 - if meta['DSD']: - inc_b = tl.load(pinc) - inc_a = tl.load(pinc + 1) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = inc_b * stride_kb - if meta['DDS']: - inc_a = tl.load(pinc) - inc_b = tl.load(pinc + 1) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = inc_a * stride_ka + if meta['DSD']: + inc_b = tl.load(pinc) + inc_a = tl.load(pinc + 1) + inc_b = tl.multiple_of(inc_b, 8) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = inc_b * stride_kb + if meta['DDS']: + inc_a = tl.load(pinc) + inc_b = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.multiple_of(inc_b, 8) + inc_a = inc_a * stride_ka pa += inc_a pb += inc_b # pre-fetch diff --git a/deepspeed/ops/sparse_attention/sparsity_config.py b/deepspeed/ops/sparse_attention/sparsity_config.py index 7b7d5f50133c..b5d9be073bae 100644 --- a/deepspeed/ops/sparse_attention/sparsity_config.py +++ b/deepspeed/ops/sparse_attention/sparsity_config.py @@ -78,7 +78,7 @@ def __init__(self, num_heads, block=16, different_layout_per_head=False): super().__init__(num_heads, block, different_layout_per_head) def make_layout(self, seq_len): - """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. + """Set 1 to all blocks of the layout meaning the pattern is dense; not sparse. Arguments: seq_len: required: an integer determining the underling sequence length; must be <= max sequence length @@ -142,7 +142,7 @@ def __init__(self, if (num_different_global_patterns > 1 and not different_layout_per_head): raise ValueError( - f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.' + 'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.' ) if (num_different_global_patterns > (num_local_blocks // num_global_blocks)): raise ValueError( @@ -702,7 +702,7 @@ def set_sliding_window_layout(self, h, layout): num_blocks = layout.shape[1] if (num_blocks < self.num_sliding_window_blocks): raise ValueError( - f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!' + f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overall number of blocks in a row, {num_blocks}!' ) w = self.num_sliding_window_blocks // 2 diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 549a03a70f19..c0dd29f4f962 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -4,7 +4,8 @@ # DeepSpeed Team import json -from deepspeed.utils.types import ActivationFuncType +import torch +from deepspeed.utils.types import ActivationFuncType, NormType class TransformerConfig(): @@ -31,7 +32,6 @@ class DeepSpeedInferenceConfig(TransformerConfig): mp_size (optional): This argument is mainly used to create the parameters on the kernel side using model-parallel architecture. If the client model already takes care of this, there is no need to pass this argument. - fp16: Enable half-precision computation pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture stochastic_mode: Enable for high performance, please note that this flag has some level of non-determinism and can produce different results on different runs. However, we have seen @@ -42,6 +42,8 @@ class DeepSpeedInferenceConfig(TransformerConfig): scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation. return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture. + use_triton: This flag is to enable triton kernels in inference or not. + invert_mask: If True, the attention mask is inverted when passed to attention block. """ def __init__(self, @@ -52,9 +54,9 @@ def __init__(self, layer_norm_eps=1e-12, local_rank=-1, mp_size=1, - fp16=False, - q_int8=False, + dtype=torch.float16, pre_layer_norm=True, + norm_type=NormType.LayerNorm, stochastic_mode=False, scale_attention=True, triangular_masking=True, @@ -75,17 +77,22 @@ def __init__(self, scale_attn_by_inverse_layer_idx=False, return_single_tuple=False, set_empty_params=False, - transposed_mode=False): + transposed_mode=False, + use_triton=False, + triton_autotune=False, + num_kv=-1, + rope_theta=10000, + invert_mask=True): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) - self.fp16 = fp16 + self.dtype = dtype self.pre_layer_norm = pre_layer_norm + self.norm_type = norm_type self.local_rank = local_rank self.stochastic_mode = stochastic_mode self.epsilon = layer_norm_eps self.mp_size = mp_size - self.q_int8 = q_int8 self.scale_attention = scale_attention self.triangular_masking = triangular_masking self.local_attention = local_attention @@ -96,7 +103,6 @@ def __init__(self, self.return_tuple = return_tuple self.mlp_after_attn = mlp_after_attn self.mlp_act_func_type = mlp_act_func_type - self.specialized_mode = False self.training_mp_size = training_mp_size self.bigscience_bloom = bigscience_bloom self.max_out_tokens = max_out_tokens @@ -107,6 +113,11 @@ def __init__(self, self.return_single_tuple = return_single_tuple self.set_empty_params = set_empty_params self.transposed_mode = transposed_mode + self.use_triton = use_triton + self.triton_autotune = triton_autotune + self.num_kv = num_kv + self.rope_theta = rope_theta + self.invert_mask = invert_mask @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/diffusers_attention.py b/deepspeed/ops/transformer/inference/diffusers_attention.py index 63325e058e02..3c2340ccfc6f 100644 --- a/deepspeed/ops/transformer/inference/diffusers_attention.py +++ b/deepspeed/ops/transformer/inference/diffusers_attention.py @@ -10,10 +10,11 @@ from packaging import version as pkg_version from deepspeed.utils.logging import log_dist from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp +from deepspeed.ops.transformer.inference.op_binding.softmax_context import SoftmaxContextOp +from deepspeed.ops.transformer.inference.op_binding import LinearOp +from deepspeed.ops.transformer.inference.op_binding.pad_transform import PadTransformOp -# Cuda modules will be imported if needed -inference_cuda_module = None minus_inf = -10000.0 triton_flash_attn = None @@ -36,7 +37,8 @@ class DeepSpeedDiffusersAttentionFunction(Function): @staticmethod def forward(ctx, input, context, input_mask, config, attn_qkvw, attn_qw, attn_kw, attn_vw, attn_qkvb, num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob, - do_out_bias, score_context_func, linear_func, triton_flash_attn_kernel): + do_out_bias, score_context_func, linear_func, pad_transform_func, triton_flash_attn_kernel, + rope_theta): def _transpose_for_context(x): x = x.permute(0, 2, 1, 3) @@ -52,14 +54,14 @@ def _transpose_for_scores(x): return x.contiguous() def selfAttention_fp(input, context, input_mask): - if config.fp16 and input.dtype == torch.float32: + if config.dtype in [torch.half, torch.float16] and input.dtype == torch.float32: input = input.half() head_size = input.shape[-1] // config.heads do_flash_attn = (head_size <= 128) scale = (1 / norm_factor) * (1 / norm_factor) - if do_flash_attn and context == None: + if do_flash_attn and context is None: qkv_out = linear_func(input, attn_qkvw, attn_qkvb if attn_qkvb is not None else attn_qkvw, attn_qkvb - is not None, do_flash_attn, config.heads) + is not None, do_flash_attn, config.heads, False, rope_theta) context_layer = triton_flash_attn_kernel(qkv_out[0], qkv_out[1], qkv_out[2], scale, input.shape[-2] % 128 == 0) @@ -77,12 +79,11 @@ def selfAttention_fp(input, context, input_mask): query = query.contiguous() key = key.contiguous() value = value.contiguous() - query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, - do_flash_attn) + query, key, value = pad_transform_func(query, key, value, config.heads, do_flash_attn) attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1) context_layer = _transpose_for_context(torch.matmul(attention_scores, value)) - output = linear_func(context_layer, attn_ow, attn_ob, do_out_bias, False, config.heads) + output = linear_func(context_layer, attn_ow, attn_ob, do_out_bias, False, config.heads, False, rope_theta) return output output = selfAttention_fp(input, context, input_mask) @@ -116,12 +117,8 @@ def __init__( device = get_accelerator().current_device_name() if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 - data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float - data_type_fp = torch.half if config.fp16 else torch.float - global inference_cuda_module - if inference_cuda_module is None: - builder = InferenceBuilder() - inference_cuda_module = builder.load() + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype if DeepSpeedDiffusersAttention.layer_id == 1: log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0]) @@ -172,24 +169,24 @@ def __init__( self.norm_factor *= math.sqrt(self.config.layer_id + 1) # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191 - self.score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ - inference_cuda_module.softmax_context_fp16 - self.linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ - inference_cuda_module.linear_layer_fp32 - self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if not (config.fp16) else \ - inference_cuda_module.allocate_workspace_fp16 + self.workspace = WorkspaceOp(self.config) + self.score_context_func = SoftmaxContextOp(self.config) + self.linear_func = LinearOp(self.config) + self.pad_transform_func = PadTransformOp(self.config) - def forward(self, input, context=None, input_mask=None): + def allocate_workspace(self, size): + # Allocate memory only on first layer forward if self.config.layer_id == 0: - self.allocate_workspace(self.config.hidden_size, self.config.heads, - input.size()[1], - input.size()[0], DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, - 0, self.config.max_out_tokens) - output = DeepSpeedDiffusersAttentionFunction.apply(input, context, input_mask, self.config, self.attn_qkvw, - self.attn_qw, self.attn_kw, self.attn_vw, self.attn_qkvb, - self.num_attention_heads_per_partition, self.norm_factor, - self.hidden_size_per_partition, self.attn_ow, self.attn_ob, - self.do_out_bias, self.score_context_func, self.linear_func, - self.triton_flash_attn_kernel) + self.workspace.allocate_workspace(self.config.hidden_size, self.config.heads, size[1], size[0], + DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, 0, + self.config.max_out_tokens, self.config.min_out_tokens) + + def forward(self, input, context=None, input_mask=None): + self.allocate_workspace(input.size()) + output = DeepSpeedDiffusersAttentionFunction.apply( + input, context, input_mask, self.config, self.attn_qkvw, self.attn_qw, self.attn_kw, self.attn_vw, + self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, self.hidden_size_per_partition, + self.attn_ow, self.attn_ob, self.do_out_bias, self.score_context_func, self.linear_func, + self.pad_transform_func, self.triton_flash_attn_kernel, self.config.rope_theta) return output diff --git a/deepspeed/ops/transformer/inference/diffusers_transformer_block.py b/deepspeed/ops/transformer/inference/diffusers_transformer_block.py index 3d45714e543c..59fa4f609dcf 100644 --- a/deepspeed/ops/transformer/inference/diffusers_transformer_block.py +++ b/deepspeed/ops/transformer/inference/diffusers_transformer_block.py @@ -10,25 +10,9 @@ from .diffusers_attention import DeepSpeedDiffusersAttention from .bias_add import nhwc_bias_add from .diffusers_2d_transformer import Diffusers2DTransformerConfig -from deepspeed.ops.op_builder import InferenceBuilder, SpatialInferenceBuilder - -# Ops will be loaded on demand -transformer_cuda_module = None -spatial_cuda_module = None - - -def load_transformer_module(): - global transformer_cuda_module - if transformer_cuda_module is None: - transformer_cuda_module = InferenceBuilder().load() - return transformer_cuda_module - - -def load_spatial_module(): - global spatial_cuda_module - if spatial_cuda_module is None: - spatial_cuda_module = SpatialInferenceBuilder().load() - return spatial_cuda_module +from deepspeed.utils.types import ActivationFuncType +from .op_binding.gated_activation import GatedActivationOp +from .op_binding.layer_norm import LayerNormOp class DeepSpeedDiffusersTransformerBlock(nn.Module): @@ -73,10 +57,10 @@ def __init__(self, equivalent_module: nn.Module, config: Diffusers2DTransformerC self.attn_2.do_out_bias = False self.attn_2_bias = self.attn_2.attn_ob else: - self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False) + self.attn_2_bias = nn.Parameter(torch.zeros_like(self.norm3_g), requires_grad=False) - self.transformer_cuda_module = load_transformer_module() - load_spatial_module() + self.gated_activation = GatedActivationOp() + self.layer_norm = LayerNormOp() def forward(self, hidden_states, context=None, timestep=None, **kwargs): # In v0.12.0 of diffuser, several new kwargs were added. Capturing @@ -84,20 +68,20 @@ def forward(self, hidden_states, context=None, timestep=None, **kwargs): # In v0.11.0 of diffusers, the kwarg was changed from 'context' to 'encoder_hidden_states' # This is so we can support older and newer versions of diffusers - if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] != None: + if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] is not None: context = kwargs["encoder_hidden_states"] - out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps) + out_norm_1 = self.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps) out_attn_1 = self.attn_1(out_norm_1) - out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res( + out_norm_2, out_attn_1 = self.layer_norm.layer_norm_residual_store_pre_ln_res( out_attn_1, self.attn_1_bias, hidden_states, self.norm2_g, self.norm2_b, self.norm2_eps) out_attn_2 = self.attn_2(out_norm_2, context=context) - out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res( + out_norm_3, out_attn_2 = self.layer_norm.layer_norm_residual_store_pre_ln_res( out_attn_2, self.attn_2_bias, out_attn_1, self.norm3_g, self.norm3_b, self.norm3_eps) out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w) - out_geglu = self.transformer_cuda_module.bias_geglu(out_ff1, self.ff1_b) + out_geglu = self.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU) out_ff2 = nn.functional.linear(out_geglu, self.ff2_w) return nhwc_bias_add(out_ff2, self.ff2_b, other=out_attn_2) diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 46c36d337428..24f710d22494 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -20,8 +20,8 @@ class DeepSpeedSelfAttention(nn.Module): def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1): super(DeepSpeedSelfAttention, self).__init__() self.config = config - data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float - data_type_fp = torch.half if config.fp16 else torch.float + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu' @@ -37,7 +37,8 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.attn_ow = None self.attn_ob = None else: - qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 if config.num_kv < 0 else \ + ((self.config.heads + self.config.num_kv * 2) // self.config.mp_size) * (self.config.hidden_size // self.config.heads) self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, dtype=data_type, @@ -56,6 +57,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count requires_grad=False) self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size + self.num_kv_partition = self.config.num_kv // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads @@ -87,11 +89,11 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device) ] - def compute_attention(self, qkv_out, input_mask, layer_past, alibi): - if isinstance(qkv_out, list): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): + if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] - no_masking = input_mask is None + no_masking = input_mask is None or input_mask is False if no_masking: input_mask = torch.empty(1) @@ -101,25 +103,29 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): attn_mask=((1 - input_mask).to(qkv_out.dtype) * minus_inf) if input_mask.dtype == torch.int64 else input_mask, heads=self.num_attention_heads_per_partition, + num_kv=self.num_kv_partition, norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0), no_masking=no_masking, layer_id=self.config.layer_id, num_layers=DeepSpeedSelfAttention.num_layers, - alibi=alibi) + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer def _merge_qkv(self): qvkw = DeepSpeedSelfAttention._qkv_buffers[0] - qvkw[:self.hidden_size_per_partition, :] = self.attn_qw - qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw - qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw + qvkw[:self.hidden_size_per_partition, :] = self.attn_qw # type: ignore + qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw # type: ignore + qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw # type: ignore if self.attn_qb is not None: qvkb = DeepSpeedSelfAttention._qkv_buffers[1] qvkb[:self.hidden_size_per_partition] = self.attn_qb - qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb - qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb + qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb # type: ignore + qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb # type: ignore return DeepSpeedSelfAttention._qkv_buffers def forward(self, @@ -133,13 +139,13 @@ def forward(self, output_attentions=False, norm_w=None, norm_b=None, - alibi=None): + alibi=None, + **kwargs): if self.attn_qkvw is None: self._attn_qkvw, self._attn_qkvb = self._merge_qkv() else: self._attn_qkvw = self.attn_qkvw self._attn_qkvb = self.attn_qkvb - if not self.config.pre_layer_norm: qkv_out = self.linear_func(input=input, weight=self._attn_qkvw, @@ -151,23 +157,27 @@ def forward(self, else: qkv_out = self.qkv_func(input=input, weight=self._attn_qkvw, - bias=(self._attn_qkvb if self._attn_qkvb is not None else norm_b), + bias=self._attn_qkvb, gamma=norm_w, - beta=norm_b, - add_bias=(self.attn_qkvb is not None), - num_layers=DeepSpeedSelfAttention.num_layers, - num_heads=self.num_attention_heads_per_partition) + beta=norm_b) + + is_prompt = kwargs.get("first_token", qkv_out[0].shape[1] > 1) + token_idx = kwargs.get("token_idx", None) + position_ids = kwargs.get("position_ids", None) + context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, - alibi=alibi) - output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) + output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce(output, group=self.mp_group) - return (output, key_layer, value_layer, context_layer, inp_norm) @@ -211,8 +221,8 @@ def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_ return tensor_list - def compute_attention(self, qkv_out, input_mask, layer_past, alibi): - if isinstance(qkv_out, list): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): + if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] no_masking = input_mask is None @@ -249,8 +259,18 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): attention_scores = matmul_result.view(output_size[0], output_size[1], output_size[2], -1) offset = dist.get_rank() * self.num_attention_heads_per_partition if dist.is_initialized() else 0 + target_dtype = torch.float16 if self.config.dtype == torch.int8 else self.config.dtype + + # When using the hybrid engine with BLOOM, input_mask needs to be converted from torch.bool -> torch.int64 + if input_mask.dtype == torch.bool: + input_mask = input_mask.long() + + # Invert input_mask per transformer implementation (eg, in BLOOM, it's already inverted) + if self.config.invert_mask: + input_mask = 1 - input_mask + attention_probs = self.softmax_func(attn_scores=attention_scores, - attn_mask=((1 - input_mask).half() * minus_inf), + attn_mask=input_mask.to(target_dtype) * minus_inf, alibi=alibi, triangular=(self.config.triangular_masking and (attention_scores.shape[-2] > 1)), diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index a4375178347a..36de06db920f 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -7,24 +7,37 @@ import torch import torch.nn as nn from deepspeed import comm as dist +from deepspeed.utils.types import GATED_ACTIVATION_TYPES from deepspeed.accelerator import get_accelerator from .op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp class DeepSpeedMLP(nn.Module): + _inter_w_buffers = [] def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False): super(DeepSpeedMLP, self).__init__() self.config = config - data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float - data_type_fp = torch.half if config.fp16 else torch.float + + data_type = torch.int8 if self.config.dtype == torch.int8 else self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype device = get_accelerator().current_device_name() + + proj_factor = 2 if self.config.mlp_act_func_type in GATED_ACTIVATION_TYPES else 1 + self.config.intermediate_size = self.config.intermediate_size if self.config.intermediate_size > 0 else 4 * self.config.hidden_size + self.intm_w_sz_per_partition = self.config.intermediate_size * proj_factor // self.config.mp_size + self.intm_o_sz_per_partition = self.config.intermediate_size // self.config.mp_size + if self.config.set_empty_params: self.attn_nw = None self.attn_nb = None self.inter_w = None self.inter_b = None + self.inter_up_w = None + self.inter_up_b = None + self.inter_gate_w = None + self.inter_gate_b = None self.output_w = None self.output_b = None else: @@ -32,15 +45,15 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count requires_grad=False) self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), requires_grad=False) - intm_size_per_partition = self.config.intermediate_size // self.config.mp_size + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, - intm_size_per_partition, + self.intm_w_sz_per_partition, dtype=data_type, device=device), requires_grad=False) - self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), + self.inter_b = nn.Parameter(torch.empty(self.intm_w_sz_per_partition, dtype=data_type_fp, device=device), requires_grad=False) - self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.output_w = nn.Parameter(torch.empty(self.intm_o_sz_per_partition, self.config.hidden_size, dtype=data_type, device=device), @@ -59,29 +72,53 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.fused_gemm_gelu = GELUGemmOp(config) self.residual_add_func = ResidualAddOp(config) + if len(DeepSpeedMLP._inter_w_buffers) == 0: + DeepSpeedMLP._inter_w_buffers = [ + torch.empty(self.intm_w_sz_per_partition, self.config.hidden_size, dtype=data_type, device=device), + torch.empty(self.intm_w_sz_per_partition, dtype=data_type_fp, device=device) + ] + + def _merge_inter_w(self): + inter_w = DeepSpeedMLP._inter_w_buffers[0] + inter_w[:self.intm_w_sz_per_partition // 2, :] = self.inter_up_w # type: ignore + inter_w[self.intm_w_sz_per_partition // 2:, :] = self.inter_gate_w # type: ignore + if self.inter_up_b is not None: + inter_b = DeepSpeedMLP._inter_w_buffers[1] + inter_b[:self.intm_w_sz_per_partition // 2] = self.inter_up_b # type: ignore + inter_b[self.intm_w_sz_per_partition // 2:] = self.inter_gate_b # type: ignore + return DeepSpeedMLP._inter_w_buffers + def forward(self, input, residual, residual_norm, bias): + if self.inter_w is None: + self._inter_w, self._inter_b = self._merge_inter_w() + else: + self._inter_w = self.inter_w + self._inter_b = self.inter_b + residual_add = None if self.attn_nw is None: output = self.fused_gemm_gelu(input=residual_norm, - weight=self.inter_w, - bias=self.inter_b, + weight=self._inter_w, + bias=self._inter_b, weight_out=self.output_w) else: output, residual_add = self.mlp_gemm_func(input=input, residual=residual, - input_bias=bias, - weight_interm=self.inter_w, + weight_interm=self._inter_w, weight_out=self.output_w, - bias=self.inter_b, + input_bias=bias, + bias=self._inter_b, gamma=self.attn_nw, beta=self.attn_nb) + residual = self.residual_add_func(hidden_state=output, residual=residual, + add_bias=bias is not None, attention_output=input, attention_bias=bias if bias is not None else self.output_b, final_bias=self.output_b, - add_bias=bias is not None, residual_add=residual_add) if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce(residual, group=self.mp_group) + return residual diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index bf14a5fc36b2..3a9785985d19 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -7,17 +7,16 @@ import math import torch from torch.autograd import Function -#from ...inference.engine import inference_cuda_module, specialized_mode -# Cuda modules will be imported if needed -inference_cuda_module = None -specialized_mode = None import torch.nn as nn from .ds_attention import DeepSpeedSelfAttention from .config import DeepSpeedInferenceConfig +from .op_binding import SoftmaxOp, VectorMatMulOp, GELUGemmOp +from .op_binding.bias_residual import BiasResidualOp +from .op_binding.einsum_sec_sm_ecm import EinsumSecSmEcmOp +from .op_binding.layer_norm import LayerNormOp from ....moe.sharded_moe import TopKGate from deepspeed import comm as dist -from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder +from .op_binding.moe_res_matmul import MoEResMatmulOp class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): @@ -35,6 +34,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): using model-parallel architecture. If the client model already takes care of this, there is no need to pass this argument. fp16: Enable half-precision computation + bf16: Enable bf16 floating point computation pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture stochastic_mode: Enable for high performance, please note that this flag has some level of non-determinism and can produce different results on different runs. However, we have seen @@ -55,6 +55,7 @@ def __init__(self, local_rank=-1, mp_size=1, fp16=False, + bf16=False, q_int8=False, pre_layer_norm=True, stochastic_mode=False, @@ -76,9 +77,9 @@ def __init__(self, scale_attn_by_inverse_layer_idx=False): super(DeepSpeedMoEInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, - num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, q_int8, pre_layer_norm, - stochastic_mode, scale_attention, triangular_masking, local_attention, window_size, - return_tuple) + num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, bf16, q_int8, + pre_layer_norm, stochastic_mode, scale_attention, triangular_masking, local_attention, + window_size, return_tuple) self.moe_experts = moe_experts self.k = k self.capacity_factor = capacity_factor @@ -109,18 +110,13 @@ class DeepSpeedMLPFunction(Function): @staticmethod def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group, - async_op): + async_op, gelu_gemm_func, vector_matmul_func): if config.q_int8: - intermediate = inference_cuda_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon, - q_scales[2], (q_groups * (2**merge_count)), - config.pre_layer_norm) - output = inference_cuda_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups, - (merge_count)) + intermediate = gelu_gemm_func(input, inter_w, inter_b, config.epsilon, q_scales[2], + (q_groups * (2**merge_count)), config.pre_layer_norm) + output = vector_matmul_func(intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: - mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ - inference_cuda_module.fused_gemm_gelu_fp32 - - output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) + output = gelu_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group, async_op=async_op) @@ -151,10 +147,13 @@ def __init__(self, config, q_scales=None, q_groups=1, merge_count=1, mlp_extra_g self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups self.merge_count = int(math.log2(merge_count)) self.mp_group = mp_group + self.gelu_gemm_func = GELUGemmOp(self.config) + self.vector_matmul_func = VectorMatMulOp(self.config) def forward(self, input, async_op=False): return DeepSpeedMLPFunction.apply(input, self.inter_w, self.inter_b, self.config, self.output_b, self.output_w, - self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op) + self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op, + self.gelu_gemm_func, self.vector_matmul_func) class DeepSpeedMoEInference(nn.Module): @@ -188,18 +187,8 @@ def __init__(self, self.config = config self.config.layer_id = DeepSpeedMoEInference.layer_id - global inference_cuda_module - global specialized_mode - if inference_cuda_module is None: - specialized_mode = False - # InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string - builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder") - if builder != None and builder.is_compatible(): - inference_cuda_module = builder.load() - specialized_mode = True - else: - inference_cuda_module = InferenceBuilder().load() - self.config.specialized_mode = specialized_mode + + assert self.config.dtype != torch.bfloat16, "DeepSpeed MoE Transformer Inference not yet tested for bfloat support" DeepSpeedMoEInference.layer_id += 1 self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count) @@ -213,10 +202,8 @@ def __init__(self, self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping, mp_group) self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2)) - self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \ - inference_cuda_module.softmax_fp32 - self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ - inference_cuda_module.vector_matmul_fp32 + self.coef_func = SoftmaxOp(self.config) + self.vector_matmul_func = VectorMatMulOp(self.config) config.mp_size = 1 self.mlp = nn.ModuleList( @@ -226,7 +213,7 @@ def __init__(self, self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k, self.config.capacity_factor, self.config.eval_capacity_factor, self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens, - self.config.use_rts) + self.config.use_rts, self.ep_group) self.ep_group = ep_group self.mp_group = mp_group @@ -234,12 +221,10 @@ def __init__(self, print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__) - self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ - inference_cuda_module.bias_residual_fp32 - self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \ - inference_cuda_module.layer_norm_fp32 - self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.fp16 or self.config.q_int8 else \ - inference_cuda_module.einsum_sec_sm_ecm_fp32 + self.bias_residual_func = BiasResidualOp(self.config) + self.ds_layernorm = LayerNormOp(self.config) + self.einsum_sec_sm_ecm = EinsumSecSmEcmOp(self.config) + self.moe_res_matmul = MoEResMatmulOp(self.config) def res_coef_func(self, inp, async_op): inp = self.vector_matmul_func(inp, self.res_coef, async_op) @@ -302,8 +287,7 @@ def forward(self, input_mask = input_mask if attention_mask is None else attention_mask input_type = input.dtype - if (self.config.fp16 or self.config.q_int8) \ - and input.dtype == torch.float: + if (self.config.dtype in [torch.float16, torch.int8]) and input_type == torch.float: input = input.half() with torch.no_grad(): @@ -327,12 +311,12 @@ def forward(self, res_coef_out = self.res_coef_func(attention_output, async_op=True) if self.expert_mp_group is not None: - tensor_list = [ - torch.empty_like(attention_output) for _ in range(dist.get_world_size(group=self.expert_mp_group)) - ] - tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output - dist.all_gather(tensor_list, attention_output, group=self.expert_mp_group) - attention_output = torch.cat(tensor_list).contiguous() + world_size = dist.get_world_size(group=self.expert_mp_group) + gather_buffer = torch.empty(world_size * attention_output.numel(), + dtype=attention_output.dtype, + device=attention_output.device) + dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group) + attention_output = gather_buffer.view(-1, *attention_output.size()[1:]) ############## MoE Gating + Experts ############### dispatched_attention, combined_weights = self.moe_gate_einsum(attention_output) @@ -347,7 +331,7 @@ def forward(self, dim=0)[dist.get_rank(group=self.expert_mp_group)] if self.config.mlp_type == 'residual': - inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output) + self.moe_res_matmul(res_mlp_out, res_coef_out, output) output = self.bias_residual_func(output, residual_add, torch.empty(1)) diff --git a/deepspeed/ops/transformer/inference/op_binding/base.py b/deepspeed/ops/transformer/inference/op_binding/base.py index 1bdfdeeb5fb1..5a997f95d5cc 100644 --- a/deepspeed/ops/transformer/inference/op_binding/base.py +++ b/deepspeed/ops/transformer/inference/op_binding/base.py @@ -10,11 +10,11 @@ class BaseOp(torch.nn.Module): - inference_cuda_module = None + inference_module = None def __init__(self, config: DeepSpeedInferenceConfig): super(BaseOp, self).__init__() self.config = config - if BaseOp.inference_cuda_module is None: + if BaseOp.inference_module is None: builder = InferenceBuilder() - BaseOp.inference_cuda_module = builder.load() + BaseOp.inference_module = builder.load() diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_add.py b/deepspeed/ops/transformer/inference/op_binding/bias_add.py new file mode 100644 index 000000000000..d2ae38f546eb --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_add.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasAddOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasAddOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_add_func = self.inference_module.bias_add_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_add_func = self.inference_module.bias_add_bf16 + else: + self.bias_add_func = self.inference_module.bias_add_fp32 + except AttributeError: + self.bias_add_func = self.bias_add_fallback + + @classmethod + def bias_add_fallback(cls, input, bias): + return torch.add(input, bias) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_add_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py b/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py new file mode 100644 index 000000000000..f0fee0b0d06e --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasGeluOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasGeluOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_gelu_func = self.inference_module.bias_gelu_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_gelu_func = self.inference_module.bias_gelu_bf16 + else: + self.bias_gelu_func = self.inference_module.bias_gelu_fp32 + except AttributeError: + self.bias_gelu_func = self.bias_gelu_fallback + + @classmethod + def bias_gelu_fallback(cls, activations, bias): + # Expected behavior is that of casting to float32 internally and using the tanh approximation + return F.gelu(activations.to(torch.float32) + bias.to(torch.float32), approximate='tanh').to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_gelu_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_relu.py b/deepspeed/ops/transformer/inference/op_binding/bias_relu.py new file mode 100644 index 000000000000..ccfade1d9524 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_relu.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasReluOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasReluOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_relu_func = self.inference_module.bias_relu_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_relu_func = self.inference_module.bias_relu_bf16 + else: + self.bias_relu_func = self.inference_module.bias_relu_fp32 + except AttributeError: + self.bias_relu_func = self.bias_relu_fallback + + @classmethod + def bias_relu_fallback(cls, activations, bias): + # Expected behavior is that of casting to float32 internally + return F.relu(activations.to(torch.float32) + bias.to(torch.float32)).to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_relu_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_residual.py b/deepspeed/ops/transformer/inference/op_binding/bias_residual.py new file mode 100644 index 000000000000..ecad50e10ffe --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_residual.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasResidualOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasResidualOp, self).__init__(config) + + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.bias_residual_func = self.inference_module.bias_residual_fp16 + else: + self.bias_residual_func = self.inference_module.bias_residual_fp32 + except AttributeError: + self.bias_residual_func = self.bias_residual_fallback + + @classmethod + def bias_residual_fallback(cls, output, residual, bias): + raise NotImplementedError("bias residual fallback isn't implemented") + + def forward(self, output, residual, bias): + return self.bias_residual_func(output, residual, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py b/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py new file mode 100644 index 000000000000..f34b10f786d1 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class EinsumSecSmEcmOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(EinsumSecSmEcmOp, self).__init__(config) + + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.einsum_sec_sm_ecm_func = self.inference_module.einsum_sec_sm_ecm_fp16 + else: + self.einsum_sec_sm_ecm_func = self.inference_module.einsum_sec_sm_ecm_fp32 + except AttributeError: + self.einsum_sec_sm_ecm_func = self.einsum_sec_sm_ecm_fallback + + @classmethod + def einsum_sec_sm_ecm_fallback(cls, Q, W): + raise NotImplementedError("einsum sec sm ecm fallback isn't implemented") + + def forward(self, Q, W): + return self.einsum_sec_sm_ecm_func(Q, W) diff --git a/deepspeed/ops/transformer/inference/op_binding/gated_activation.py b/deepspeed/ops/transformer/inference/op_binding/gated_activation.py new file mode 100644 index 000000000000..d28d818ce4b3 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/gated_activation.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from deepspeed.utils.types import ActivationFuncType +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class GatedActivationOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(GatedActivationOp, self).__init__(config) + try: + self.gated_activation_func = self.inference_module.gated_activation + except AttributeError: + self.gated_activation_func = self.gated_activation_fallback + + @classmethod + def gated_activation_fallback(cls, activation, bias, activation_func_type): + # Expected behavior is that of casting to float32 internally + # Explicitly using the default GeLU + activation_func = None + activations = activation + bias.reshape(1, 1, -1) + hidden_states, gate = activations.chunk(2, dim=-1) + + if activation_func_type == ActivationFuncType.GATED_SILU: + activation_func = F.silu + elif activation_func_type == ActivationFuncType.GATED_GELU: + activation_func = F.gelu + + return hidden_states * activation_func(gate.to(torch.float32)).to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor, activation_func_type: ActivationFuncType): + return self.gated_activation_func(activation, bias, activation_func_type) diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py index 89ef0b517c49..60bbb4b48bdb 100644 --- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py @@ -4,26 +4,49 @@ # DeepSpeed Team import torch +import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class GELUGemmOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(GELUGemmOp, self).__init__(config) - if self.config.fp16: - self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp16 - else: - self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp32 - - def forward(self, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - weight_out: torch.Tensor, - async_op: bool = False): - output = self.fused_gemm_gelu(input, weight, weight.scale, bias, weight_out, weight_out.scale, - self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op, - self.config.transposed_mode) + try: + if self.config.dtype == torch.int8: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_int8 + elif self.config.dtype == torch.float16: + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu + self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore + else: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore + elif self.config.dtype == torch.bfloat16: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore + else: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore + except AttributeError: + self.fused_gemm_gelu = self.gelu_gemm_fallback + + def gelu_gemm_fallback(self, input, weight, scale, bias, out, out_scale, dtype, transpose): + tmp = torch.matmul(input, weight) + tmp = F.gelu(tmp.to(torch.float32) + bias.to(torch.float32), approximate="tanh").to(tmp.dtype) + output = torch.matmul(tmp, out) + + return output + + def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, weight_out: torch.Tensor): + + output = self.fused_gemm_gelu( + input, + weight, + weight.scale if hasattr(weight, 'scale') else torch.empty(1), # type: ignore + bias, + weight_out, + weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), # type: ignore + self.config.dtype == torch.int8, + self.config.transposed_mode) + return output diff --git a/deepspeed/ops/transformer/inference/op_binding/layer_norm.py b/deepspeed/ops/transformer/inference/op_binding/layer_norm.py new file mode 100644 index 000000000000..31219a58ac3c --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/layer_norm.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class LayerNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + super(LayerNormOp, self).__init__(config) + try: + if config is None: + self.layer_norm_func = self.inference_module.layer_norm + elif self.config.dtype in [torch.float16, torch.int8]: + self.layer_norm_func = self.inference_module.layer_norm_fp16 + else: + self.layer_norm_func = self.inference_module.layer_norm_fp32 + except AttributeError: + self.layer_norm_func = self.layer_norm_fallback + + @classmethod + def layer_norm_residual(cls, vals, bias, res, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + bias_f = bias.to(torch.float32).reshape(1, 1, -1) + res_f = res.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return F.layer_norm(vals_f + bias_f + res_f, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + + @classmethod + def layer_norm_residual_store_pre_ln_res(cls, vals, bias, res, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + bias_f = bias.to(torch.float32).reshape(1, 1, -1) + res_f = res.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + res_output = vals_f + bias_f + res_f + norm_output = F.layer_norm(res_output, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + return norm_output, res_output.to(dtype) + + @classmethod + def layer_norm_fallback(cls, vals, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return F.layer_norm(vals_f, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + + def forward(self, vals, gamma, beta, epsilon): + return self.layer_norm_func(vals, gamma, beta, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/linear.py b/deepspeed/ops/transformer/inference/op_binding/linear.py index 9178c5f1fc5b..b8decb6dc5ea 100644 --- a/deepspeed/ops/transformer/inference/op_binding/linear.py +++ b/deepspeed/ops/transformer/inference/op_binding/linear.py @@ -6,16 +6,33 @@ import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class LinearOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(LinearOp, self).__init__(config) - if self.config.fp16: - self.linear_func = self.inference_cuda_module.linear_layer_fp16 - else: - self.linear_func = self.inference_cuda_module.linear_layer_fp32 + try: + if self.config.dtype in [torch.float16, torch.int8]: + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func + self.linear_func = _triton_linear_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.linear_func = self.inference_module.linear_layer_fp16 + self.linear_func = self.inference_module.linear_layer_fp16 + elif self.config.dtype == torch.bfloat16: + self.linear_func = self.inference_module.linear_layer_bf16 + else: + self.linear_func = self.inference_module.linear_layer_fp32 + except AttributeError: + self.linear_func = self.linear_fallback + + def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose, rope_theta): + raise NotImplementedError def forward(self, input: torch.Tensor, @@ -27,5 +44,17 @@ def forward(self, external_cache: bool = None, num_layers: int = None): qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, - self.config.transposed_mode) + self.config.transposed_mode, self.config.rope_theta) return qkv_out + + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py index e7ca40219c34..5f1f915ec021 100644 --- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py @@ -3,27 +3,122 @@ # DeepSpeed Team +from typing import Optional + import torch +import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from deepspeed.utils.types import NormType +from .pre_rms_norm import PreRMSNormOp class MLPGemmOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(MLPGemmOp, self).__init__(config) - if self.config.fp16: - self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp16 + try: + if self.config.norm_type == NormType.LayerNorm: + if self.config.dtype in [ + torch.float16, torch.int8 + ]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops + self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore + elif self.config.dtype == torch.bfloat16: + self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16 + else: + self.mlp_gemm_func = self.inference_module.mlp_gemm_fp32 # type: ignore + elif self.config.norm_type == NormType.RMSNorm: + if self.config.dtype in [torch.float16, torch.int8]: + self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp16 # type: ignore + elif self.config.dtype == torch.bfloat16: + self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_bf16 + else: + self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp32 # type: ignore + except AttributeError: + if self.config.norm_type == NormType.LayerNorm: + self.mlp_gemm_func = self.mlp_gemm_fallback + elif self.config.norm_type == NormType.RMSNorm: + self.mlp_gemm_func = self.rms_mlp_gemm_fallback + self.pre_rms_norm = PreRMSNormOp() + + def mlp_gemm_fallback(self, input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, + pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, + transpose): + if mlp_after_attn: + residual_add = F.layer_norm(input + residual + input_bias, (input.shape[2], ), gamma, beta, eps) + tmp = torch.matmul(residual_add, weight_interm.t() if transpose else weight_interm) + tmp = F.gelu(tmp + bias) + output = torch.matmul(tmp, weight_out.t() if transpose else weight_out) + + return output, residual_add + else: + raise NotImplementedError + + def rms_mlp_gemm_fallback(self, input, residual, weight_interm, weight_out, gamma, eps, interm_scale, out_scale, + dtype, mlp_act_func_type, transpose): + inp_norm, residual = self.pre_rms_norm(input, residual, gamma, eps) + tmp = torch.matmul(inp_norm.view([-1, inp_norm.size(2)]), weight_interm.t() if transpose else weight_interm) + up_proj, gate_proj = tmp.chunk(2, dim=1) + + from deepspeed.utils.types import ActivationFuncType + if mlp_act_func_type == ActivationFuncType.GELU: + intermediate = F.gelu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.ReLU: + intermediate = F.relu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.GATED_GELU: + intermediate = F.gelu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.GATED_SILU: + intermediate = F.silu(gate_proj) + else: + raise f"rms_mlp_gemm_fallback not implemented for activation type {mlp_act_func_type}" + + intermediate = intermediate * up_proj + + output = torch.matmul(intermediate, weight_out.t() if transpose else weight_out) + output = output.view([input.size(0), input.size(1), -1]) + + return [output, residual] + + def forward(self, + input: torch.Tensor, + residual: torch.Tensor, + weight_interm: torch.Tensor, + weight_out: torch.Tensor, + input_bias: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + gamma: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None): + if self.config.norm_type == NormType.LayerNorm: + output, residual_add = self.mlp_gemm_func( + input, + residual, + input_bias, + weight_interm, + weight_out, + bias, + gamma, + beta, + self.config.epsilon, + self.config.pre_layer_norm, + self.config.mlp_after_attn, + weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1), # type: ignore + weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), # type: ignore + self.config.dtype == torch.int8, + self.config.mlp_act_func_type, + self.config.transposed_mode) else: - self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp32 - - def forward(self, input: torch.Tensor, residual: torch.Tensor, input_bias: torch.Tensor, - weight_interm: torch.Tensor, weight_out: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor, - beta: torch.Tensor): - output, residual_add = self.mlp_gemm_func( - input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, self.config.epsilon, - self.config.pre_layer_norm, self.config.mlp_after_attn, - weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1), - weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), self.config.q_int8, - self.config.mlp_act_func_type, self.config.transposed_mode) + if input_bias is not None: + input += input_bias + output, residual_add = self.mlp_gemm_func( + input, + residual, + weight_interm, + weight_out, + gamma, + self.config.epsilon, + weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1), # type: ignore + weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), # type: ignore + self.config.dtype == torch.int8, + self.config.mlp_act_func_type, + self.config.transposed_mode) return output, residual_add diff --git a/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py b/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py new file mode 100644 index 000000000000..ef3558c8bc88 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class MoEResMatmulOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(MoEResMatmulOp, self).__init__(config) + try: + self.moe_res_matmul_func = self.inference_module.moe_res_matmul + except AttributeError: + self.moe_res_matmul_func = self.moe_res_matmul_fallback + + @classmethod + def moe_res_matmul_fallback(cls, residual, coef, output): + coef_t = coef.transpose(1, 2).contiguous() + coef1, coef2 = torch.split(coef_t, split_size_or_sections=coef_t.shape[len(coef_t.shape) - 1] // 2, dim=-1) + return residual * coef1 + output * coef2 + + def forward(self, residual, coef, output): + return self.moe_res_matmul_func(residual, coef, output) diff --git a/deepspeed/ops/transformer/inference/op_binding/pad_transform.py b/deepspeed/ops/transformer/inference/op_binding/pad_transform.py new file mode 100644 index 000000000000..876fefc3bcfb --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/pad_transform.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class PadTransformOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(PadTransformOp, self).__init__(config) + try: + self.pad_transform_func = self.inference_module.pad_transform_fp16 + except AttributeError: + self.pad_transform_func = self.pad_transform_fallback + + @staticmethod + def pad_transform_fallback(query, key, value, heads, do_flash_attn): + raise NotImplementedError("pad_transform fallback is not implemented.") + + def forward(self, query, key, value, heads, do_flash_attn): + return self.pad_transform_func(query, key, value, heads, do_flash_attn) diff --git a/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py b/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py new file mode 100644 index 000000000000..7969d20f0527 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp +from .rms_norm import RMSNormOp + + +class PreRMSNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(PreRMSNormOp, self).__init__(config) + try: + self.pre_rms_norm_func = self.inference_module.pre_rms_norm + except AttributeError: + self.pre_rms_norm_func = self.pre_rms_norm_fallback + + @staticmethod + def pre_rms_norm_fallback(vals, residual, gamma, epsilon): + residual = vals.to(torch.float32) + residual.to(torch.float32) + vals = residual + + return RMSNormOp.rms_norm_fallback(vals, gamma, epsilon), residual.to(gamma.dtype) + + def forward(self, vals, residual, gamma, epsilon): + return self.pre_rms_norm_func(vals, residual, gamma, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py index 6b338b9041d9..9ff5366fae5d 100644 --- a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py @@ -4,35 +4,91 @@ # DeepSpeed Team import torch +import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp -from deepspeed import comm as dist +from .rms_norm import RMSNormOp +import deepspeed +from deepspeed.utils.types import NormType class QKVGemmOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(QKVGemmOp, self).__init__(config) - if self.config.fp16: - self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp16 - else: - self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp32 - - def forward(self, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - gamma: torch.Tensor, - beta: torch.Tensor, - add_bias: bool, - num_layers: int, - num_heads: int = None, - max_out_tokens: int = None): - q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) - external_cache = self.config.bigscience_bloom - rank = dist.get_rank() if dist.is_initialized() else 0 - q_int8 = self.config.q_int8 - output = self.qkv_gemm_func(input, weight, q_scale, bias, gamma, beta, self.config.epsilon, add_bias, - num_layers, external_cache, self.config.mp_size, rank, q_int8, - self.config.transposed_mode) + try: + if self.config.norm_type == NormType.LayerNorm: + if self.config.dtype in [torch.float16, torch.int8]: + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import qkv_gemm_func as _triton_qkv_gemm_func + self.qkv_gemm_func = _triton_qkv_gemm_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore + elif self.config.dtype == torch.bfloat16: + self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16 + else: + self.qkv_gemm_func = self.inference_module.qkv_gemm_fp32 # type: ignore + elif self.config.norm_type == NormType.RMSNorm: + if self.config.dtype in [torch.float16, torch.int8]: + self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp16 # type: ignore + elif self.config.dtype == torch.bfloat16: + self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_bf16 + else: + self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp32 # type: ignore + except AttributeError: + if self.config.norm_type == NormType.LayerNorm: + self.qkv_gemm_func = self.qkv_gemm_fallback + elif self.config.norm_type == NormType.RMSNorm: + self.qkv_gemm_func = self.rms_qkv_gemm_fallback + + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() + + @staticmethod + def qkv_gemm_fallback(input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps) + tmp = torch.matmul(inp_norm, weight.t() if transpose else weight) + if add_bias: + tmp += bias + output = [tmp, inp_norm] + + return output + + @staticmethod + def rms_qkv_gemm_fallback(input, weight, q_scale, gamma, eps, q_int8, transpose): + inp_norm = RMSNormOp.rms_norm_fallback(input, gamma, eps) + tmp = torch.matmul(inp_norm, weight.t() if transpose else weight) + output = [tmp, inp_norm] + return output + + def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor): + + add_bias = bias is not None + bias = bias if add_bias else torch.empty(1) # type: ignore + q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) # type: ignore + q_int8 = self.config.dtype == torch.int8 + + if self.config.norm_type == NormType.LayerNorm: + output, norm = self.qkv_gemm_func(input, weight, q_scale, bias, gamma, beta, self.config.epsilon, add_bias, + q_int8, self.config.transposed_mode) + else: + output, norm = self.qkv_gemm_func(input, weight, q_scale, gamma, self.config.epsilon, q_int8, + self.config.transposed_mode) + if add_bias: + output += bias + + return output, norm diff --git a/deepspeed/ops/transformer/inference/op_binding/residual_add.py b/deepspeed/ops/transformer/inference/op_binding/residual_add.py index e79f5dee5c54..93b229c5d1ac 100644 --- a/deepspeed/ops/transformer/inference/op_binding/residual_add.py +++ b/deepspeed/ops/transformer/inference/op_binding/residual_add.py @@ -4,6 +4,9 @@ # DeepSpeed Team import torch +from typing import Optional + +from .vector_add import VectorAddOp from ..config import DeepSpeedInferenceConfig from .base import BaseOp @@ -12,18 +15,59 @@ class ResidualAddOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(ResidualAddOp, self).__init__(config) - if self.config.fp16 or self.config.q_int8: - self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp16 + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.residual_add_func = self.inference_module.residual_add_bias_fp16 + elif self.config.dtype == torch.bfloat16: + self.residual_add_func = self.inference_module.residual_add_bias_bf16 + else: + self.residual_add_func = self.inference_module.residual_add_bias_fp32 + except AttributeError: + self.residual_add_func = self.residual_add_fallback + self.vector_add = VectorAddOp() + + @staticmethod + def res_add_bias(hidden_state, residual, attn_output, attn_bias, final_bias, add_attn_bias, mp_size): + hidden_state += attn_output + (residual + final_bias) / mp_size + if add_attn_bias: + hidden_state += attn_bias / mp_size + + return hidden_state + + @staticmethod + def residual_add_fallback(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + if mlp_after_attn: + if pre_layer_norm: + tmp = (residual.float() + attention_output.float() + attention_bias.float() + + final_bias.float()) / mp_size + hidden_state.float() + else: + tmp = residual.float() + hidden_state.float() + final_bias.float() else: - self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp32 + tmp = ResidualAddOp.res_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, + add_bias, mp_size) + residual.copy_(tmp.to(hidden_state.dtype)) - def forward(self, hidden_state: torch.Tensor, residual: torch.Tensor, attention_output: torch.Tensor, - attention_bias: torch.Tensor, final_bias: torch.Tensor, add_bias: bool, residual_add: torch.Tensor): + return residual + + def forward(self, + hidden_state: torch.Tensor, + residual: torch.Tensor, + add_bias: bool, + attention_output: Optional[torch.Tensor] = None, + residual_add: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + final_bias: Optional[torch.Tensor] = None): + + if final_bias is None and attention_bias is None: + residual = self.vector_add(residual + attention_output, hidden_state, 1.0 / self.config.mp_size) + else: + if not self.config.pre_layer_norm and residual_add is not None: + # only use residual add if its set and we are not pre layer norm + residual = residual_add - if not self.config.pre_layer_norm and residual_add is not None: - # only use residual add if its set and we are not pre layer norm - residual = residual_add + self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias, + self.config.mp_size, self.config.mlp_after_attn, add_bias, + self.config.pre_layer_norm) - self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias, - self.config.mp_size, self.config.mlp_after_attn, add_bias, self.config.pre_layer_norm) return residual diff --git a/deepspeed/ops/transformer/inference/op_binding/rms_norm.py b/deepspeed/ops/transformer/inference/op_binding/rms_norm.py new file mode 100644 index 000000000000..128883ce5d43 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/rms_norm.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class RMSNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(RMSNormOp, self).__init__(config) + try: + self.rms_norm_func = self.inference_module.rms_norm + except AttributeError: + self.rms_norm_func = self.rms_norm_fallback + + @staticmethod + def rms_norm_fallback(vals, gamma, epsilon): + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + def forward(self, vals, gamma, epsilon): + return self.rms_norm_func(vals, gamma, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax.py b/deepspeed/ops/transformer/inference/op_binding/softmax.py index 529df9ed6181..2e08541596fa 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax.py @@ -4,25 +4,68 @@ # DeepSpeed Team import torch +import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from deepspeed.ops.transformer.inference.op_binding.workspace import InferenceContext class SoftmaxOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(SoftmaxOp, self).__init__(config) - if self.config.fp16: - self.softmax_func = self.inference_cuda_module.softmax_fp16 - else: - self.softmax_func = self._not_implemented + self.num_attention_heads_per_partition = config.heads // config.mp_size + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.softmax_func = self.inference_module.softmax_fp16 + elif self.config.dtype == torch.bfloat16: + self.softmax_func = self.inference_module.softmax_bf16 + else: + self.softmax_func = self.inference_module.softmax_fp32 + except AttributeError: + self.softmax_func = self.softmax_fallback - def _not_implemented(self, *args, **kwargs): - raise NotImplementedError + @staticmethod + def softmax_fallback(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size, async_op, + layer_scale, head_offset, mp_size): + scores_len = len(attn_scores.size()) + heads = 1 + if scores_len > 1: + heads = attn_scores.size()[1] + num_attention_heads_per_partition = heads // mp_size + + if alibi is not None: + if len(alibi.shape) == 1: + alibi = None + else: + alibi = alibi[head_offset:head_offset + num_attention_heads_per_partition] + if attn_mask is not None and len(attn_mask.shape) == 1: + attn_mask = None + input_dtype = attn_scores.dtype + attn_scores *= layer_scale + + if alibi is not None: + attn_scores += alibi + if attn_mask is not None: + # expand atten_mask from two dim into 4 dim, insert two dims in the middle + if len(attn_mask.shape) == 2: + attn_mask = attn_mask[:, None, None, :] + attn_scores += attn_mask + if triangular: + if attn_scores.shape[2] == 1: # query using kv cache + token_idx = InferenceContext.Instance().current_tokens() + tri = torch.arange(attn_scores.shape[2], device=attn_scores.device).ge(token_idx) + else: + tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool) + attn_scores = torch.masked_fill(attn_scores, tri, float('-inf')) + output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype) + + return output def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool, recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float, head_offset: int): output = self.softmax_func(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size, async_op, layer_scale, head_offset, self.config.mp_size) + return output diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py index 1a132982aba6..d745df678e93 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py @@ -7,19 +7,126 @@ from deepspeed import comm as dist from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from .softmax import SoftmaxOp +from deepspeed.ops.transformer.inference.op_binding.workspace import InferenceContext class SoftmaxContextOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(SoftmaxContextOp, self).__init__(config) - if self.config.fp16: - self.softmax_context_func = self.inference_cuda_module.softmax_context_fp16 - else: - self.softmax_context_func = self.inference_cuda_module.softmax_context_fp32 + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.softmax_context_func = self.inference_module.softmax_context_fp16 + elif self.config.dtype == torch.bfloat16: + self.softmax_context_func = self.inference_module.softmax_context_bf16 + else: + self.softmax_context_func = self.inference_module.softmax_context_fp32 + except AttributeError: + self.softmax_context_func = self.softmax_context_fallback + + @staticmethod + def transform4d_0213(x, seq_length): + assert x.dim() == 3, F"Dim {x.dim()} is not supported" + batch_size, num_heads, seq_length_head_dim = x.shape + head_dim = seq_length_head_dim // seq_length + x = x.view(batch_size, num_heads, seq_length, head_dim) + x = x.permute(0, 2, 1, 3) + + return x + + @staticmethod + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep <= 1 or num_key_value_heads == 1: + return hidden_states + + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + @staticmethod + def bias_add_transform_0213(input, bias, num_heads, trans_count, perform_bias=False): + assert trans_count == 1 or trans_count == 3, F"Trans count {trans_count} is not supported" + assert input.dim() == 3, F"Dim {input.dim()} is not supported" + input_biased = torch.add(input, bias) if perform_bias else input + batch_size, seq_length, value_size = input_biased.shape + hid_dim = value_size // trans_count + head_dim = hid_dim // num_heads + + if trans_count == 1: + query_layer = input.view(batch_size, seq_length, num_heads, head_dim) + query_layer = query_layer.permute(0, 2, 1, 3) + key_layer = torch.zeros_like(query_layer) + value_layer = torch.zeros_like(query_layer) + return query_layer, key_layer, value_layer + + qkv_layers = input.view(batch_size, seq_length, 3, num_heads, head_dim) + query_layer, key_layer, value_layer = qkv_layers[..., 0, :, :], qkv_layers[..., 1, :, :], qkv_layers[..., + 2, :, :] + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + return query_layer, key_layer, value_layer + + def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, + num_kv, norm_factor, triangular_masking, local_attention, window_size, no_masking, + layer_id, num_layers, alibi, rope_theta, is_prompt, token_idx, position_ids): + bat_0213_query, bat_0213_key, bat_0213_value = self.bias_add_transform_0213( + query_key_value, None, heads, 3, False) + + if rotary_dim > 0 and rotate_half: + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + rotary = InferenceContext.Instance().get_rotary(rotary_dim, rope_theta, bat_0213_value.device) + cos, sin = rotary(bat_0213_value, InferenceContext.Instance().get_max_tokens_num()) + bat_0213_query, bat_0213_key = apply_rotary_pos_emb(bat_0213_query, bat_0213_key, cos, sin, position_ids) - def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, norm_factor: float, - no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor): + bat_0213_key, bat_0213_value = InferenceContext.Instance().update_cache(layer_id, token_idx, is_prompt, + bat_0213_key, bat_0213_value) + + bat_0213_key = self.repeat_kv(bat_0213_key, num_kv) + bat_0213_value = self.repeat_kv(bat_0213_value, num_kv) + + bsz = query_key_value.shape[0] + head_dim = query_key_value.shape[2] // (heads * 3) + + bmm_output = torch.bmm(bat_0213_query.reshape(bsz * heads, bat_0213_query.shape[2], head_dim), + bat_0213_key.reshape(bsz * heads, bat_0213_key.shape[2], head_dim).transpose(1, 2)) + + layer_scale = 1.0 + if alibi is not None and len(alibi.shape) > 1: + layer_scale = max(1, layer_id).to(float) + + alpha = norm_factor * norm_factor / layer_scale + bmm_output *= alpha + bmm_output_reshape = bmm_output.reshape(bsz, heads, bmm_output.shape[1], bmm_output.shape[2]) + + recompute = is_prompt + if attn_mask is not None and len(attn_mask.shape) > 1 and attn_mask.shape[-1] < bmm_output_reshape.shape[3]: + attn_mask = torch.nn.functional.pad(attn_mask, (0, bmm_output_reshape.shape[3] - attn_mask.shape[-1]), + value=torch.finfo(attn_mask.dtype).min) + softmax_output = SoftmaxOp.softmax_fallback(bmm_output_reshape, attn_mask, alibi, triangular_masking, + recompute, local_attention, window_size, None, layer_scale, 0, 1) + + output = torch.bmm(softmax_output.reshape(bsz * heads, softmax_output.shape[2], softmax_output.shape[3]), + bat_0213_value.reshape(bsz * heads, bat_0213_value.shape[2], head_dim)) + + output = output.reshape(bsz, heads, output.shape[1], head_dim) + output = output.reshape(bsz, heads, output.shape[2] * head_dim) + input_seq_len = query_key_value.shape[1] + t4d_0123_output = self.transform4d_0213(output, input_seq_len) + t4d_0123_output = t4d_0123_output.reshape(bsz, t4d_0123_output.shape[1], heads * head_dim) + + if layer_id == num_layers - 1: + InferenceContext.Instance().advance_tokens() + + return t4d_0123_output, bat_0213_key, bat_0213_value + + def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, num_kv: int, + norm_factor: float, no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor, + is_prompt: bool, token_idx: torch.Tensor, position_ids: torch.Tensor): if alibi is not None: batch_heads = query_key_value.shape[0] * heads @@ -29,7 +136,9 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: alibi = torch.empty(1) output = self.softmax_context_func(query_key_value, attn_mask, self.config.rotary_dim, self.config.rotate_half, - self.config.rotate_every_two, heads, norm_factor, + self.config.rotate_every_two, heads, num_kv, norm_factor, self.config.triangular_masking, self.config.local_attention, - self.config.window_size, no_masking, layer_id, num_layers, alibi) + self.config.window_size, no_masking, layer_id, num_layers, alibi, + self.config.rope_theta, is_prompt, token_idx, position_ids) + return output diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_add.py b/deepspeed/ops/transformer/inference/op_binding/vector_add.py new file mode 100644 index 000000000000..015340a1084b --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/vector_add.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class VectorAddOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(VectorAddOp, self).__init__(config) + try: + self.vector_add_func = self.inference_module._vector_add + except AttributeError: + self.vector_add_func = self.vector_add_fallback + + @classmethod + def vector_add_fallback(cls, a, b, gamma): + """Based on csrc/transformer/inference/csrc/pt_binding.cpp code of _vector_add""" + dtype = a.dtype + return (gamma * a.float() + b.float()).to(dtype) + + def forward(self, a, b, gamma): + return self.vector_add_func(a, b, gamma) diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py index f916020baa9e..cabab8d8c4ab 100644 --- a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py +++ b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py @@ -6,19 +6,49 @@ import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class VectorMatMulOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(VectorMatMulOp, self).__init__(config) - if self.config.fp16: - self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp16 - else: - self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32 + try: + if self.config.dtype == torch.float16: + if deepspeed.HAS_TRITON and config.use_triton: + from deepspeed.ops.transformer.inference.triton.ops import vector_matmul_func as _triton_vector_matmul_func + self.vector_matmul_func = _triton_vector_matmul_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.vector_matmul_func = self.inference_module.vector_matmul_fp16 + elif self.config.dtype == torch.int8: + self.vector_matmul_func = self.inference_module.vector_matmul_int8 + elif self.config.dtype == torch.bfloat16: + self.vector_matmul_func = self.inference_module.vector_matmul_bf16 + else: + self.vector_matmul_func = self.inference_module.vector_matmul_fp32 + except AttributeError: + self.vector_matmul_func = self.vector_matmul_fallback + + def vector_matmul_fallback(self, input, weight, async_op, q_scale, q_int8, transpose): + return torch.matmul(input, weight.t() if transpose else weight) def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False): q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) - q_int8 = self.config.q_int8 + q_int8 = self.config.dtype == torch.int8 output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode) return output + + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/op_binding/workspace.py b/deepspeed/ops/transformer/inference/op_binding/workspace.py new file mode 100644 index 000000000000..19de7d9576af --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/workspace.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + +minus_inf = -10000.0 +key_idx = 0 +value_idx = 1 + + +class InferenceContext: + + __instance = None + + def __init__(self): + self.kv_cache = None + self.kv_cache_elem_dtype = None + self.num_tokens = 1 + self.kv_cache_num_layers = None + self.kv_cache_size = None + self.max_out_tokens = None + self.rotary = None + self.allocate_called = False + self.static_shapes = True + + @classmethod + def Instance(cls): + if InferenceContext.__instance is None: + InferenceContext.__instance = InferenceContext() + return InferenceContext.__instance + + def gen_workspace(self, num_layers, num_heads, batch_size, prompt_len, hidden_dim, mp_size, external_cache, + elem_dtype, rank, max_out_tokens, min_out_tokens): + self.allocate_called = True + self.kv_cache = None + if not external_cache: + self.kv_cache_num_layers = num_layers + self.max_out_tokens = max_out_tokens + head_size = hidden_dim // num_heads + self.kv_cache_size = torch.Size([batch_size, (num_heads // mp_size), max_out_tokens, head_size]) + self.kv_cache_elem_dtype = elem_dtype + self.num_tokens = 0 + self.static_shapes = True + return True + + def retake_workspace(self): + return True + + def _retake_workspace(self): + assert self.allocate_called, "retake workspace called before allocate workspace" + + import deepspeed.accelerator as accelerator + if self.kv_cache is None: + self.kv_cache = [] + for layer in range(self.kv_cache_num_layers): + self.kv_cache.append((torch.zeros(self.kv_cache_size, + dtype=self.kv_cache_elem_dtype, + device=accelerator.get_accelerator().device_name()), + torch.zeros(self.kv_cache_size, + dtype=self.kv_cache_elem_dtype, + device=accelerator.get_accelerator().device_name()))) + + return True + + def update_cache(self, layer_id, token_idx, is_prompt, bat_0213_key, bat_0213_value): + has_workspace = self._retake_workspace() + assert has_workspace, "Could not allocate workspace" + + # Update current token + if is_prompt: + self.static_shapes = True + if token_idx is None: + self.static_shapes = False + InferenceContext.Instance().reset_tokens(bat_0213_key.shape[2]) + else: + InferenceContext.Instance().reset_tokens(token_idx) + + if token_idx is None: + token_idx = InferenceContext.Instance().current_tokens() + + bsz = bat_0213_key.shape[0] + + # Update cache content + if is_prompt: + cache_max_seq = self.kv_cache_size[2] + cache_max_head_dim = self.kv_cache_size[3] + seq = bat_0213_key.shape[2] + + mask = torch.arange(cache_max_seq, device=bat_0213_key.device) + mask = mask.ge(token_idx) + mask = mask.unsqueeze(-1) + mask = mask.expand([cache_max_seq, cache_max_head_dim]) + + self.kv_cache[layer_id][key_idx][:bsz, :, :seq, :].copy_(bat_0213_key) + self.kv_cache[layer_id][key_idx][:bsz, :].masked_fill_(mask, 0) + self.kv_cache[layer_id][value_idx][:bsz, :, :seq, :].copy_(bat_0213_value) + self.kv_cache[layer_id][value_idx][:bsz, :].masked_fill_(mask, 0) + else: + if self.static_shapes: + assert type(token_idx) == torch.Tensor, "token_idx is expected to be torch.Tensor" + self.kv_cache[layer_id][key_idx][:bsz].index_copy_(2, token_idx - 1, bat_0213_key) + self.kv_cache[layer_id][value_idx][:bsz].index_copy_(2, token_idx - 1, bat_0213_value) + else: + assert type(token_idx) == int, "token_idx is expected to be int" + self.kv_cache[layer_id][key_idx][:bsz, :, token_idx - 1:token_idx, :] = bat_0213_key + self.kv_cache[layer_id][value_idx][:bsz, :, token_idx - 1:token_idx, :] = bat_0213_value + + bat_0213_key = self.kv_cache[layer_id][key_idx][:bsz] + bat_0213_value = self.kv_cache[layer_id][value_idx][:bsz] + + if not self.static_shapes: + bat_0213_key = bat_0213_key[:, :, :token_idx, :] + bat_0213_value = bat_0213_value[:, :, :token_idx, :] + + return bat_0213_key, bat_0213_value + + def release_workspace(self): + self.kv_cache = None + self.rotary = None + + def reset_tokens(self, initial_tokens=1): + self.num_tokens = initial_tokens + + def current_tokens(self): + return self.num_tokens + + def advance_tokens(self): + self.num_tokens = self.num_tokens + 1 + + def get_kv_cache(self): + return self.kv_cache + + def get_rotary(self, rotary_dim, rope_theta, device=None): + if self.rotary is None: + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + + self.rotary = LlamaRotaryEmbedding(rotary_dim, base=rope_theta, device=device) + + return self.rotary + + def get_max_tokens_num(self): + return self.max_out_tokens + + +class WorkspaceOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + self.inference_context = InferenceContext.Instance() + self._is_allocated = False + try: + super(WorkspaceOp, self).__init__(config) + if config.dtype == torch.float32: + self.allocate_workspace_func = self.inference_module.allocate_workspace_fp32 + elif config.dtype == torch.bfloat16: + self.allocate_workspace_func = self.inference_module.allocate_workspace_bf16 + else: + self.allocate_workspace_func = self.inference_module.allocate_workspace_fp16 + self.release_workspace_func = self.inference_module.release_workspace + self.retake_workspace_func = self.inference_module.retake_workspace + self.reset_cache_func = self.inference_module.reset_cache + except (ValueError, AttributeError) as e: + print(f"Using fallback functions in workspace because of {e}") + if config.dtype == torch.float32: + self.allocate_workspace_func = self.allocate_workspace_fp32_fallback + elif config.dtype == torch.bfloat16: + self.allocate_workspace_func = self.allocate_workspace_bf16_fallback + else: + self.allocate_workspace_func = self.allocate_workspace_fp16_fallback + self.release_workspace_func = self.release_workspace_fallback + self.retake_workspace_func = self.retake_workspace_fallback + self.reset_cache_func = self.reset_cache_fallback + + def allocate_workspace(self, *args, **kwargs): + self._is_allocated = True + return self.allocate_workspace_func(*args, **kwargs) + + def release_workspace(self): + self._is_allocated = False + return self.release_workspace_func() + + def reset_cache(self): + return self.reset_cache_func() if self.reset_cache_func else None + + def retake_workspace(self): + return self.retake_workspace_func() if self.retake_workspace_func else None + + def allocate_workspace_fp32_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.float, rank, max_out_tokens, + min_out_tokens) + + def allocate_workspace_bf16_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.bfloat16, rank, max_out_tokens, + min_out_tokens) + + def allocate_workspace_fp16_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.half, rank, max_out_tokens, + min_out_tokens) + + def reset_cache_fallback(self): + return self.inference_context.reset_tokens() + + def release_workspace_fallback(self): + return self.inference_context.release_workspace() + + def retake_workspace_fallback(self): + return self.inference_context.retake_workspace() + + def is_allocated(self): + return self._is_allocated diff --git a/deepspeed/ops/transformer/inference/triton/__init__.py b/deepspeed/ops/transformer/inference/triton/__init__.py new file mode 100755 index 000000000000..b7d1968df62a --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .residual_add import residual_add_bias +from .layer_norm import layer_norm, layer_norm_residual +from .gelu import gelu +from .softmax import softmax +from .ops import * +from .matmul_ext import fp16_matmul, matmul_4d, score_4d_matmul, context_4d_matmul diff --git a/deepspeed/ops/transformer/inference/triton/attention.py b/deepspeed/ops/transformer/inference/triton/attention.py new file mode 100644 index 000000000000..023ef767a7b2 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/attention.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import math +import torch +import torch.nn as nn +import triton +import triton.language as tl +from deepspeed.accelerator import get_accelerator +from deepspeed import comm as dist +from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp +from deepspeed.ops.transformer.inference.triton import ( + softmax, + score_4d_matmul, + context_4d_matmul, +) + +minus_inf = -10000.0 + + +class TritonSelfAttention(nn.Module): + num_layers = 0 + + def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False): + super(TritonSelfAttention, self).__init__() + self.config = config + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype + assert data_type_fp == torch.half, "triton supports fp16 data_type_fp" + + self.config.layer_id = TritonSelfAttention.num_layers + TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1 + device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu' + + assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet" + if self.config.set_empty_params: + self.attn_qw = None + self.attn_qb = None + self.attn_kw = None + self.attn_kb = None + self.attn_vw = None + self.attn_vb = None + self.attn_qkvw = None + self.attn_qkvb = None + self.attn_ow = None + self.attn_ob = None + else: + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + # self-ouput weights + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + + self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size + self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size + self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads + + self.mp_group = mp_group + self.use_flash = False + # triton flash attention is enabled when the compute capability >= 8.0 + if get_accelerator().is_triton_supported(): + self.use_flash = True + + # used for quantization + self.q_scales = q_scales + self.q_groups = q_groups + self.merge_count = int(math.log2(merge_count)) + + self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads) + if not config.use_mup: + self.norm_factor = math.sqrt(self.norm_factor) + + if self.config.scale_attn_by_inverse_layer_idx is True: + self.norm_factor *= math.sqrt(self.config.layer_id + 1) + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191 + + triton_autotune = self.config.triton_autotune and self.config.layer_id == 0 + self.qkv_func = QKVGemmOp(config) + self.score_context_func = SoftmaxContextOp(config) + self.linear_func = LinearOp(config) + self.vector_matmul_func = VectorMatMulOp(config) + + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // config.heads + self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0 + ) # making it back to 1/sqrt(head_size) + self.triangular_masking = self.config.triangular_masking + + # triton autotune table update for score/context matmul + if triton_autotune: + print("running triton autotune for regular attention kernel") + __class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size, + self.triangular_masking, self.scale) + + @staticmethod + def _triton_autotune(min_seqlen, + max_seqlen, + head_size, + hidden_size, + triangular_masking, + scale, + dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda') + output = score_4d_matmul(qkv, head_size, triangular_masking, scale) + context_4d_matmul(output, qkv, head_size) + Fp16Matmul._update_autotune_table() + + def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): + if isinstance(qkv_out, list): + qkv_out = qkv_out[0] + + no_masking = input_mask is None + + if no_masking: + input_mask = torch.empty(1) + + attn_key_value = self.score_context_func( + query_key_value=qkv_out, + attn_mask=((1 - input_mask).to(qkv_out.dtype) * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + heads=self.num_attention_heads_per_partition, + norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0), + no_masking=no_masking, + layer_id=self.config.layer_id, + num_layers=TritonSelfAttention.num_layers, + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) + + context_layer, key_layer, value_layer = attn_key_value + return context_layer, key_layer, value_layer + + def forward( + self, + input, + input_mask, + head_mask=None, + layer_past=None, + get_present=False, # not used + encoder_hidden_states=None, # not used + encoder_attention_mask=None, # not used + triangularutput_attentions=False, # not used + norm_w=None, + norm_b=None, + alibi=None, + use_triton_attention=True, + **kwargs): + + if not self.config.pre_layer_norm: + qkv_out = self.linear_func(input=input, + weight=self.attn_qkvw, + bias=self.attn_qkvb, + add_bias=self.attn_qkvb is not None, + do_flash_attn=False, + num_heads=self.num_attention_heads_per_partition, + num_layers=TritonSelfAttention.num_layers) + qkv = qkv_out + else: + qkv_out = self.qkv_func(input=input, + weight=self.attn_qkvw, + bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b), + gamma=norm_w, + beta=norm_b) + qkv = qkv_out[0] + + if use_triton_attention and (alibi is None): + context_layer = _triton_attention(qkv=qkv, + input_mask=input_mask, + scale=self.scale, + layer_past=layer_past, + alibi=alibi, + head_size=self.head_size, + use_triton_flash=self.use_flash, + use_cuda_flash=False, + triangular=self.triangular_masking) + key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:] + else: + is_prompt = kwargs.get("first_token", qkv_out[0].shape[1] > 1) + token_idx = kwargs.get("token_idx", None) + position_ids = kwargs.get("position_ids", None) + context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out, + input_mask=input_mask, + layer_past=layer_past, + alibi=alibi, + is_prompt=is_prompt, + toke_idx=token_idx, + position_ids=position_ids) + output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) + + inp_norm = qkv_out[-1] + + if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: + dist.all_reduce(output, group=self.mp_group) + + return (output, key_layer, value_layer, context_layer, inp_norm) + + +global inference_module + + +def _triton_attention(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_cuda_flash=False, + use_triton_flash=False, + use_ds_attention=False): + if isinstance(qkv, list): + qkv = qkv[0] + + assert alibi is None, "layer_past not supported in alibi yet" + + if use_triton_flash: + output = _triton_packed_flash(qkv, + head_size, + input_mask, + scale, + causal=triangular, + add_mask=(not triangular and input_mask is not None)) + else: + output = score_4d_matmul(qkv, head_size, triangular, scale) + if triangular: + output = softmax(output) + else: + output = softmax(output, input_mask) + output = context_4d_matmul(output, qkv, head_size) + + return output + + +''' +flash attention 2 +modified the triton kernel in +https://github.com/openai/triton/blob/08c16589573621fcb8cd5a9c3b8a0537077f876d/python/tutorials/06-fused-attention.py +''' + + +@triton.jit +def _flash_packed_kernel( + QKV, + mask, + ADD_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + sm_scale, + Out, + stride_qz, + stride_qn, + stride_qm, + stride_mz, + stride_oz, + stride_on, + Z, + H, + N_CTX, + P_SEQ, + hidden_size, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + batch = off_hz // H + head = off_hz % H + + q_offset = batch * stride_qz + head * BLOCK_DMODEL + k_offset = q_offset + hidden_size + v_offset = k_offset + hidden_size + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :] + k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :] + v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :] + + # mask + off_mask = batch * stride_mz + offs_n[None, :] + mask_ptrs = mask + off_mask + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0) + q = (q * qk_scale).to(tl.float16) + # loop over k, v and update accumulator + lo = 0 + hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0) + v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16) + + if ADD_MASK: + mask_val = tl.load(mask_ptrs) + mask_ptrs += BLOCK_N + qk = qk + mask_val.to(tl.float32) + + if IS_CAUSAL: + qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16) + qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back l and m + acc = acc / l_i[:, None] + o_offset = batch * stride_oz + head * BLOCK_DMODEL + out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :]) + tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX) + + +def _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True): + heads = qkv.shape[-1] // 3 // head_size + hidden_size = qkv.shape[-1] // 3 + + BLOCK_M = 128 + BLOCK_N = 64 if head_size <= 64 else 32 + + o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half) + if mask is None: + mask = torch.empty(0) + add_mask = False + + grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1) + num_stages = 4 if head_size <= 64 else 3 + num_warps = 4 + P_SEQ = 0 + + _flash_packed_kernel[grid](qkv, + mask, + add_mask, + causal, + sm_scale, + o, + qkv.stride(0), + qkv.stride(1), + qkv.stride(2), + mask.stride(1) if add_mask else 0, + o.stride(0), + o.stride(1), + qkv.shape[0], + heads, + qkv.shape[1], + P_SEQ, + hidden_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=head_size, + num_warps=num_warps, + num_stages=num_stages) + + return o diff --git a/deepspeed/ops/transformer/inference/triton/gelu.py b/deepspeed/ops/transformer/inference/triton/gelu.py new file mode 100644 index 000000000000..738d7d96a1c9 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/gelu.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +from deepspeed.accelerator import get_accelerator + + +@triton.jit +def gelu_functor(x): + # Using approximation introduces greater parity errors. + # return tl.sigmoid(1.702 * x) * x + return x * 0.5 * (1.0 + tl.math.erf(x / 1.41421356237)) + + +@triton.jit +def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = gelu_functor(x) + tl.store(output_ptr + offsets, output, mask=mask) + + +def gelu(activations: torch.Tensor) -> torch.Tensor: + assert activations.is_contiguous() + assert get_accelerator().on_accelerator(activations) + + output = torch.empty_like(activations) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024) + return output diff --git a/deepspeed/ops/transformer/inference/triton/layer_norm.py b/deepspeed/ops/transformer/inference/triton/layer_norm.py new file mode 100644 index 000000000000..d3f313d2ac3d --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/layer_norm.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +''' +layer-normalization +modified the triton kernel in +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/05-layer-norm.py +''' + + +@triton.jit +def layer_norm_kernel( + Out, + A, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # # write-back + tl.store(Out + cols, out, mask=mask) + + +@triton.jit +def layer_norm_residual_kernel( + Out, + A, + Residual, + ln_input, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + Residual += row * stride + ln_input += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32) + a = a + res + tl.store(ln_input + cols, a, mask=cols < N) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # write-back + tl.store(Out + cols, out, mask=mask) + + +@triton.jit +def layer_norm_residual_bias_kernel( + Out, + A, + Residual, + InputBias, + ln_input, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + Residual += row * stride + ln_input += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32) + b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32) + a = a + b + res + tl.store(ln_input + cols, a, mask=cols < N) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # write-back + tl.store(Out + cols, out, mask=mask) + + +def layer_norm(a, weight, bias, eps): + assert a.is_contiguous() + assert weight.is_contiguous() + assert bias.is_contiguous() + + # allocate output + out = torch.empty_like(a) + # reshape input data into 2D tensor + a_arg = a.view(-1, a.shape[-1]) + M, N = a_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // a.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192 + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + layer_norm_kernel[(M, )]( + out, + a_arg, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return out + + +def layer_norm_residual(a, input_bias, residual, weight, bias, eps): + assert a.is_contiguous() + assert weight.is_contiguous() + assert bias.is_contiguous() + assert residual.is_contiguous() + + # allocate output and scratch-pad for residual addition + out = torch.empty_like(a) + ln_input = torch.empty_like(a) + # reshape input data into 2D tensor + a_arg = a.view(-1, a.shape[-1]) + residual = residual.view(-1, residual.shape[-1]) + M, N = a_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // a.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192 + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + if input_bias is None: + layer_norm_residual_kernel[(M, )]( + out, + a_arg, + residual, + ln_input, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + else: + layer_norm_residual_bias_kernel[(M, )]( + out, + a_arg, + residual, + input_bias, + ln_input, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return out diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py new file mode 100644 index 000000000000..f8d11ce6fdf3 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py @@ -0,0 +1,489 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import os +from filelock import FileLock +import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel +import pickle +from io import open +import deepspeed +from pathlib import Path +import atexit +import subprocess + + +# ----------------------------------------------------------------------------- +# util class/functions for triton +def is_nfs_path(path): + if os.name == 'nt': + return False + + # Normalize the path to get the absolute path + path = os.path.abspath(path) + + # Walk up to the nearest existing ancestor so 'df' does not fail + # when the target directory has not been created yet (see #7642). + while not os.path.exists(path): + parent = os.path.dirname(path) + if parent == path: + break + path = parent + + # Use the 'df' command to find the file system type for the given path + try: + output = subprocess.check_output(['df', '-T', path], encoding='utf-8', stderr=subprocess.DEVNULL) + except (subprocess.CalledProcessError, FileNotFoundError): + return False # Command failed or 'df' not available + + # Process the output of 'df -T' to check for 'nfs' in the filesystem type column + lines = output.strip().split('\n') + if len(lines) > 1: # The first line is headers + fs_type = lines[1].split()[1].lower() # File system type is the second column + return 'nfs' in fs_type + return False + + +class TritonCacheDir: + _warning_printed = False + + @staticmethod + def warn_if_nfs(cache_dir): + if is_nfs_path(cache_dir) and not TritonCacheDir._warning_printed: + print( + f"Warning: The cache directory for DeepSpeed Triton autotune, {cache_dir}, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path." + ) + TritonCacheDir._warning_printed = True + return + + @staticmethod + def default_cache_dir(): + tt_home = os.environ.get('TRITON_HOME') or os.path.join(Path.home(), ".triton") + tmp_path = os.path.join(tt_home, "autotune") + return tmp_path + + +def bias_add_activation(C, bias=None, activation=""): + if bias is not None: + C += bias + # activation + if activation == "relu": + relu = torch.nn.Relu() + C = relu(C) + elif activation == "leaky_relu": + leaky_relu = torch.nn.LeakyReLU(0.01) + C = leaky_relu(C) + elif activation == "gelu": + sigmoid = torch.nn.Sigmoid() + C = sigmoid(1.702 * C) * C + elif activation == "sigmoid": + sigmoid = torch.nn.Sigmoid() + C = sigmoid(C) + return C + + +class AutotuneCacheManager: + """ + Cache manager for autotune + """ + + def __init__(self, key): + self.key = key + self.file_path = None + self.lock_path = None + # if caching is enabled, get the lock and bin path + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', TritonCacheDir.default_cache_dir()) + TritonCacheDir.warn_if_nfs(self.cache_dir) + if self.cache_dir: + os.makedirs(self.cache_dir, exist_ok=True) + self.file_path = os.path.join(self.cache_dir, self.key + ".pickle") + self.lock_path = self.file_path + ".lock" + + def has_file(self): + return self.file_path and os.path.exists(self.file_path) + + def put(self, table): + if self.file_path: + assert self.lock_path is not None + with FileLock(self.lock_path): + with open(self.file_path + ".tmp", 'wb') as handle: + pickle.dump(table, handle) + os.replace(self.file_path + ".tmp", self.file_path) + + def load(self): + if os.path.exists(self.file_path): + with open(self.file_path, 'rb') as handle: + loaded_dict = pickle.load(handle) + return loaded_dict + else: + return None + + +# ----------------------------------------------------------------------------- +# triton matmul class + + +class MatmulExt(torch.autograd.Function): + """ + a wrapper class that can call different triton matmul kernels depending on the input parameters + """ + + @staticmethod + def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False): + """ + A: input, activation matrix A + B: input, weight matrix B + """ + matmul = None + quantize_activation = False + Batch = 0 + + if len(A.shape) == 3: # if A is 3d-tensor where batch index is given as 0-axis + assert A.is_contiguous(), "matrix A must be contiguous" + Batch, M, K = A.shape + A = A.view(-1, K) + + # fp16 activation and fp16 weight matmul into fp16 output + matmul = fp16_matmul + C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation) + + if matmul and update_autotune_table: + matmul._update_autotune_table() + + if Batch > 0: + C = C.view(Batch, M, -1) + + return C + + +class TritonMatmul(torch.autograd.Function): + """ + triton matmul kernel superclass + """ + + def __init__(self): + pass + + @staticmethod + def _ref_forward(A, B, ref_dtype=torch.float32): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + return C + + @staticmethod + def _read_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + table = cache_manager.load() + if table: + triton_kernel.cache = table + + @staticmethod + def _write_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + cache_manager.put(triton_kernel.cache) + + @staticmethod + def _update_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + autotune_table = cache_manager.load() + if autotune_table is None: + autotune_table = dict() + autotune_table.update(triton_kernel.cache) # always overwrite with the new autotune results + cache_manager = AutotuneCacheManager(cache_key) + cache_manager.put(autotune_table) + + @staticmethod + def forward( + A, + B, + ref_dtype=torch.float32, # fp32 only + bias=None, + activation=""): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + C = bias_add_activation(C, bias, activation) + return C + + +class Fp16Matmul(TritonMatmul): + """ + fp16 matrix multiplication kernel + dtypes: fp16 x fp16 = fp16 + """ + + _2d_kernel = triton_matmul_kernel._fp_matmul + _4d_kernel = triton_matmul_kernel.matmul_4d_kernel + _cache_stride = 32 + + def __init__(self, read_cache=True): + super().__init__() + if read_cache: + __class__._read_autotune_table() + + def skip_autotune(self): + __class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]] + __class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]] + + @staticmethod + def forward(A, B, use_triton=True, bias=None, activation=""): + if use_triton: + device = A.device + # handle non-contiguous inputs if necessary + if A.stride(0) > 1 and A.stride(1) > 1: + A = A.contiguous() + if B.stride(0) > 1 and B.stride(1) > 1: + B = B.contiguous() + # checks constraints + assert A.shape[1] == B.shape[0], "incompatible dimensions" + M, K = A.shape + _, N = B.shape + # allocates output + C = torch.empty((M, N), device=device, dtype=A.dtype) + # accumulator types + ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32 + ] else triton.language.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + __class__._2d_kernel[grid](A, + B, + C, + M, + N, + K, + bias, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + BIAS_ADD=(0 if bias is None else 1), + ACTIVATION=activation) + else: + C = torch.matmul(A, B) + return C + + @staticmethod + def _matmul_4d(a, b): + assert a.shape[-1] == b.shape[-2], "incompatible dimensions" + assert a.is_contiguous(), "matrix A must be contiguous" + assert b.is_contiguous(), "matrix B must be contiguous" + + B, H, M, K = a.shape + B, H, K, N = b.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + H, + B, + ) + + __class__._4d_kernel[grid]( + a, + b, + c, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + a.stride(0), + a.stride(1), + a.stride(2), + a.stride(3), + b.stride(0), + b.stride(1), + b.stride(2), + b.stride(3), + c.stride(0), + c.stride(1), + c.stride(2), + c.stride(3), + scale=-1.0, + MASK=False, + ) + return c + + @staticmethod + def _score_4d_matmul(input, head_size, input_mask, scale=-1.0): + assert input.is_contiguous(), "matrix input must be contiguous" + + batches = input.shape[0] + d_model = input.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = input[:, :, :d_model] + k = input[:, :, d_model:d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + + # checks constraints + assert q.shape == k.shape, "incompatible dimensions" + B, M, H, K = q.shape + B, N, H, K = k.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + # allocates output + output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + H, + B, + ) + __class__._4d_kernel[grid]( + q, + k, + output, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + scale=scale, + MASK=False, + ) + return output + + @staticmethod + def _context_4d_matmul(prob, input, head_size): + assert prob.is_contiguous(), "matrix prob must be contiguous" + assert input.is_contiguous(), "matrix input must be contiguous" + + batches = input.shape[0] + d_model = input.shape[-1] // 3 + num_of_heads = d_model // head_size + + v = input[:, :, d_model * 2:] + + v = v.view(batches, -1, num_of_heads, head_size) + + # checks constraints + assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1] + and prob.shape[3] == v.shape[1]), "incompatible dimensions" + B, H, M, K = prob.shape + B, K, H, N = v.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + # allocates output + output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + H, + B, + ) + + __class__._4d_kernel[grid]( + prob, + v, + output, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + prob.stride(0), + prob.stride(1), + prob.stride(2), + prob.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + # Here we also transpose the output when writing to memory. + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + scale=-1, + MASK=False, + ) + return output.view(batches, -1, d_model) + + @staticmethod + def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + C = bias_add_activation(C, bias, activation) + return C + + @staticmethod + def _check_parity(A, + B, + output_dtype, + SA=None, + SB=None, + qblock_size=None, + ref_dtype=torch.float32, + tol=0.01, + use_triton=True, + bias=None, + activation=""): + torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation) + triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation) + assert torch.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), rtol=tol) + print(f"{__class__.__name__}: PASSed the parity check") + return triton_output, torch_output + + @staticmethod + def _read_autotune_table(): + TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + @staticmethod + def _write_autotune_table(): + TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + @staticmethod + def _update_autotune_table(): + TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + +# ----------------------------------------------------------------------------- +# mapping +if deepspeed.HAS_TRITON: + fp16_matmul = Fp16Matmul() + matmul = MatmulExt.forward + matmul_4d = fp16_matmul._matmul_4d + score_4d_matmul = fp16_matmul._score_4d_matmul + context_4d_matmul = fp16_matmul._context_4d_matmul +else: + fp16_matmul = None + matmul = None + matmul_4d = None + score_4d_matmul = None + context_4d_matmul = None + + +@atexit.register +def matmul_ext_update_autotune_table(): + if deepspeed.HAS_TRITON: + fp16_matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/triton/mlp.py b/deepspeed/ops/transformer/inference/triton/mlp.py new file mode 100644 index 000000000000..1708080b27ef --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/mlp.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +from deepspeed.accelerator import get_accelerator +from deepspeed import comm as dist +from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp + + +class TritonMLP(nn.Module): + + def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False): + super(TritonMLP, self).__init__() + + self.config = config + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype + device = get_accelerator().current_device_name() + self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + intm_size_per_partition = self.config.intermediate_size // self.config.mp_size + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, + intm_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + + # used for quantization + self.q_scales = q_scales + self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups + self.merge_count = int(math.log2(merge_count)) + self.mp_group = mp_group + + self.mlp_gemm_func = MLPGemmOp(config) + self.vector_matmul_func = VectorMatMulOp(config) + self.fused_gemm_gelu = GELUGemmOp(config) + self.residual_add_func = ResidualAddOp(config) + + def forward(self, input, residual, residual_norm, bias): + residual_add = None + if self.attn_nw is None: + output = self.fused_gemm_gelu(input=residual_norm, + weight=self.inter_w, + bias=self.inter_b, + weight_out=self.output_w) + else: + output, residual_add = self.mlp_gemm_func(input=input, + residual=residual, + input_bias=bias, + weight_interm=self.inter_w, + weight_out=self.output_w, + bias=self.inter_b, + gamma=self.attn_nw, + beta=self.attn_nb) + residual = self.residual_add_func(hidden_state=output, + residual=residual, + attention_output=input, + attention_bias=bias if bias is not None else self.output_b, + final_bias=self.output_b, + add_bias=bias is not None, + residual_add=residual_add) + + if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: + dist.all_reduce(residual, group=self.mp_group) + + return residual diff --git a/deepspeed/ops/transformer/inference/triton/ops.py b/deepspeed/ops/transformer/inference/triton/ops.py new file mode 100644 index 000000000000..dbed45313780 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/ops.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp +from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual +from deepspeed.utils.types import ActivationFuncType + + +def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode): + assert not transposed_mode and not async_op and not q_int8 + return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True) + + +def fused_gemm_gelu(input, + weight, + weight_scale, + bias, + weight_out, + weight_out_scale, + epsilon, + pre_layer_norm, + q_int8, + async_op, + transposed_mode, + use_triton_ln=True): + assert not transposed_mode + + # activation + activation = "gelu" + + # intermediate fc in FF + intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True) + + # output fc in FF + ff_out = matmul_ext.matmul( + intm_out, + weight_out, + bias=None, + activation="", # bias added layer with residual_add + bias + layerNorm layer + use_triton=True) + return ff_out + + +def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False): + assert not transposed_mode and not do_flash_attn + qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True) + + return qkv_out + + +def mlp_gemm_func(input, + residual, + input_bias, + weight_interm, + weight_out, + bias, + gamma, + beta, + epsilon, + pre_layer_norm, + mlp_after_attn, + weight_interm_scale, + weight_out_scale, + q_int8, + mlp_act_func_type, + transposed_mode, + use_triton_ln=True): + assert not transposed_mode + + # residual add and layerNorm after attention + if use_triton_ln: + mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) + else: + mlp_input = LayerNormOp.layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) + + # activation + if ActivationFuncType(mlp_act_func_type) == ActivationFuncType.GELU: + activation = "gelu" + elif ActivationFuncType(mlp_act_func_type) == ActivationFuncType.ReLU: + activation = "relu" + else: + activation = "" + + # intermediate fc in FF + intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True) + # output fc in FF + ff_out = matmul_ext.matmul( + intm_out, + weight_out, + bias=None, + activation="", # bias added layer with residual_add + bias + layerNorm layer + use_triton=True) + + return ff_out, mlp_input + + +def qkv_gemm_func( + input, + weight, + q_scale, + bias, + gamma, + beta, + epsilon, + add_bias, + q_int8, + transposed_mode=False, + use_triton_ln=True, +): + + assert not transposed_mode + # residual add and layerNorm after attention + if use_triton_ln: + qkv_input = layer_norm(input, gamma, beta, epsilon) + else: + qkv_input = LayerNormOp()(input, gamma, beta, epsilon) + + qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True) + + return qkv_out, qkv_input diff --git a/deepspeed/ops/transformer/inference/triton/residual_add.py b/deepspeed/ops/transformer/inference/triton/residual_add.py new file mode 100644 index 000000000000..063e7a7e4a2d --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/residual_add.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +from deepspeed.accelerator import get_accelerator + + +@triton.jit +def residual_add_bias_kernel( + hidden_state_ptr, + residual_ptr, + attn_output_ptr, + hidden_state_size, + attn_bias_ptr, + final_bias_ptr, + bias_size, + output_ptr, + mp_size: tl.constexpr, + mlp_after_attn: tl.constexpr, + pre_attn_norm: tl.constexpr, + add_attn_bias: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_state_size + + bias_offsets = offsets % bias_size + bias_mask = bias_offsets < bias_size + + tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask) + tl_residual = tl.load(residual_ptr + offsets, mask=mask) + tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask) + tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask) + tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask) + + if mlp_after_attn: + if pre_attn_norm: + output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size + else: + output = tl_hidden_state + tl_residual + tl_final_bias + else: + output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size + if add_attn_bias: + output += tl_attn_bias / mp_size + + tl.store(output_ptr + offsets, output, mask=mask) + + +def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor, + attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool, + add_attn_bias: bool, pre_attn_norm: bool): + # check that all tensors are on the same device + assert get_accelerator().on_accelerator(hidden_state) \ + and get_accelerator().on_accelerator(residual) \ + and get_accelerator().on_accelerator(attn_output) \ + and get_accelerator().on_accelerator(attn_bias) \ + and get_accelerator().on_accelerator(final_bias) + + # check that all tensors have the same dtype + assert hidden_state.dtype == residual.dtype == attn_output.dtype \ + == attn_bias.dtype == final_bias.dtype + + # check that all tensors have the right shape + assert hidden_state.shape == residual.shape == attn_output.shape + assert attn_bias.shape == final_bias.shape + assert attn_bias.shape[0] == hidden_state.shape[2] + + output = torch.empty_like(hidden_state) + + hidden_state_size = output.numel() + bias_size = attn_bias.numel() + + grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), ) + + residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\ + attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \ + add_attn_bias, \ + BLOCK_SIZE=1024) + + return output diff --git a/deepspeed/ops/transformer/inference/triton/softmax.py b/deepspeed/ops/transformer/inference/triton/softmax.py new file mode 100644 index 000000000000..1ee10d63e6cf --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/softmax.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +''' +softmax +modified the triton kernel in +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py +''' + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr): + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +@triton.jit +def masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr): + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + row_minus_max = row_minus_max + mask + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + assert input.is_contiguous() + assert (dim == -1) or (dim == len(input.shape) - 1), "Only dim=-1 is supported" + + use_mask = False if mask is None else True + input_arg = input.view(-1, input.shape[-1]) + n_rows, n_cols = input_arg.shape + BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + output = torch.empty_like(input) + if use_mask: + assert mask.is_contiguous() + mask = mask.view(-1, mask.shape[-1]) + mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0 + masked_softmax_kernel[(n_rows, )]( + output, + input, + input_arg.stride(0), + mask, + mask_stride, + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + softmax_kernel[(n_rows, )]( + output, + input, + input_arg.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output diff --git a/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py new file mode 100644 index 000000000000..e2128e046df0 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py @@ -0,0 +1,398 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import triton +import triton.language as tl +from .gelu import gelu_functor +import torch + +AUTOTUNE_TOP_K = 10 +SKIP_AUTOTUNE = False + + +def _triton_ops_matmul_early_config_prune(configs, named_args): + device = torch.cuda.current_device() #ignore-cuda + capability = torch.cuda.get_device_capability() #ignore-cuda + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = triton.runtime.driver.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + + return pruned_configs + + +def _fp16_matmul_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE): + if skip_autotune: + configs = [configs[0]] + else: + configs = _triton_ops_matmul_early_config_prune(configs, named_args) + return configs + + +""" +fp16 matmul implementation is adapted from triton matmul: +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/triton/ops/matmul.py +""" + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 256, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 256, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 32, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=5, num_warps=2), + ], + key=['CACHE_M', 'CACHE_N', 'CACHE_K'], + prune_configs_by={ + 'early_config_prune': _fp16_matmul_prune_config, + 'perf_model': None, + 'top_k': AUTOTUNE_TOP_K + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _fp_matmul( + A, + B, + C, + M, + N, + K, + bias, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + CACHE_M, + CACHE_N, + CACHE_K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + BIAS_ADD: tl.constexpr, + ACTIVATION: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K * SPLIT_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # bias addition + if BIAS_ADD: + bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias_ptr = bias + bias_offset + b = tl.load(bias_ptr, mask=bias_offset < N) + acc = acc + b[None, :] + # activation + if ACTIVATION == "relu": + acc = tl.where(acc >= 0, acc, 0) + elif ACTIVATION == "leaky_relu": + acc = tl.where(acc >= 0, acc, 0.01 * acc) + elif ACTIVATION == "gelu": + #acc = tl.sigmoid(1.702 * acc) * acc + acc = gelu_functor(acc) + elif ACTIVATION == "sigmoid": + acc = tl.sigmoid(acc) # sigmoid + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def matmul_4d_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE): + if skip_autotune: + configs = [configs[0]] + else: + device = torch.cuda.current_device() #ignore-cuda + capability = torch.cuda.get_device_capability() #ignore-cuda + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['a_ptr'].element_size() + dtype = named_args['a_ptr'].dtype + + # make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_SIZE_M'], kw['BLOCK_SIZE_N'], kw['BLOCK_SIZE_K'], config.num_stages + + max_shared_memory = triton.runtime.driver.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + return configs + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + num_stages=1, # this is mainly for unit test, to minimize the share memory usage + num_warps=8), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=['CACHE_M', 'CACHE_N', 'CACHE_K'], + prune_configs_by={ + 'early_config_prune': matmul_4d_prune_config, + 'perf_model': None, + 'top_k': AUTOTUNE_TOP_K + }, +) +@triton.jit +def matmul_4d_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + CACHE_M, + CACHE_N, + CACHE_K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MASK: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + pid = tl.program_id(axis=0) + head = tl.program_id(axis=1) + batch = tl.program_id(axis=2) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if MASK: + if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N: + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float("inf") + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + tl.store(c_ptrs, c) + return + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + c = c * scale.to(c_ptr.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if MASK: + c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float("-inf")) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index 0c9c53ab1de1..f98f45ef638e 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -4,7 +4,7 @@ # DeepSpeed Team """ Inspired by original Triton implementation: -https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py +https://github.com/openai/triton/blob/release/2.1.x/python/tutorials/06-fused-attention.py """ import torch @@ -18,7 +18,6 @@ def _fwd_kernel( K, V, sm_scale, - TMP, Out, stride_qz, stride_qh, @@ -45,63 +44,79 @@ def _fwd_kernel( ): start_m = tl.program_id(0) off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v # initialize pointer to m and l - t_ptrs = TMP + off_hz * N_CTX + offs_m m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(tl.float16) # loop over k, v and update accumulator - for start_n in range(0, N_CTX, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) - + lo = 0 + #hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + hi = N_CTX + #hi = (start_m + 1) * BLOCK_M + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - qk *= sm_scale - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(tl.float16) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new + #if IS_CAUSAL: + #qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.float16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) m_i = m_i_new - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + #l_ptrs = L + off_hz * N_CTX + offs_m + #tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + tl.store(O_block_ptr, acc.to(tl.float16)) class triton_flash_attn(torch.nn.Module): @@ -115,7 +130,6 @@ def forward(self, q, k, v, sm_scale, block_128=True): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) - tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( @@ -123,7 +137,6 @@ def forward(self, q, k, v, sm_scale, block_128=True): k, v, sm_scale, - tmp, o, q.stride(0), q.stride(1), diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index bfd4d60dcb1c..1170ae40ca43 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -335,27 +335,29 @@ def __init__(self, config, initial_weights=None, initial_biases=None): self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.init_transformer_weights(self.config.adjust_init_range) else: - # For testing only. - q = initial_weights[0].data - k = initial_weights[1].data - v = initial_weights[2].data - - self.attn_qkvw = nn.Parameter(torch.cat((q, k, v))) - #self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \ - # initial_weights[i].clone() - #torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data) - self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3)) - self.attn_qkvb.data.zero_() - self.attn_ow = initial_weights[3] - self.attn_ob = initial_biases[3] - self.attn_nw = initial_weights[4] - self.attn_nb = initial_biases[4] - self.inter_w = initial_weights[5] - self.inter_b = initial_biases[5] - self.output_w = initial_weights[6] - self.output_b = initial_biases[6] - self.norm_w = initial_weights[7] - self.norm_b = initial_biases[7] + if initial_weights is not None: + # For testing only. + q = initial_weights[0].data + k = initial_weights[1].data + v = initial_weights[2].data + + self.attn_qkvw = nn.Parameter(torch.cat((q, k, v))) + #self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \ + # initial_weights[i].clone() + #torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data) + self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3)) + self.attn_qkvb.data.zero_() + self.attn_ow = initial_weights[3] + self.attn_nw = initial_weights[4] + self.inter_w = initial_weights[5] + self.output_w = initial_weights[6] + self.norm_w = initial_weights[7] + if initial_biases is not None: + self.attn_ob = initial_biases[3] + self.attn_nb = initial_biases[4] + self.inter_b = initial_biases[5] + self.output_b = initial_biases[6] + self.norm_b = initial_biases[7] # Load cuda modules if needed global transformer_cuda_module, stochastic_transformer_cuda_module diff --git a/deepspeed/profiling/config.py b/deepspeed/profiling/config.py index 7533fc299f0e..e4f06630ea6f 100644 --- a/deepspeed/profiling/config.py +++ b/deepspeed/profiling/config.py @@ -13,6 +13,7 @@ def __init__(self, param_dict): super(DeepSpeedFlopsProfilerConfig, self).__init__() self.enabled = None + self.recompute_fwd_factor = None self.profile_step = None self.module_depth = None self.top_modules = None @@ -27,6 +28,9 @@ def __init__(self, param_dict): def _initialize(self, flops_profiler_dict): self.enabled = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT) + self.recompute_fwd_factor = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR, + FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR_DEFAULT) + self.profile_step = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_PROFILE_STEP, FLOPS_PROFILER_PROFILE_STEP_DEFAULT) diff --git a/deepspeed/profiling/constants.py b/deepspeed/profiling/constants.py index e16baea27ded..0374303d7d96 100644 --- a/deepspeed/profiling/constants.py +++ b/deepspeed/profiling/constants.py @@ -13,6 +13,7 @@ "session_params": { "flops_profiler": { "enabled": true, + "recompute_fwd_factor": 0.0, "profile_step": 1, "module_depth": -1, "top_modules": 3, @@ -27,6 +28,9 @@ FLOPS_PROFILER_ENABLED = "enabled" FLOPS_PROFILER_ENABLED_DEFAULT = False +FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR = "recompute_fwd_factor" +FLOPS_PROFILER_RECOMPUTE_FWD_FACTOR_DEFAULT = 0.0 + FLOPS_PROFILER_PROFILE_STEP = "profile_step" FLOPS_PROFILER_PROFILE_STEP_DEFAULT = 1 diff --git a/deepspeed/profiling/flops_profiler/README.md b/deepspeed/profiling/flops_profiler/README.md index af23d56ee76a..68ac3dc285c7 100644 --- a/deepspeed/profiling/flops_profiler/README.md +++ b/deepspeed/profiling/flops_profiler/README.md @@ -166,6 +166,7 @@ When using DeepSpeed for model training, the profiler can be configured in the d { "flops_profiler": { "enabled": true, + "recompute_fwd_factor": 0.0, "profile_step": 1, "module_depth": -1, "top_modules": 1, @@ -177,7 +178,7 @@ When using DeepSpeed for model training, the profiler can be configured in the d #### Example: Megatron-LM -For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). +For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/Megatron-LM). An example output of 12-layer Megatron-LM model (`hidden_size = 8192, num_attention_heads = 32, batch_size = 1024, seq_length = 1024`) is shown below. diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index f39f25ce87b1..1f051077c36c 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -12,6 +12,11 @@ from collections import OrderedDict import numpy as np from deepspeed.accelerator import get_accelerator +from deepspeed.utils import logger +from deepspeed.moe.layer import MoE +from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER +from deepspeed.utils.torch import required_torch_version +import einops Tensor = torch.Tensor @@ -19,6 +24,8 @@ module_mac_count = [] old_functions = {} +DEFAULT_PRECISION = 2 + class FlopsProfiler(object): """Measures the latency, number of estimated floating-point operations and parameters of each module in a PyTorch model. @@ -57,9 +64,10 @@ class FlopsProfiler(object): object (torch.nn.Module): The PyTorch model to profile. """ - def __init__(self, model, ds_engine=None): + def __init__(self, model, ds_engine=None, recompute_fwd_factor=0.0): self.model = model self.ds_engine = ds_engine + self.recompute_fwd_factor = recompute_fwd_factor self.started = False self.func_patched = False @@ -71,9 +79,11 @@ def start_profile(self, ignore_list=None): Args: ignore_list (list, optional): the list of modules to ignore while profiling. Defaults to None. """ + logger.info("Flops profiler started") self.reset_profile() _patch_functionals() _patch_tensor_methods() + _patch_miscellaneous_operations() def register_module_hooks(module, ignore_list): if ignore_list and type(module) in ignore_list: @@ -107,7 +117,7 @@ def start_time_hook(module, input): get_accelerator().synchronize() module.__start_time__ = time.time() - if not hasattr(module, "__start_time_hook_handle"): + if not hasattr(module, "__start_time_hook_handle__"): module.__start_time_hook_handle__ = module.register_forward_pre_hook(start_time_hook) def end_time_hook(module, input, output): @@ -129,6 +139,7 @@ def stop_profile(self): if self.started and self.func_patched: _reload_functionals() _reload_tensor_methods() + _reload_miscellaneous_operations() self.func_patched = False def remove_profile_attrs(module): @@ -156,10 +167,34 @@ def reset_profile(self): Adds or resets the extra attributes. """ + def get_param_count_and_ep(param): + """ + Return the number of parameters in the layer, whether the layer is an MoE layer, + and its expert parallelism size if so + """ + prefix = 'ep_size_' + offset = len(prefix) + expert_parallelism = 0 + if getattr(param, "group_name", "").startswith(prefix): + try: + expert_parallelism = int(param.group_name[offset:]) + except ValueError: + pass + return param.numel(), expert_parallelism, param.element_size() + def add_or_reset_attrs(module): module.__flops__ = 0 module.__macs__ = 0 - module.__params__ = sum(p.numel() for p in module.parameters()) + module.__params__ = module.__expert_params__ = module.__model_expert_params__ = 0 + parameters = (get_param_count_and_ep(p) for p in module.parameters()) + for num_params, expert_parallelism, per_param_size in parameters: + params = num_params if not expert_parallelism else 0 + expert_params = num_params if expert_parallelism else 0 + # number of expert parameters taking into account other expert parallel groups + model_expert_params = num_params * expert_parallelism + module.__params__ += params + module.__expert_params__ += expert_params + module.__model_expert_params__ += model_expert_params module.__start_time__ = 0 module.__duration__ = 0 @@ -182,12 +217,17 @@ def remove_profile_attrs(module): del module.__macs__ if hasattr(module, "__params__"): del module.__params__ + if hasattr(module, "__expert_params__"): + del module.__expert_params__ + if hasattr(module, "__model_expert_params__"): + del module.__model_expert_params__ if hasattr(module, "__start_time__"): del module.__start_time__ if hasattr(module, "__duration__"): del module.__duration__ self.model.apply(remove_profile_attrs) + logger.info("Flops profiler finished") def get_total_flops(self, as_string=False): """Returns the total flops of the model. @@ -199,7 +239,7 @@ def get_total_flops(self, as_string=False): The number of multiply-accumulate operations of the model forward pass. """ total_flops = get_module_flops(self.model) - return num_to_string(total_flops) if as_string else total_flops + return number_to_string(total_flops) if as_string else total_flops def get_total_macs(self, as_string=False): """Returns the total MACs of the model. @@ -226,15 +266,22 @@ def get_total_duration(self, as_string=False): return duration_to_string(total_duration) if as_string else total_duration def get_total_params(self, as_string=False): - """Returns the total parameters of the model. + """Returns the total number of parameters stored per rank. Args: as_string (bool, optional): whether to output the parameters as string. Defaults to False. Returns: - The number of parameters in the model. + The total number of parameters stored per rank. """ - return params_to_string(self.model.__params__) if as_string else self.model.__params__ + total_params = self.model.__expert_params__ + self.model.__params__ + return params_to_string(total_params) if as_string else total_params + + def is_expert_tensor_parallelism_enabled(self): + for _, module in self.model.named_modules(): + if isinstance(module, MoE) and hasattr(module, 'enable_expert_tensor_parallelism'): + return module.enable_expert_tensor_parallelism + return False def print_model_profile(self, profile_step=1, module_depth=-1, top_modules=1, detailed=True, output_file=None): """Prints the model graph with the measured profile attached to each module. @@ -264,6 +311,14 @@ def print_model_profile(self, profile_step=1, module_depth=-1, top_modules=1, de total_macs = self.get_total_macs() total_duration = self.get_total_duration() total_params = self.get_total_params() + expert_tensor_parallelism = None # silence the linters + total_model_expert_params = total_model_nonexpert_params = 0 + if self.ds_engine: + total_model_nonexpert_params = self.model.__params__ * self.ds_engine.mp_world_size + if self.ds_engine.has_moe_layers: + expert_tensor_parallelism = self.ds_engine.mp_world_size if self.is_expert_tensor_parallelism_enabled( + ) else 1 + total_model_expert_params = self.model.__model_expert_params__ * expert_tensor_parallelism self.flops = total_flops self.macs = total_macs @@ -271,70 +326,92 @@ def print_model_profile(self, profile_step=1, module_depth=-1, top_modules=1, de print("\n-------------------------- DeepSpeed Flops Profiler --------------------------") print(f'Profile Summary at step {profile_step}:') - print( - "Notations:\ndata parallel size (dp_size), model parallel size(mp_size),\nnumber of parameters (params), number of multiply-accumulate operations(MACs),\nnumber of floating-point operations (flops), floating-point operations per second (FLOPS),\nfwd latency (forward propagation latency), bwd latency (backward propagation latency),\nstep (weights update latency), iter latency (sum of fwd, bwd and step latency)\n" - ) + print("Notations:\n" + "data parallel size (dp_size), model parallel size(mp_size),\n" + "number of parameters (params), number of multiply-accumulate operations(MACs),\n" + "number of floating-point operations (flops), floating-point operations per second (FLOPS),\n" + "fwd latency (forward propagation latency), bwd latency (backward propagation latency),\n" + "step (weights update latency), iter latency (sum of fwd, bwd and step latency)\n") + line_fmt = '{:<70} {:<8}' if self.ds_engine: - print('{:<60} {:<8}'.format('world size: ', self.ds_engine.world_size)) - print('{:<60} {:<8}'.format('data parallel size: ', self.ds_engine.dp_world_size)) - print('{:<60} {:<8}'.format('model parallel size: ', self.ds_engine.mp_world_size)) - print('{:<60} {:<8}'.format('batch size per GPU: ', self.ds_engine.train_micro_batch_size_per_gpu())) - - print('{:<60} {:<8}'.format('params per gpu: ', params_to_string(total_params))) - print('{:<60} {:<8}'.format( - 'params of model = params per GPU * mp_size: ', - params_to_string(total_params * ((self.ds_engine.mp_world_size) if self.ds_engine else 1)))) + print(line_fmt.format('world size: ', self.ds_engine.world_size)) + print(line_fmt.format('data parallel size: ', self.ds_engine.dp_world_size)) + print(line_fmt.format('model parallel size: ', self.ds_engine.mp_world_size)) + print(line_fmt.format('batch size per GPU: ', self.ds_engine.train_micro_batch_size_per_gpu())) + if self.ds_engine.has_moe_layers: + print(line_fmt.format('expert tensor parallelism enabled: ', expert_tensor_parallelism > 1)) + + print(line_fmt.format('params per GPU: ', params_to_string(total_params))) + if total_model_expert_params > 0: + print( + line_fmt.format('params of model: ', + params_to_string(total_model_nonexpert_params + total_model_expert_params))) + print(line_fmt.format(' non-expert params of model: ', params_to_string(total_model_nonexpert_params))) + print(line_fmt.format(' expert params of model: ', params_to_string(total_model_expert_params))) + else: + print( + line_fmt.format('params of model = params per GPU * mp_size: ', + params_to_string(total_model_nonexpert_params))) - print('{:<60} {:<8}'.format('fwd MACs per GPU: ', macs_to_string(total_macs))) + print(line_fmt.format('fwd MACs per GPU: ', macs_to_string(total_macs))) - print('{:<60} {:<8}'.format('fwd flops per GPU: ', num_to_string(total_flops))) + print(line_fmt.format('fwd flops per GPU: ', number_to_string(total_flops))) - print('{:<60} {:<8}'.format( - 'fwd flops of model = fwd flops per GPU * mp_size: ', - num_to_string(total_flops * ((self.ds_engine.mp_world_size) if self.ds_engine else 1)))) + print( + line_fmt.format('fwd flops of model = fwd flops per GPU * mp_size: ', + number_to_string(total_flops * (self.ds_engine.mp_world_size if self.ds_engine else 1)))) fwd_latency = self.get_total_duration() if self.ds_engine and self.ds_engine.wall_clock_breakdown(): - fwd_latency = self.ds_engine.timers('forward').elapsed(False) / 1000.0 - print('{:<60} {:<8}'.format('fwd latency: ', duration_to_string(fwd_latency))) - print('{:<60} {:<8}'.format('fwd FLOPS per GPU = fwd flops per GPU / fwd latency: ', - flops_to_string(total_flops / fwd_latency))) + fwd_latency = self.ds_engine.timers(FORWARD_GLOBAL_TIMER).elapsed(False) / 1000.0 + print(line_fmt.format('fwd latency: ', duration_to_string(fwd_latency))) + print( + line_fmt.format('fwd FLOPS per GPU = fwd flops per GPU / fwd latency: ', + flops_to_string(total_flops / fwd_latency))) if self.ds_engine and self.ds_engine.wall_clock_breakdown(): - bwd_latency = self.ds_engine.timers('backward').elapsed(False) / 1000.0 - step_latency = self.ds_engine.timers('step').elapsed(False) / 1000.0 - print('{:<60} {:<8}'.format('bwd latency: ', duration_to_string(bwd_latency))) - print('{:<60} {:<8}'.format('bwd FLOPS per GPU = 2 * fwd flops per GPU / bwd latency: ', - flops_to_string(2 * total_flops / bwd_latency))) - print('{:<60} {:<8}'.format('fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency): ', - flops_to_string(3 * total_flops / (fwd_latency + bwd_latency)))) + bwd_factor = 2 + self.recompute_fwd_factor + bwd_latency = self.ds_engine.timers(BACKWARD_GLOBAL_TIMER).elapsed(False) / 1000.0 + step_latency = self.ds_engine.timers(STEP_GLOBAL_TIMER).elapsed(False) / 1000.0 + print(line_fmt.format('bwd latency: ', duration_to_string(bwd_latency))) + print( + line_fmt.format(f'bwd FLOPS per GPU = {bwd_factor:g} * fwd flops per GPU / bwd latency: ', + flops_to_string(bwd_factor * total_flops / bwd_latency))) + print( + line_fmt.format( + f'fwd+bwd FLOPS per GPU = {bwd_factor + 1:g} * fwd flops per GPU / (fwd+bwd latency): ', + flops_to_string((bwd_factor + 1) * total_flops / (fwd_latency + bwd_latency)))) - print('{:<60} {:<8}'.format('step latency: ', duration_to_string(step_latency))) + print(line_fmt.format('step latency: ', duration_to_string(step_latency))) iter_latency = fwd_latency + bwd_latency + step_latency - print('{:<60} {:<8}'.format('iter latency: ', duration_to_string(iter_latency))) - print('{:<60} {:<8}'.format('FLOPS per GPU = 3 * fwd flops per GPU / iter latency: ', - flops_to_string(3 * total_flops / iter_latency))) + print(line_fmt.format('iter latency: ', duration_to_string(iter_latency))) + print( + line_fmt.format(f'FLOPS per GPU = {bwd_factor + 1:g} * fwd flops per GPU / iter latency: ', + flops_to_string((bwd_factor + 1) * total_flops / iter_latency))) samples_per_iter = self.ds_engine.train_micro_batch_size_per_gpu() * self.ds_engine.world_size - print('{:<60} {:<8.2f}'.format('samples/second: ', samples_per_iter / iter_latency)) + print(line_fmt.format('samples/second: ', round(samples_per_iter / iter_latency, DEFAULT_PRECISION))) def flops_repr(module): - params = module.__params__ + params = module.__params__ + module.__expert_params__ flops = get_module_flops(module) macs = get_module_macs(module) + duration = get_module_duration(module) items = [ - params_to_string(params), - "{:.2%} Params".format(params / total_params if total_params else 0), - macs_to_string(macs), - "{:.2%} MACs".format(0.0 if total_macs == 0 else macs / total_macs), + "{} = {:g}% Params".format( + params_to_string(params), + round(100 * params / total_params, DEFAULT_PRECISION) if total_params else 0), + "{} = {:g}% MACs".format(macs_to_string(macs), + round(100 * macs / total_macs, DEFAULT_PRECISION) if total_macs else 0), + "{} = {:g}% latency".format( + duration_to_string(duration), + round(100 * duration / total_duration, DEFAULT_PRECISION) if total_duration else 0), + flops_to_string(round(flops / duration, DEFAULT_PRECISION) if duration else 0), ] - duration = get_module_duration(module) - - items.append(duration_to_string(duration)) - items.append("{:.2%} latency".format(0.0 if total_duration == 0 else duration / total_duration)) - items.append(flops_to_string(0.0 if duration == 0 else flops / duration)) - items.append(module.original_extra_repr()) + original_extra_repr = module.original_extra_repr() + if original_extra_repr: + items.append(original_extra_repr) return ", ".join(items) def add_extra_repr(module): @@ -394,7 +471,7 @@ def walk_module(module, curr_depth, info): 0, ] # macs, params, time info[curr_depth][module.__class__.__name__][0] += get_module_macs(module) - info[curr_depth][module.__class__.__name__][1] += module.__params__ + info[curr_depth][module.__class__.__name__][1] += module.__params__ + module.__expert_params__ info[curr_depth][module.__class__.__name__][2] += get_module_duration(module) has_children = len(module._modules.items()) != 0 if has_children: @@ -495,9 +572,20 @@ def _conv_flops_compute(input, weight, bias=None, stride=1, padding=0, dilation= length = len(input_dims) - paddings = padding if type(padding) is tuple else (padding, ) * length strides = stride if type(stride) is tuple else (stride, ) * length dilations = dilation if type(dilation) is tuple else (dilation, ) * length + if isinstance(padding, str): + if padding == 'valid': + paddings = (0, ) * length + elif padding == 'same': + paddings = () + for d, k in zip(dilations, kernel_dims): + total_padding = d * (k - 1) + paddings += (total_padding // 2, ) + elif isinstance(padding, tuple): + paddings = padding + else: + paddings = (padding, ) * length output_dims = [] for idx, input_dim in enumerate(input_dims): @@ -530,7 +618,7 @@ def _conv_trans_flops_compute( ): batch_size = input.shape[0] in_channels = input.shape[1] - out_channels = weight.shape[0] + out_channels = weight.shape[1] kernel_dims = list(weight.shape[2:]) input_dims = list(input.shape[2:]) @@ -619,21 +707,33 @@ def _instance_norm_flops_compute( return input.numel() * (5 if has_affine else 4), 0 -def _upsample_flops_compute(input, **kwargs): +def _upsample_flops_compute(*args, **kwargs): + input = args[0] size = kwargs.get('size', None) + if size is None and len(args) > 1: + size = args[1] + if size is not None: if isinstance(size, tuple) or isinstance(size, list): return int(_prod(size)), 0 else: return int(size), 0 + scale_factor = kwargs.get('scale_factor', None) + if scale_factor is None and len(args) > 2: + scale_factor = args[2] assert scale_factor is not None, "either size or scale_factor should be defined" + flops = input.numel() - if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): - flops * int(_prod(scale_factor)) + if isinstance(scale_factor, (list, tuple)): + # see documentation of `F.interpolate` + # the spatial dims are defined as the last `n-2` dims of the tensor + assert len(scale_factor) == input.ndim - 2 + flops *= _prod(scale_factor) else: - flops * scale_factor**len(input) - return flops, 0 + flops *= scale_factor**(input.ndim - 2) + + return int(flops), 0 def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None): @@ -694,6 +794,29 @@ def _einsum_flops_compute(equation, *operands): raise NotImplementedError("Unsupported einsum operation.") +def _einops_einsum_flops_compute(*args): + """ + Count flops for the einops.einsum operation. + """ + *operands, equation = args + input_shapes = [o.shape for o in operands] + + # Re-map equation so that same equation with different alphabet + # representations will look the same. + letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() + mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} + equation = equation.translate(mapping) + + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + flop = int(float(line.split(":")[-1])) + return flop, 0 + + raise NotImplementedError("Unsupported einops.einsum operation.") + + def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None): """ Count flops for the tensor addmm operation. @@ -735,6 +858,15 @@ def _elementwise_flops_compute(input, other): return flops, 0 +def _attn_flops_compute(q, k, v, *args, **kwargs): + """ + Count flops for the scaled_dot_product_attention operation. + """ + macs = _prod(q.shape) * k.shape[-2] + macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1] + return 2 * macs, macs + + def wrapFunc(func, funcFlopCompute): oldFunc = func name = func.__str__ @@ -807,10 +939,15 @@ def _patch_functionals(): # embedding F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) + # attn - scaled_dot_product_attention added in torch 2.0+ + if required_torch_version(min_version=2.0): + F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute) + def _patch_tensor_methods(): torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute) torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute) + torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute) torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) @@ -830,6 +967,10 @@ def _patch_tensor_methods(): torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute) +def _patch_miscellaneous_operations(): + einops.einsum = wrapFunc(einops.einsum, _einops_einsum_flops_compute) + + def _reload_functionals(): # torch.nn.functional does not support importlib.reload() F.linear = old_functions[F.linear.__str__] @@ -888,11 +1029,16 @@ def _reload_tensor_methods(): torch.baddbmm = old_functions[torch.baddbmm.__str__] +def _reload_miscellaneous_operations(): + einops.einsum = old_functions[einops.einsum.__str__] + + def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): + gates_size = w_ih.shape[0] # matrix matrix mult ih state and internal state - flops += w_ih.shape[0] * w_ih.shape[1] + flops += 2 * w_ih.shape[0] * w_ih.shape[1] - gates_size # matrix matrix mult hh state and internal state - flops += w_hh.shape[0] * w_hh.shape[1] + flops += 2 * w_hh.shape[0] * w_hh.shape[1] - gates_size if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): # add both operations flops += rnn_module.hidden_size @@ -969,118 +1115,59 @@ def _rnn_cell_forward_hook(rnn_cell_module, input, output): } -def num_to_string(num, precision=2): - if num // 10**9 > 0: - return str(round(num / 10.0**9, precision)) + " G" - elif num // 10**6 > 0: - return str(round(num / 10.0**6, precision)) + " M" - elif num // 10**3 > 0: - return str(round(num / 10.0**3, precision)) + " K" - else: - return str(num) - - -def macs_to_string(macs, units=None, precision=2): - if units is None: - if macs // 10**9 > 0: - return str(round(macs / 10.0**9, precision)) + " GMACs" - elif macs // 10**6 > 0: - return str(round(macs / 10.0**6, precision)) + " MMACs" - elif macs // 10**3 > 0: - return str(round(macs / 10.0**3, precision)) + " KMACs" - else: - return str(macs) + " MACs" - else: - if units == "GMACs": - return str(round(macs / 10.0**9, precision)) + " " + units - elif units == "MMACs": - return str(round(macs / 10.0**6, precision)) + " " + units - elif units == "KMACs": - return str(round(macs / 10.0**3, precision)) + " " + units - else: - return str(macs) + " MACs" +def macs_to_string(macs, units=None, precision=DEFAULT_PRECISION): + return f"{number_to_string(macs, units=units, precision=precision)}MACs" -def number_to_string(num, units=None, precision=2): +def number_to_string(num, units=None, precision=DEFAULT_PRECISION): if units is None: - if num // 10**9 > 0: - return str(round(num / 10.0**9, precision)) + " G" - elif num // 10**6 > 0: - return str(round(num / 10.0**6, precision)) + " M" - elif num // 10**3 > 0: - return str(round(num / 10.0**3, precision)) + " K" + if num >= 1e12: + magnitude, units = 1e12, "T" + elif num >= 1e9: + magnitude, units = 1e9, "G" + elif num >= 1e6: + magnitude, units = 1e6, "M" + elif num >= 1e3: + magnitude, units = 1e3, "K" + elif num >= 1 or num == 0: + magnitude, units = 1, "" + elif num >= 1e-3: + magnitude, units = 1e-3, "m" else: - return str(num) + " " + magnitude, units = 1e-6, "u" else: - if units == "G": - return str(round(num / 10.0**9, precision)) + " " + units + if units == "T": + magnitude = 1e12 + elif units == "G": + magnitude = 1e9 elif units == "M": - return str(round(num / 10.0**6, precision)) + " " + units + magnitude = 1e6 elif units == "K": - return str(round(num / 10.0**3, precision)) + " " + units + magnitude = 1e3 + elif units == "m": + magnitude = 1e-3 + elif units == "u": + magnitude = 1e-6 else: - return str(num) + " " + magnitude = 1 + return f"{round(num / magnitude, precision):g} {units}" -def flops_to_string(flops, units=None, precision=2): - if units is None: - if flops // 10**12 > 0: - return str(round(flops / 10.0**12, precision)) + " TFLOPS" - if flops // 10**9 > 0: - return str(round(flops / 10.0**9, precision)) + " GFLOPS" - elif flops // 10**6 > 0: - return str(round(flops / 10.0**6, precision)) + " MFLOPS" - elif flops // 10**3 > 0: - return str(round(flops / 10.0**3, precision)) + " KFLOPS" - else: - return str(flops) + " FLOPS" - else: - if units == "TFLOPS": - return str(round(flops / 10.0**12, precision)) + " " + units - if units == "GFLOPS": - return str(round(flops / 10.0**9, precision)) + " " + units - elif units == "MFLOPS": - return str(round(flops / 10.0**6, precision)) + " " + units - elif units == "KFLOPS": - return str(round(flops / 10.0**3, precision)) + " " + units - else: - return str(flops) + " FLOPS" +def flops_to_string(flops, units=None, precision=DEFAULT_PRECISION): + return f"{number_to_string(flops, units=units, precision=precision)}FLOPS" -def params_to_string(params_num, units=None, precision=2): - if units is None: - if params_num // 10**6 > 0: - return str(round(params_num / 10**6, 2)) + " M" - elif params_num // 10**3: - return str(round(params_num / 10**3, 2)) + " k" - else: - return str(params_num) - else: - if units == "M": - return str(round(params_num / 10.0**6, precision)) + " " + units - elif units == "K": - return str(round(params_num / 10.0**3, precision)) + " " + units - else: - return str(params_num) +def bytes_to_string(b, units=None, precision=DEFAULT_PRECISION): + return f"{number_to_string(b, units=units, precision=precision)}B" -def duration_to_string(duration, units=None, precision=2): - if units is None: - if duration > 1: - return str(round(duration, precision)) + " s" - elif duration * 10**3 > 1: - return str(round(duration * 10**3, precision)) + " ms" - elif duration * 10**6 > 1: - return str(round(duration * 10**6, precision)) + " us" - else: - return str(duration) - else: - if units == "us": - return str(round(duration * 10.0**6, precision)) + " " + units - elif units == "ms": - return str(round(duration * 10.0**3, precision)) + " " + units - else: - return str(round(duration, precision)) + " s" +def params_to_string(params_num, units=None, precision=DEFAULT_PRECISION): + units = units.replace("B", "G") if units else units + return number_to_string(params_num, units=units, precision=precision).replace("G", "B").strip() + + +def duration_to_string(duration, units=None, precision=DEFAULT_PRECISION): + return f"{number_to_string(duration, units=units, precision=precision)}s" # can not iterate over all submodules using self.model.modules() @@ -1105,24 +1192,23 @@ def get_module_duration(module): duration = module.__duration__ if duration == 0: # e.g. ModuleList for m in module.children(): - duration += m.__duration__ + duration += get_module_duration(m) return duration -def get_model_profile( - model, - input_shape=None, - args=[], - kwargs={}, - print_profile=True, - detailed=True, - module_depth=-1, - top_modules=1, - warm_up=1, - as_string=True, - output_file=None, - ignore_modules=None, -): +def get_model_profile(model, + input_shape=None, + args=[], + kwargs={}, + print_profile=True, + detailed=True, + module_depth=-1, + top_modules=1, + warm_up=1, + as_string=True, + output_file=None, + ignore_modules=None, + mode='forward'): """Returns the total floating-point operations, MACs, and parameters of a model. Example: @@ -1169,17 +1255,30 @@ def get_model_profile( args = [input] assert (len(args) > 0) or (len(kwargs) > 0), "args and/or kwargs must be specified if input_shape is None" + logger.info("Flops profiler warming-up...") for _ in range(warm_up): if kwargs: - _ = model(*args, **kwargs) + if mode == 'forward': + _ = model(*args, **kwargs) + if mode == 'generate': + _ = model.generate(*args, **kwargs) else: - _ = model(*args) + if mode == 'forward': + _ = model(*args) + if mode == 'generate': + _ = model.generate(*args) prof.start_profile(ignore_list=ignore_modules) if kwargs: - _ = model(*args, **kwargs) + if mode == 'forward': + _ = model(*args, **kwargs) + if mode == 'generate': + _ = model.generate(*args, **kwargs) else: - _ = model(*args) + if mode == 'forward': + _ = model(*args) + if mode == 'generate': + _ = model.generate(*args) flops = prof.get_total_flops() macs = prof.get_total_macs() diff --git a/deepspeed/runtime/__init__.py b/deepspeed/runtime/__init__.py index 347ff7993d82..208299fb8c50 100644 --- a/deepspeed/runtime/__init__.py +++ b/deepspeed/runtime/__init__.py @@ -2,11 +2,3 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - - -class DeepSpeedOptimizer(object): - pass - - -class ZeROOptimizer(DeepSpeedOptimizer): - pass diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index c305ade64621..fae0148ba887 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -18,23 +18,28 @@ import torch import contextlib from deepspeed import comm as dist +import weakref import mmap from torch import _C from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.utils import logger -from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank -from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers +from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage +from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.accelerator import get_accelerator +from deepspeed.runtime import compiler # DeepSpeed Checkpointing Enabled or Disabled deepspeed_checkpointing_enabled = False # MP parameters mpu = None -mp_rank = None -mp_size = None + +#set default values +mp_rank = 0 +mp_size = 1 mp_group = None # Model Parameters @@ -58,8 +63,6 @@ # Default name for the model parallel rng tracker. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' -transport_stream = None -cuda_device = None def detach_variable(inputs, device=None): @@ -236,6 +239,14 @@ def model_parallel_cuda_manual_seed(seed): _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) +def model_parallel_reconfigure_tp_seed(seed): + global mpu + tp_rank = bwc_tensor_model_parallel_rank(mpu) + model_parallel_seed = seed + 2718 + tp_rank + with _CUDA_RNG_STATE_TRACKER.fork(): + get_accelerator().manual_seed(model_parallel_seed) + + def get_partition_start(item): global mp_rank, mp_size, mp_group size = item.numel() @@ -269,6 +280,8 @@ def gather_partitioned_activations(tensors, device=None): # don't need to do all_gather if model parallel is not enabled if mp_group is None or mp_size == 1: item = item.view(list(size.numpy())) + if device is not None: + item = item.to(device) inputs.append(item) continue @@ -278,13 +291,9 @@ def gather_partitioned_activations(tensors, device=None): flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) else: flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device) - partitions = [] - for i in range(mp_size): - part_i = flat_tensor.narrow(0, partition_size * i, partition_size) - if i == mp_rank: - part_i.copy_(item) - partitions.append(part_i) - dist.all_gather(partitions, partitions[mp_rank], group=mp_group) + part = flat_tensor.narrow(0, partition_size * mp_rank, partition_size) + part.copy_(item) + dist.all_gather_into_tensor(flat_tensor, part, group=mp_group) input_tensor = flat_tensor.view(list(size.numpy())) item.data = input_tensor.data @@ -360,7 +369,9 @@ def is_activation_to_checkpoint(item): Is an activation to be checkpointed """ global mp_size - return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size + extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing') + and item.no_checkpointing == False) + return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): @@ -432,7 +443,9 @@ def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint num_non_fp_tensors += 1 continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data + new_args.append(arg) i = arg_index - num_non_fp_tensors @@ -465,7 +478,8 @@ def get_cpu_activations_for_backward(args, inputs): new_args.append(arg) continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data new_args.append(arg) return new_args @@ -499,42 +513,17 @@ def save_args_for_backward(*all_args): timers = Timers() if PROFILE_TIME: - timers('forward').start() + timers(FORWARD_GLOBAL_TIMER).start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets - if mp_rank is None: - if mpu is not None: - if hasattr(mpu, 'get_tensor_model_parallel_rank'): - mp_rank = mpu.get_tensor_model_parallel_rank() - mp_size = mpu.get_tensor_model_parallel_world_size() - mp_group = mpu.get_tensor_model_parallel_group() - else: - mp_rank = mpu.get_model_parallel_rank() - mp_size = mpu.get_model_parallel_world_size() - mp_group = mpu.get_model_parallel_group() - else: - mp_rank = 0 - mp_size = 1 - mp_group = None - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset - - if cuda_device is None: - see_memory_usage("First Forward Beginning", force=False) - if dist.get_rank() == 0: - logger.info(f"Activation Checkpointing Information") - logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}") - logger.info( - f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers") - logger.info(f"----Synchronization {SYNCHRONIZE}") - logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") + global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset - cuda_device = get_accelerator().current_device_name() - transport_stream = get_accelerator().Stream(device=cuda_device) + cuda_device = get_accelerator().current_device_name() + transport_stream = get_accelerator().Stream(device=cuda_device) if PARTITION_ACTIVATIONS: inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING) @@ -568,8 +557,8 @@ def save_args_for_backward(*all_args): save_args_for_backward(*args) if PROFILE_TIME: - timers('forward').stop() - timers.log(['forward']) + timers(FORWARD_GLOBAL_TIMER).stop() + timers.log([FORWARD_GLOBAL_TIMER]) if SYNCHRONIZE: get_accelerator().synchronize() @@ -619,7 +608,14 @@ def backward(ctx, *grads): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") - global cuda_device, transport_stream, PARTITION_ACTIVATIONS + global PARTITION_ACTIVATIONS + cuda_device = get_accelerator().current_device_name() + transport_stream = get_accelerator().Stream(device=cuda_device) + # Rebuild deepspeed_saved_tensors + for t in ctx.deepspeed_saved_tensors: + if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None: + t.data = t.saved_data.to(t.device) + t.saved_data = None if PARTITION_ACTIVATIONS: # with get_accelerator().stream(transport_stream): @@ -705,6 +701,250 @@ def backward(ctx, *grads): return tuple(ret_list) +def non_reentrant_checkpoint(function, *args): + """This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module + + This function is aim to solve the back probagation error raised from all input requires no grad. + * has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode. + * can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable` + + Main modifications compared to the implementation of torch: + 1. adapt to the signature of `checkpoint` function in this module + 2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction` + 3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation + 4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution. + 5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0 + """ + global mpu, timers, SYNCHRONIZE, PROFILE_TIME + + deepspeed_saved_tensors = None + non_tensor_args = None + tensor_flags = None + + def save_args_for_backward(*all_args): + """keep this function to reduce the modification from original implementation""" + nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags + tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args) + deepspeed_saved_tensors = tensor_args + non_tensor_args = non_tensor_args + tensor_flags = tensor_flags + + if SYNCHRONIZE: + get_accelerator().synchronize() + + if timers is None and PROFILE_TIME: + timers = Timers() + + if PROFILE_TIME: + timers(FORWARD_GLOBAL_TIMER).start() + + global num_layers + global mp_rank, mp_size, mp_group + global contiguous_data_buffers, contiguous_size_buffers + global data_offsets, size_offsets + global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset + + cuda_device = get_accelerator().current_device_name() + transport_stream = get_accelerator().Stream(device=cuda_device) + + if PARTITION_ACTIVATIONS: + inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING) + elif CPU_CHECKPOINT: + inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint) + + # just in case something funky is happening such as reuse of inputs + inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint) + + # Copy the rng states. + fwd_cpu_rng_state = torch.get_rng_state() + fwd_cuda_rng_state = get_accelerator().get_rng_state() + fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + if PARTITION_ACTIVATIONS: + new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING) + assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' + save_args_for_backward(*new_args) + elif CPU_CHECKPOINT: + new_args = get_cpu_activations_for_backward(args, inputs) + save_args_for_backward(*new_args) + else: + save_args_for_backward(*args) + + class Holder(): + """the place holder object used as activations to save memory""" + pass + + # weakref seems utilized to discover the tensor deletion before a whole + # forward backward pair loop finished + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + leaf_tensors = [] + backward_visited_leaf_nodes = 0 + + def checkpoint_pack(tensor_from_forward): + """used to record the activation order in the `weak_holder_list` + + the activation order in holder list is consistent between the first forward and recomputing forward. + * the jit compiled forward will break the order consistency * + """ + res = Holder() + weak_holder_list.append(weakref.ref(res)) + + # if this is a leaf tensor, save it for backward progression trace + # leaf tensor used to be input or parameters, which is not activations and + # has no memory overhead + if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf: + leaf_tensors.append(tensor_from_forward) + return res + + def checkpoint_unpack(holder_from_backward): + """retrieve the activations from recompute""" + nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags + + # if this is the first step of backward probagation, recompute the graph and save + # all the activations with the same order as `checkpoint_pack` does + if len(storage) == 0: + unpack_counter = 0 + + def replay_pack(tensor_from_replay): + """save recompute activations""" + nonlocal unpack_counter + unpack_counter += 1 + + if weak_holder_list[unpack_counter - 1]() is None: + return + + detached_activations = tensor_from_replay.detach() + storage[weak_holder_list[unpack_counter - 1]()] = detached_activations + + return + + def replay_unpack(none_value): + """recompute graph need not to backward""" + raise RuntimeError("You are calling backwards on a tensor that is never exposed.") + + global timers + see_memory_usage("In backward", force=False) + # removing pointers to the contiguous buffer memory + # so that they can be garbage collected once the checkpoints + # have been used + if SYNCHRONIZE: + get_accelerator().synchronize() + if PROFILE_TIME: + timers('backward').start() + + if CONTIGUOUS_CHECKPOINTING: + global data_offsets, size_offsets + global contiguous_data_buffers, contiguous_size_buffers + + for buffers in contiguous_data_buffers: + buffers = [] + + # frees up all the pointers to the checkpoints except for the ones + # stored by save for backward + contiguous_data_buffers = [] + contiguous_size_buffers = [] + data_offsets = [] + size_offsets = [] + + see_memory_usage("In backward checkpointing code", force=False) + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + + global PARTITION_ACTIVATIONS + cuda_device = get_accelerator().current_device_name() + transport_stream = get_accelerator().Stream(device=cuda_device) + + # gather inputs which is partitioned or checkpointed before first forward + if PARTITION_ACTIVATIONS: + # with get_accelerator().stream(transport_stream): + inputs = gather_partitioned_activations(deepspeed_saved_tensors, + device=cuda_device if CPU_CHECKPOINT else None) + detached_inputs = detach_variable(inputs) + elif CPU_CHECKPOINT: + inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint) + detached_inputs = detach_variable(inputs) + else: + inputs = deepspeed_saved_tensors + detached_inputs = detach_variable(inputs) + + # Add non tensor input args + detached_inputs = merge_tensors(tensor_objects=detached_inputs, + non_tensor_objects=non_tensor_args, + tensor_flags=tensor_flags) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = get_accelerator().get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(fwd_cpu_rng_state) + _set_cuda_rng_state(fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker) + + see_memory_usage("In backward checkpointing code before forward", force=False) + with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack): + _unused = function(*detached_inputs) + + see_memory_usage("In backward checkpointing code after forward", force=False) + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + deepspeed_saved_tensors = None + non_tensor_args = None + tensor_flags = None + + if holder_from_backward not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported.") + + return storage[holder_from_backward] + + def after_backward_hook(_nonuse_grads): + """the hook registered to all leaf tensors""" + nonlocal leaf_tensors, backward_visited_leaf_nodes + backward_visited_leaf_nodes += 1 + + if backward_visited_leaf_nodes == len(leaf_tensors): + see_memory_usage("After backward checkpointing code after backward", force=False) + + if PROFILE_TIME: + timers('backward').stop() + timers.log(['backward']) + if SYNCHRONIZE: + get_accelerator().synchronize() + + with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack): + outputs = function(*inputs_cuda) + if PROFILE_TIME or SYNCHRONIZE: + for leaf_tensor in leaf_tensors: + leaf_tensor.register_hook(after_backward_hook) + + see_memory_usage("After running forward on the layer", force=False) + + if PROFILE_TIME: + timers(FORWARD_GLOBAL_TIMER).stop() + timers.log([FORWARD_GLOBAL_TIMER]) + if SYNCHRONIZE: + get_accelerator().synchronize() + + all_outputs = [] + if torch.is_tensor(outputs): + all_outputs += [outputs] + else: + all_outputs += outputs + + if len(all_outputs) == 1: + return all_outputs[0] + else: + return tuple(all_outputs) + + +@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue def checkpoint(function, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint. """ @@ -867,6 +1107,27 @@ def configure( if CONTIGUOUS_CHECKPOINTING: assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing" + global mp_rank, mp_size, mp_group + + if mpu is not None: + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + mp_rank = mpu.get_tensor_model_parallel_rank() + mp_size = mpu.get_tensor_model_parallel_world_size() + mp_group = mpu.get_tensor_model_parallel_group() + else: + mp_rank = mpu.get_model_parallel_rank() + mp_size = mpu.get_model_parallel_world_size() + mp_group = mpu.get_model_parallel_group() + + #print configuration only once + see_memory_usage("After configuration", force=False) + if dist.get_rank() == 0: + logger.info("Activation Checkpointing Information") + logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}") + logger.info(f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers") + logger.info(f"----Synchronization {SYNCHRONIZE}") + logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") + def is_configured(): """True if deepspeed activation checkpointing has been configured diff --git a/deepspeed/runtime/activation_checkpointing/config.py b/deepspeed/runtime/activation_checkpointing/config.py index 13df4b981298..dc07388a95da 100755 --- a/deepspeed/runtime/activation_checkpointing/config.py +++ b/deepspeed/runtime/activation_checkpointing/config.py @@ -17,7 +17,7 @@ "partitioned_activations": [true|false], "number_checkpoints": 100, "contiguous_memory_optimization": [true|false], - "cpu_checkpointing": [true|false] + "cpu_checkpointing": [true|false], "profile": [true|false], "synchronize_checkpoint_boundary": [true|false], } diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py new file mode 100644 index 000000000000..c9dbfd0a4e81 --- /dev/null +++ b/deepspeed/runtime/base_optimizer.py @@ -0,0 +1,491 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +from typing import Any + +from deepspeed.utils import logger +from deepspeed.utils.tensor_fragment import map_to_flat_opt_states +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage +from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized +from deepspeed.runtime.utils import maybe_loss_for_backward + + +class DeepSpeedOptimizer(object): + pass + + +class BackwardHookStateManager: + """Manages backward pass state for ZeRO optimizers. + + This class handles the complex state management needed for gradient accumulation hooks + to work correctly with: + + 1. **Reentrant Gradient Checkpointing** (use_reentrant=True): + With reentrant checkpointing, gradient hooks fire in multiple phases within a + single backward() call. For example, with model: linear1 (checkpointed) -> linear2: + - Phase 1: Hooks for linear2 fire (non-checkpointed params) + - Checkpoint recomputes linear1's forward + - Phase 2: Hooks for linear1 fire (checkpointed params) + + The challenge is that `count_used_parameters_in_backward()` only sees params + currently in the backward graph. During Phase 1, it returns 2 (linear2's params), + but after checkpoint recomputation, it returns 4 (all params). We must NOT run + the epilogue prematurely after Phase 1. + + Solution: Queue a post-backward callback on the autograd engine at the start of + backward and run the epilogue when the graph task completes. This avoids premature + epilogues across reentrant phases. The `_max_expected_hooks_seen` counter remains + as a fallback when the callback API is unavailable. + + 2. **TiledFusedLogitsLoss and Similar Custom Autograd Functions**: + Some custom autograd functions call `torch.autograd.backward()` from their + forward pass BEFORE the user calls `engine.backward(loss)`. These internal + backward calls trigger ZeRO's gradient hooks, but we must NOT run the epilogue + until the user's actual backward pass. + + Solution: Track `_backward_active_depth` which is only incremented when + `enter_backward()` is called (from engine.backward or user code). Hooks check + this depth before running the epilogue. + + 3. **Multiple Backward Phases with Exit/Re-entry**: + When the epilogue runs after Phase 1 (with reentrant checkpointing), it calls + `exit_backward()`, setting `_backward_active_depth` to 0. When Phase 2's hooks + fire, we need to re-enter the backward context. + + Solution: `_backward_seen_this_step` flag tracks if backward was ever active + this step. Combined with `_backward_active_depth == 0`, this detects Phase 2 + and calls `enter_backward()` again. + + Attributes: + remaining_grad_acc_hooks: Count of hooks remaining before epilogue should run + backward_active_depth: Nesting depth of backward() calls (0 = not in backward) + backward_seen_this_step: True if enter_backward() was called this step + epilogue_ran_this_backward: True if epilogue ran (for micro_step_id management) + hooks_fired_this_backward: Count of gradient hooks that have fired + max_expected_hooks_seen: Maximum expected hook count seen (grows with reentrant) + post_backward_callback_queued: True if a post-backward callback is queued + post_backward_callback_graph_task_id: Graph task id for the queued callback + """ + + def __init__(self): + self.remaining_grad_acc_hooks = 0 + self._grad_acc_post_hooks = [] + self.backward_active_depth = 0 + self.backward_seen_this_step = False + self.epilogue_ran_this_backward = False + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None + + def register_grad_acc_post_hook(self, hook): + """Register a callback to run when all gradient hooks have fired.""" + self._grad_acc_post_hooks.append(hook) + + def unregister_grad_acc_post_hooks(self): + """Remove all registered gradient accumulation post hooks.""" + self._grad_acc_post_hooks = [] + + def run_grad_acc_post_hooks(self): + """Run all registered post hooks if backward is active. + + Custom autograd Functions (e.g., TiledFusedLogitsLoss) can invoke + `torch.autograd.backward()` from their *forward* pass before the user + ever calls `engine.backward(loss)`. Those early backward calls still + trigger ZeRO's grad hooks, but we must not run the engine's + post-backward logic (which reduces/clears grads) until the outer/user + backward is active. The depth guard filters out only those pre-user + invocations while still allowing backward calls that happen during + the real user backward. + """ + if self.backward_active_depth == 0: + return + for hook in self._grad_acc_post_hooks: + hook() + + def enter_backward(self): + """Enter backward context. Call at the start of backward pass.""" + # On first real backward entry of a step, reset counters that may have been + # polluted by pre-user-backward hooks (e.g. TiledFusedLogitsLoss calling + # torch.autograd.backward() from forward). Do NOT reset on reentrant + # phase re-entry (backward_seen_this_step == True) so phase-to-phase + # state remains intact. + if self.backward_active_depth == 0 and not self.backward_seen_this_step: + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + self.remaining_grad_acc_hooks = 0 + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None + self.backward_active_depth += 1 + # Track that backward has been active at some point in this step. + # This is used to detect subsequent gradient hook phases with reentrant checkpointing. + self.backward_seen_this_step = True + + def exit_backward(self): + """Exit backward context. Call at the end of backward pass.""" + if self.backward_active_depth > 0: + self.backward_active_depth -= 1 + + def reset_for_new_step(self): + """Reset state at the start of each forward/backward step.""" + self.backward_seen_this_step = False + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + self.epilogue_ran_this_backward = False + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None + + def should_refresh_expected_hook_count(self): + """Return True when count_used_parameters_in_backward() should be re-evaluated. + + Refresh is needed in two cases: + 1. First hook of a backward (or backward phase): hooks_fired == 0. + 2. A new reentrant phase started: remaining hooks exhausted, we exited + backward, but backward was active earlier this step. + + The predicate must be evaluated BEFORE reenter_backward_if_needed() + because re-entering changes backward_active_depth and hides the + phase-boundary signal. + """ + return (self.hooks_fired_this_backward == 0 + or (self.remaining_grad_acc_hooks == 0 and self.backward_active_depth == 0 + and self.backward_seen_this_step)) + + def reenter_backward_if_needed(self): + """Re-enter backward context for subsequent phases in reentrant checkpointing. + + With reentrant gradient checkpointing, gradient hooks can fire in multiple phases + within a single backward call. When the epilogue runs after a phase, it calls + exit_backward(), setting backward_active_depth to 0. When the next phase starts, + we need to re-enter backward. + + We detect subsequent phases by checking: + 1. remaining_grad_acc_hooks == 0 (epilogue ran or new backward) + 2. backward_active_depth == 0 (we've exited from previous phase) + 3. backward_seen_this_step == True (backward was active earlier) + + This distinguishes from TiledFusedLogitsLoss which calls backward() during forward - + in that case backward_seen_this_step is False because enter_backward() was never called. + """ + if self.remaining_grad_acc_hooks == 0: + if self.backward_active_depth == 0 and self.backward_seen_this_step: + self.enter_backward() + + def queue_post_backward_callback(self): + """Queue post-backward hooks to run after the current graph finishes.""" + if self.post_backward_callback_queued: + return True + if self.backward_active_depth == 0: + return False + + engine = getattr(torch.autograd.Variable, "_execution_engine", None) + if engine is None or not hasattr(engine, "queue_callback"): + return False + if not hasattr(torch._C, "_current_graph_task_id"): + return False + + graph_task_id = torch._C._current_graph_task_id() + if graph_task_id == -1: + return False + + def _run_post_backward(): + self.run_grad_acc_post_hooks() + + engine.queue_callback(_run_post_backward) + self.post_backward_callback_queued = True + self.post_backward_callback_graph_task_id = graph_task_id + return True + + def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): + """Update hook state after a gradient hook fires and run epilogue if all hooks have fired. + + With reentrant gradient checkpointing, count_used_parameters_in_backward() returns the + count of params that will execute in the current backward graph. This count grows as + checkpointed regions are recomputed. We track the MAXIMUM count seen to ensure we don't + run the epilogue until all params that will ever participate have been processed. + Counters are reset at forward() time via reset_for_new_step(). + + Args: + current_expected_count: The current expected number of hooks, from + count_used_parameters_in_backward() plus any leaf modules. + """ + self.hooks_fired_this_backward += 1 + self.max_expected_hooks_seen = max(self.max_expected_hooks_seen, current_expected_count) + + # Prefer running post-backward hooks via autograd engine callback when available. + # This avoids premature epilogues with reentrant checkpointing. + if self.queue_post_backward_callback(): + self.remaining_grad_acc_hooks = max(self.max_expected_hooks_seen - self.hooks_fired_this_backward, 0) + return + + # Fallback: Run epilogue only when we've processed ALL params that will participate. + # This is the maximum count we've seen (accounts for late-joining params + # from reentrant checkpointing) and also excludes unused params. + if self.hooks_fired_this_backward >= self.max_expected_hooks_seen: + self.remaining_grad_acc_hooks = 0 + self.run_grad_acc_post_hooks() + else: + self.remaining_grad_acc_hooks = self.max_expected_hooks_seen - self.hooks_fired_this_backward + + +class ZeROOptimizer(DeepSpeedOptimizer): + """Base class for ZeRO optimizer implementations (stages 1, 2, and 3).""" + + def __init__(self): + self._backward_hook_state = BackwardHookStateManager() + + # Delegate backward hook state management to the manager. + # These properties provide backward compatibility with code that accesses + # these attributes directly (e.g., in stage3.py and stage_1_and_2.py). + @property + def _remaining_grad_acc_hooks(self): + return self._backward_hook_state.remaining_grad_acc_hooks + + @_remaining_grad_acc_hooks.setter + def _remaining_grad_acc_hooks(self, value): + self._backward_hook_state.remaining_grad_acc_hooks = value + + @property + def _backward_active_depth(self): + return self._backward_hook_state.backward_active_depth + + @_backward_active_depth.setter + def _backward_active_depth(self, value): + self._backward_hook_state.backward_active_depth = value + + @property + def _backward_seen_this_step(self): + return self._backward_hook_state.backward_seen_this_step + + @_backward_seen_this_step.setter + def _backward_seen_this_step(self, value): + self._backward_hook_state.backward_seen_this_step = value + + @property + def _epilogue_ran_this_backward(self): + return self._backward_hook_state.epilogue_ran_this_backward + + @_epilogue_ran_this_backward.setter + def _epilogue_ran_this_backward(self, value): + self._backward_hook_state.epilogue_ran_this_backward = value + + @property + def _hooks_fired_this_backward(self): + return self._backward_hook_state.hooks_fired_this_backward + + @_hooks_fired_this_backward.setter + def _hooks_fired_this_backward(self, value): + self._backward_hook_state.hooks_fired_this_backward = value + + @property + def _max_expected_hooks_seen(self): + return self._backward_hook_state.max_expected_hooks_seen + + @_max_expected_hooks_seen.setter + def _max_expected_hooks_seen(self, value): + self._backward_hook_state.max_expected_hooks_seen = value + + @property + def _grad_acc_post_hooks(self): + return self._backward_hook_state._grad_acc_post_hooks + + @_grad_acc_post_hooks.setter + def _grad_acc_post_hooks(self, value): + self._backward_hook_state._grad_acc_post_hooks = value + + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + optim_sd = torch.load(optim_state_path, weights_only=False) + + self._load_global_state(optim_sd) + + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + if self.mpu is None: + logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.") + tp_world_size = 1 + else: + tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ + else self.mpu.get_tensor_model_parallel_world_size() + + for i, (param_group, + loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + steps = [] + + lp_groups = getattr(self, lp_groups_name) + for lp in lp_groups[i]: + if lp._hp_mapping is not None: + #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") + step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, + tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + steps.append(step) + + hp_param = param_group['params'][0] + assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" + if steps[0] is not None: + self.optimizer.state[hp_param]['step'] = steps[0] + + map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) + + for key, value in loaded_param_group.items(): + if key == 'params': + continue + param_group[key] = value + + def report_ipg_memory_usage(self, tag, param_elems, dtype=None): + dtypes = self.ipg_buckets.keys() if dtype is None else [dtype] + + for dt in dtypes: + bucket = self.ipg_buckets[dt] + elem_count = bucket.elements + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {dt} {bucket.elements} param {param_elems} max_percent {percent_of_bucket_size}" + ) + + def get_param_comm_dtype(self, param): + if is_autocast_initialized(): + return get_comm_dtype(param) + else: + return self.communication_data_type + + def needs_scaler(self) -> bool: + """ + Check if this optimizer requires loss scaling for correct backward pass. + + Returns True if any of the following conditions are met: + - Custom loss scaler is enabled + - torch.autocast gradient scaler is active (fp16 only) + - Dynamic loss scaling is enabled (fp16 with DeepSpeed's loss scaler) + + Returns False for bf16 or fp32, which don't require gradient scaling. + """ + return (self.custom_loss_scaler or self.torch_autocast_gradscaler is not None + or (hasattr(self, 'dynamic_loss_scale') and self.dynamic_loss_scale)) + + def scale_if_loss(self, value: Any) -> Any: + """ + Applies loss scaling to the input value if it is a loss tensor. + """ + if maybe_loss_for_backward(value): + if self.custom_loss_scaler: + return self.external_loss_scale * value + if self.torch_autocast_gradscaler: + return self.torch_autocast_gradscaler.scale(value) + # Only call loss_scaler if it exists (not present in BF16_Optimizer) + if hasattr(self, 'loss_scaler') and self.loss_scaler is not None: + return self.loss_scaler.scale_loss(value) + + return value + + def backward_prologue(self): + pass + + def backward_epilogue(self, **kwargs): + pass + + def backward(self, loss, **kwargs): + assert maybe_loss_for_backward(loss), "Optimizer's backward() only accepts a scalar tensor" + + scaled_loss = self.backward_prologue(loss) + retain_graph = kwargs.pop('retain_graph', False) + self.enter_backward() + scaled_loss.backward(retain_graph=retain_graph) + self.backward_epilogue() + self.exit_backward() + + def register_grad_acc_post_hook(self, hook): + """Register a callback to run when all gradient hooks have fired.""" + self._backward_hook_state.register_grad_acc_post_hook(hook) + + def unregister_grad_acc_post_hooks(self): + """Remove all registered gradient accumulation post hooks.""" + self._backward_hook_state.unregister_grad_acc_post_hooks() + + def run_grad_acc_post_hooks(self): + """Run all registered post hooks if backward is active.""" + self._backward_hook_state.run_grad_acc_post_hooks() + + def enter_backward(self): + """Enter backward context. Call at the start of backward pass.""" + self._backward_hook_state.enter_backward() + + def exit_backward(self): + """Exit backward context. Call at the end of backward pass.""" + self._backward_hook_state.exit_backward() + + def clear_backward_seen_flag(self): + """Clear the backward seen flag and reset hook counters at the start of each step.""" + self._backward_hook_state.reset_for_new_step() + + def should_refresh_expected_hook_count(self): + """Return True when count_used_parameters_in_backward() should be re-evaluated.""" + return self._backward_hook_state.should_refresh_expected_hook_count() + + def reenter_backward_if_needed(self): + """Re-enter backward context for subsequent phases in reentrant checkpointing.""" + self._backward_hook_state.reenter_backward_if_needed() + + def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): + """Update hook state after a gradient hook fires and run epilogue if all hooks have fired.""" + self._backward_hook_state.update_hook_state_and_maybe_run_epilogue(current_expected_count) + + def queue_post_backward_callback(self): + """Queue post-backward hooks to run after autograd completes.""" + return self._backward_hook_state.queue_post_backward_callback() + + def _configure_master_weights(self, + fp16_master_weights_and_gradients=False, + bf16_master_weights_and_gradients=False, + bf16_optimizer_states=False, + offload_enabled=False, + fp16_offload_validator=None, + bf16_offload_validator=None): + """ + Common validation and dtype selection for ZeRO optimizer master-weight settings. + Optionally accepts callables that enforce backend-specific offload requirements. + ``offload_enabled`` tells this method whether optimizer-state offload is configured, + so the offload requirement is also enforced for the bf16-optimizer-states + offload case. + """ + self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients + self.bf16_master_weights_and_gradients = bf16_master_weights_and_gradients + assert not (self.fp16_master_weights_and_gradients and self.bf16_master_weights_and_gradients), \ + "fp16_master_weights_and_gradients and bf16_master_weights_and_gradients are mutually exclusive." + + self.bf16_optimizer_states = bf16_optimizer_states + if self.bf16_optimizer_states: + assert self.bf16_master_weights_and_gradients, \ + "bf16_optimizer_states requires bf16_master_weights_and_gradients." + + # bf16 master weights require ZeRO-Offload + DeepSpeedCPUAdam whenever the optimizer states + # cannot stay on the GPU: either because they remain fp32 (bf16_optimizer_states disabled), + # or because CPU offload is explicitly requested alongside bf16 optimizer states. + if (self.bf16_master_weights_and_gradients and bf16_offload_validator is not None + and (not self.bf16_optimizer_states or offload_enabled)): + bf16_offload_validator() + # Offloaded bf16 optimizer states need the CPU optimizer to store moments in the + # parameter (bf16) precision; otherwise they would silently expand back to fp32. + if self.bf16_optimizer_states: + assert not getattr(self.optimizer, 'fp32_optimizer_states', True), \ + "bf16_optimizer_states with ZeRO-Offload requires DeepSpeedCPUAdam constructed " \ + "with fp32_optimizer_states=False so optimizer moments are stored in bf16." + + if self.fp16_master_weights_and_gradients and fp16_offload_validator is not None: + fp16_offload_validator() + + if self.fp16_master_weights_and_gradients: + return torch.float16 + elif self.bf16_master_weights_and_gradients: + return torch.bfloat16 + else: + return torch.float32 diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f36f67900b86..746618fb5bd9 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -6,20 +6,21 @@ from collections import OrderedDict import torch import sys -import os +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED -from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime import ZeROOptimizer +from deepspeed.runtime.base_optimizer import ZeROOptimizer from packaging import version as pkg_version - from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, - align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage) - -from deepspeed.utils import link_hp_params, fragment_address + align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter, + see_memory_usage, graph_process, get_norm_with_moe_layers) +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups +from deepspeed.moe.utils import is_moe_param, is_moe_param_group +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank +from deepspeed.utils.torch import register_grad_hook from deepspeed.checkpoint import enable_universal_checkpoint +from deepspeed.checkpoint.constants import UNIVERSAL_CHECKPOINT_INFO from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, PARAM_SLICE_MAPPINGS) @@ -27,17 +28,26 @@ setattr(sys.modules[__name__], 'fragment_address', fragment_address) +def print_rank_0(message, debug=False, force=False): + if dist.get_rank() == 0 and (debug or force): + print(message) + + class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, param_names, + bfloat16_config, mpu=None, clip_grad=0.0, norm_type=2, allgather_bucket_size=5000000000, dp_process_group=None, - timers=None): + timers=None, + grad_acc_dtype=None, + graph_harvesting=False, + has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -45,18 +55,33 @@ def __init__(self, self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) + assert bfloat16_config.enabled, "BF16Optimizer: requires bfloat16 to be enabled" + assert grad_acc_dtype in [torch.float32, torch.bfloat16 + ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" + self.grad_acc_dtype = grad_acc_dtype + + # BF16 doesn't use loss scaling, but these attributes are needed for API compatibility + self.custom_loss_scaler = False + self.external_loss_scale = None + self.torch_autocast_gradscaler = None + + self.immediate_grad_update = bfloat16_config.immediate_grad_update + self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) + self.has_moe_layers = has_moe_layers + self.non_expert_gradients = [] self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))] + if self.has_moe_layers: + self._configure_moe_settings() - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten + # Use torch (un)flatten ops + self.flatten = _flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors #align nccl all-gather send buffers to 4-bye boundary self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 @@ -76,40 +101,65 @@ def __init__(self, self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] - self.step_count = 0 self.group_paddings = [] - + self.graph_harvesting = graph_harvesting if self.using_real_optimizer: self._setup_for_real_optimizer() - see_memory_usage('end bf16_optimizer', force=True) + see_memory_usage('end bf16_ optimizer', force=True) + + def destroy(self): + if not self.using_real_optimizer: + return + for i, _ in enumerate(self.optimizer.param_groups): + for p in self.bf16_groups[i]: + if getattr(p, '_hp_mapping', None): + p._hp_mapping = None + for hook in self._grad_acc_hooks: + hook.remove() + print_rank_0("Removed grad acc hooks") + + def _configure_moe_settings(self): + assert any( + [is_moe_param_group(group) for group in self.optimizer.param_groups] + ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" + + for i, group in enumerate(self.optimizer.param_groups): + if is_moe_param_group(group): + assert all([is_moe_param(param) + for param in group['params']]), "All params in MoE group must be MoE params" + self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name']) + self.expert_gradients = {} + if self.has_moe_layers: + for key in groups._get_expert_data_parallel_group_dict().keys(): + self.expert_gradients[key] = [] def _setup_for_real_optimizer(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) - self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))] + self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group] for i, param_group in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) # grab the original list - self.bf16_groups.append(param_group['params']) + trainable_parameters = [param for param in param_group['params'] if param.requires_grad] + self.bf16_groups.append(trainable_parameters) # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned(self.bf16_groups[i], - self.nccl_start_alignment_factor * dp_world_size)) - + self.nccl_start_alignment_factor * real_dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) - for dp_index in range(dp_world_size) + for dp_index in range(real_dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) @@ -120,7 +170,12 @@ def _setup_for_real_optimizer(self): num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients - self.fp32_groups_gradients_flat.append(torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32)) + fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype) + self.fp32_groups_gradients_flat.append(fp32_flat_buffer) + if self.has_moe_layers and is_moe_param_group(param_group): + self.expert_gradients[param_group['name']].append(fp32_flat_buffer) + else: + self.non_expert_gradients.append(fp32_flat_buffer) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i], @@ -153,19 +208,30 @@ def _setup_for_real_optimizer(self): see_memory_usage(f'after initializing group {i}', force=True) - see_memory_usage('before initialize_optimizer', force=True) - self.initialize_optimizer_states() - see_memory_usage('end initialize_optimizer', force=True) + self._grad_acc_hooks = [] + if self.immediate_grad_update: + self.create_grad_acc_hooks() # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._hp_optimizer_states_linked = False self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() def _enable_universal_checkpoint(self): + self._universal_checkpoint_info = None for lp_param_group in self.bf16_groups: + if self._universal_checkpoint_info is None: + for param in lp_param_group: + autotp_uc_info = getattr(param, UNIVERSAL_CHECKPOINT_INFO, None) + if autotp_uc_info is not None: + self._universal_checkpoint_info = autotp_uc_info + break enable_universal_checkpoint(param_list=lp_param_group) + def _get_universal_checkpoint_info(self): + return getattr(self, '_universal_checkpoint_info', None) + def _create_param_mapping(self): param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): @@ -179,11 +245,12 @@ def _create_param_mapping(self): return param_mapping def _link_all_hp_params(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, _ in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) + # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size flat_hp_partition = self.fp32_groups_flat_partition[i] link_hp_params(lp_param_list=self.bf16_groups[i], flat_hp_partition=flat_hp_partition, @@ -193,23 +260,14 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) - def initialize_optimizer_states(self): - """Take an optimizer step with zero-valued gradients to allocate internal - optimizer state. - - This helps prevent memory fragmentation by allocating optimizer state at the - beginning of training instead of after activations have been allocated. - """ - for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, - self.fp32_groups_gradient_flat_partition): - param_partition.grad = grad_partition - - self.optimizer.step() - - self.clear_hp_grads() + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True def _split_flat_tensor(self, flat_tensor, num_elem_list): assert sum(num_elem_list) <= flat_tensor.numel() @@ -235,9 +293,18 @@ def step(self, closure=None): if closure is not None: raise NotImplementedError(f'{self.__class__} does not support closure.') - all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), - mpu=self.mpu, - norm_type=self.norm_type) + non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm() + non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm, + mpu=self.mpu, + norm_type=self.norm_type, + use_graph=self.graph_harvesting) + all_groups_norm = non_expert_groups_norm + if self.has_moe_layers: + all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) + self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -245,72 +312,125 @@ def step(self, closure=None): clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad, global_norm=all_groups_norm, - mpu=self.mpu) + mpu=self.mpu, + use_graph=self.graph_harvesting) + + for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, + self.fp32_groups_gradient_flat_partition): + # In case of grad acc dtype different than FP32, need to cast to high precision. + param_partition.grad = grad_partition.to( + param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition self.optimizer.step() - self.update_lp_params() + if self.grad_acc_dtype is not torch.float32: + for param_partition in self.fp32_groups_flat_partition: + param_partition.grad = None - self.clear_hp_grads() - self.step_count += 1 + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() - def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): - """Perform a backward pass and copy the low-precision gradients to the - high-precision copy. + self.update_lp_params() - We copy/accumulate to the high-precision grads now to prevent accumulating in the - bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1) + self.clear_hp_grads() - The low-precision grads are deallocated during this procedure. - """ + def backward_prologue(self): self.clear_lp_grads() - loss.backward(**bwd_kwargs) + def backward_epilogue(self, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): if update_hp_grads: self.update_hp_grads(clear_lp_grads=clear_lp_grads) + @torch.no_grad() + def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): + if lp.grad is None: + return + + hp_grad = self.fp32_groups_gradients[group_idx][param_idx] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]' + + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[group_idx][param_idx] = True + + # clear gradients + if clear_lp_grads: + lp.grad.zero_() + + @torch.no_grad() + def _update_hp_grads_func(self, clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + self._update_hp_grad(lp, i, j, clear_lp_grads) + @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): + if self.immediate_grad_update: + return + + if self.graph_harvesting: + graph_process(False, self._update_hp_grads_func, clear_lp_grads) + else: + self._update_hp_grads_func(clear_lp_grads) + #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue - - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad = None - @torch.no_grad() def get_grads_for_reduction(self): - return self.fp32_groups_gradients_flat + if self.has_moe_layers: + return self.non_expert_gradients, self.expert_gradients + return self.non_expert_gradients, {} @torch.no_grad() def get_grads_for_norm(self, for_clipping=False): - grads = [] + """ + Returns: + tuple[list[Tensor], dict[ep_name, List[Tensor]] | list: + If for_clipping, return all gradients. + Otherwise, separate and return dict of expert_grad and list of non_expert_grad + """ + # (grads, expert_group_name) + expert_grads_for_norm = {} + + # grads + non_expert_grads_for_norm = [] + all_grads_for_clip = [] + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + assert len(self.bf16_groups) == len(self.optimizer.param_groups) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if not for_clipping: if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue - if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)): + # skip duplicated parameters. perform norm only on cards with tp_rank=0. + # non-duplicated parameters include: + # - Parameters with tp: Use allreducesum of mp_group. + # - Moe Parameters with ep: Use allreducesum of ep_group. + if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)): continue if not self.fp32_groups_has_gradients[i][j]: continue - - grads.append(self.fp32_groups_gradients[i][j]) - - return grads + if not for_clipping: + param_group = self.optimizer.param_groups[i] + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j]) + else: + non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j]) + else: + all_grads_for_clip.append(self.fp32_groups_gradients[i][j]) + if not for_clipping: + return non_expert_grads_for_norm, expert_grads_for_norm + return all_grads_for_clip @torch.no_grad() def update_lp_params(self): @@ -318,11 +438,9 @@ def update_lp_params(self): fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) - all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, + all_gather_dp_groups(groups_flat=self.bf16_groups_flat, + partitioned_param_groups=self.bf16_partitioned_groups, dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) @@ -334,10 +452,27 @@ def clear_hp_grads(self): for i, group in enumerate(self.fp32_groups_gradients): self.fp32_groups_has_gradients[i] = [False] * len(group) - def clear_lp_grads(self): + def clear_lp_grads(self, set_to_none=False): + + # using zero_() fixed memory address for graph replay + if self.graph_harvesting: + assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None" + + zero_grads_list = [] for group in self.bf16_groups: for param in group: - param.grad = None + if set_to_none: + param.grad = None + elif param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + zero_grads_list.append(param.grad) + if not set_to_none and len(zero_grads_list) > 0: + torch._foreach_zero_(zero_grads_list) + + def zero_grad(self, set_to_none=True): + self.clear_lp_grads(set_to_none) + self.clear_hp_grads() def state_dict(self): state_dict = {} @@ -349,23 +484,29 @@ def state_dict(self): state_dict[DS_VERSION] = version state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings + autotp_uc_info = self._get_universal_checkpoint_info() + if autotp_uc_info is not None: + state_dict[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info + return state_dict # Restore base optimizer fp32 weights bfloat16 weights def _restore_from_bit16_weights(self): - for i, group in enumerate(self.bf16_groups): + for i, (bf16_partitions, + fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition): - fp32_partition.data.copy_(bf16_partitions[partition_id].data) + fp32_partition.data.copy_(bf16_partitions[partition_id].data) def refresh_fp32_params(self): self._restore_from_bit16_weights() def load_state_dict(self, state_dict_list, - checkpoint_folder, + checkpoint_folder=None, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + load_serial=None, + param_shapes=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: @@ -377,12 +518,13 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l current_rank_sd = state_dict_list[dp_rank] ckpt_version = current_rank_sd.get(DS_VERSION, False) - assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" + assert ckpt_version, "Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) if load_optimizer_states: + print("_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]") self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) if load_from_fp32_weights: @@ -395,24 +537,38 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._link_all_hp_params() def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self._load_hp_checkpoint_state(checkpoint_folder) + self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder) + + def _load_global_state(self, sd): + pass @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups - def _load_hp_checkpoint_state(self, checkpoint_dir): - checkpoint_dir = os.path.join(checkpoint_dir, "zero") - tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_slice_parallel_world_size() + @property + def state(self): + """Forward the wrapped optimizer's states.""" + return self.optimizer.state - for i, _ in enumerate(self.optimizer.param_groups): - for lp in self.bf16_groups[i]: - if lp._hp_mapping is not None: - #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") - lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, - tp_world_size) + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): + assert self.immediate_grad_update + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) + + def create_grad_acc_hooks(self): + for i, param_group in enumerate(self.bf16_groups): + for j, param in enumerate(param_group): + if param.requires_grad: + + def wrapper(param, i, j): + + def accumulate_hp_grads_and_remove_lp(*notneeded): + self.accumulate_hp_grads_and_remove_lp(param, i, j) + + self._grad_acc_hooks.append(register_grad_hook(param, accumulate_hp_grads_and_remove_lp)) + + wrapper(param, i, j) def _get_padded_tensor(src_tensor, size): diff --git a/deepspeed/runtime/checkpoint_engine/README.md b/deepspeed/runtime/checkpoint_engine/README.md index a19f54889802..5ee2e365504f 100644 --- a/deepspeed/runtime/checkpoint_engine/README.md +++ b/deepspeed/runtime/checkpoint_engine/README.md @@ -20,7 +20,7 @@ class CheckpointEngine(object): def __init__(self, config_params=None): pass - def create(self, tag): + def create(self, info:CheckpointCommitInfo): # create checkpoint on give tag for save/load. pass @@ -30,7 +30,7 @@ class CheckpointEngine(object): def load(self, path: str, map_location=None): pass - def commit(self, tag): + def commit(self, info:CheckpointCommitInfo): # to tell checkpoint services if all files are readys. pass diff --git a/deepspeed/runtime/checkpoint_engine/__init__.py b/deepspeed/runtime/checkpoint_engine/__init__.py index 6c5067f71c8f..4661658ea257 100644 --- a/deepspeed/runtime/checkpoint_engine/__init__.py +++ b/deepspeed/runtime/checkpoint_engine/__init__.py @@ -3,3 +3,10 @@ # DeepSpeed Team '''Copyright The Microsoft DeepSpeed Team''' + +from .fast_checkpoint_engine import FastCheckpointEngine +from .torch_checkpoint_engine import TorchCheckpointEngine +from .decoupled_checkpoint_engine import DecoupledCheckpointEngine +from .checkpoint_engine import CheckpointCommitInfo +from .datastates_checkpoint_engine import DataStatesCheckpointEngine +from .utils import create_checkpoint_engine diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index 3f8978df0316..ccd4d7215823 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -5,26 +5,59 @@ import os +import abc +from abc import ABC -class CheckpointEngine(object): +from dataclasses import dataclass + +@dataclass +class CheckpointCommitInfo(object): + tag: str + save_dir: str + save_latest: bool + + +class CheckpointEngine(ABC): # init checkpoint engine for save/load def __init__(self, config_params=None): - pass + self.name = None - def create(self, tag): + @abc.abstractmethod + def create(self, info: CheckpointCommitInfo): # create checkpoint on give tag for save/load. - pass + ... + + @abc.abstractmethod + def save(self, state_dict, path: str): + ... def makedirs(self, path, exist_ok=False): os.makedirs(path, exist_ok=exist_ok) - def save(self, state_dict, path: str): - pass - + @abc.abstractmethod def load(self, path: str, map_location=None): + ... + + @abc.abstractmethod + def commit(self, info: CheckpointCommitInfo): + # to tell checkpoint services if all files are ready. + ... + + def is_data_parallel_writer(self, dp_rank): + return dp_rank == 0 + + def is_decoupled(self): + return False + + def set_commit_info(self, info: CheckpointCommitInfo): pass - def commit(self, tag): - # to tell checkpoint services if all files are readys. + def get_commit_info(self): + return None + + def cleanup(self): pass + + def preserves_storage_sharing(self): + return True diff --git a/deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py new file mode 100644 index 000000000000..f131a0925957 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory. + +# DeepSpeed Team + +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine, CheckpointCommitInfo + +ENGINE_NAME = "DataStatesCheckpointEngine" + + +class DataStatesCheckpointEngine(CheckpointEngine): + + def __init__(self, deepspeed_config, rank): + super().__init__(deepspeed_config) + self.commit_info = None + self.ckpt_engine = None + try: + from datastates import CheckpointEngine as DataStatesEngine + self.ckpt_engine = DataStatesEngine(deepspeed_config, rank) + except ImportError: + raise RuntimeError("Please install DataStates from https://github.com/DataStates/datastates-llm.") + except Exception as e: + raise RuntimeError(f"An error occurred while initializing DataStates Checkpoint Engine: {e}") + + def __del__(self): + self.cleanup() + + def create(self, info: CheckpointCommitInfo): + self.commit_info = info + return None + + def save(self, state_dict, path: str): + return self.ckpt_engine.save(state_dict, path) + + def load(self, path: str, map_location=None): + return self.ckpt_engine.load(path, map_location) + + def commit(self, info: CheckpointCommitInfo): + if info is None: + return + assert info == self.commit_info + self.ckpt_engine.wait(persist=True) + self.commit_info = None + return True + + def cleanup(self): + self.commit(self.commit_info) + if self.ckpt_engine: + self.ckpt_engine.wait(persist=True) + del self.ckpt_engine + + def is_decoupled(self): + return True + + def preserves_storage_sharing(self): + return False diff --git a/deepspeed/runtime/checkpoint_engine/decoupled_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/decoupled_checkpoint_engine.py new file mode 100644 index 000000000000..ceae5f0eae32 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/decoupled_checkpoint_engine.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.multiprocessing as mp +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine, CheckpointCommitInfo +from deepspeed.runtime.checkpoint_engine.fast_checkpoint_engine import FastCheckpointEngine +from deepspeed import comm as dist +from deepspeed.runtime.utils import get_checkpoint_folder_size +from deepspeed.utils import logger + +from enum import Enum + + +class DecoupledEvent(Enum): + SAVE_EVENT = 1 + COMMIT_EVENT = 2 + EXIT_EVENT = 3 + + +class CheckpointSize(object): + + def __init__(self): + self._pre = None + self._post = None + self._gigabytes = None + + def gb_size(self): + return self._gigabytes + + def set_pre_size(self, size): + self._pre = size + + def set_post_size(self, size): + self._post = size + self._gigabytes = (self._post - self._pre) / (1024**3) + + +def init_decoupled_checkpoint(config_params, dp_writer_config, save_event, save_queue, optimize_dp_state): + try: + checkpoint_engine = FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state) + print('Created FastCheckpointEngine for Decoupled Checkpointing') + save_path_list = [] + while True: + (save_info, event_type) = save_queue.get() + if event_type == DecoupledEvent.SAVE_EVENT and save_info is not None: + state_dict, save_path = save_info + # print(f'Received decoupled checkpoint request for {save_path=}') + save_path_list.append(save_path) + checkpoint_engine.save(state_dict, save_path) + del state_dict + # print(f'Completed decoupled checkpoint request for {save_path=}') + + if event_type == DecoupledEvent.COMMIT_EVENT: + # print(f'Recieved commit request for {save_path_list=}') + save_path_list = [] + save_event.set() + + if event_type == DecoupledEvent.EXIT_EVENT: + # print(f'Received decoupled exit request') + break + except Exception as e: + print(f'[{ENGINE_NAME}] Checkpoint subprocess crashed with error: {e}') + raise + + +ENGINE_NAME = "DecoupledCheckpointEngine" + +# Default timeout for checkpoint operations (5 minutes) +DEFAULT_CHECKPOINT_TIMEOUT_SECONDS = 300 +# Interval for checking process health while waiting +PROCESS_HEALTH_CHECK_INTERVAL_SECONDS = 10 + + +class DecoupledCheckpointEngine(CheckpointEngine): + + def __init__(self, config_params, dp_writer_config, optimize_dp_state): + # Set spawn method if not already set (needed for CUDA tensor sharing) + try: + mp.set_start_method('spawn') + except RuntimeError: + pass # Already set, ignore + super().__init__(config_params) + self.name = ENGINE_NAME + self.dp_writer_config = dp_writer_config + self.commit_info = None + self.checkpoint_size = CheckpointSize() + self.global_rank = dist.get_rank() + self.optimize_dp_state = optimize_dp_state + self._cleanup_called = False + if dp_writer_config is None: + self.save_event = None + self.save_queue = None + self.ckpt_process = None + self.local_rank = None + print( + f'[{ENGINE_NAME}]: No checkpoint process self.global_rank={self.global_rank} self.dp_writer_config={self.dp_writer_config}' + ) + else: + self.save_event = mp.Event() + self.save_queue = mp.SimpleQueue() + engine_args = (config_params, dp_writer_config, self.save_event, self.save_queue, self.optimize_dp_state) + self.ckpt_process = mp.Process(target=init_decoupled_checkpoint, args=engine_args) + self.ckpt_process.start() + self.local_rank = dp_writer_config.local_rank + print( + f'[{ENGINE_NAME}]: Create checkpoint process self.global_rank={self.global_rank} self.ckpt_process.pid={self.ckpt_process.pid} self.dp_writer_config={self.dp_writer_config}' + ) + + def __del__(self): + try: + self.cleanup() + except Exception: + # Suppress exceptions in destructor to avoid crashes during shutdown + pass + + def _check_process_alive(self): + """Check if the checkpoint process is still alive. + + Note: Only call this when self.ckpt_process is not None. + Some ranks don't have a checkpoint process by design (see Figure 6 in paper). + """ + return self.ckpt_process.is_alive() + + def _wait_for_event_with_timeout(self, timeout_seconds=DEFAULT_CHECKPOINT_TIMEOUT_SECONDS): + """Wait for save_event with timeout and process health checks. + + Returns True if event was set, raises RuntimeError if process died or timeout occurred. + """ + elapsed = 0 + while elapsed < timeout_seconds: + if self.save_event.wait(timeout=PROCESS_HEALTH_CHECK_INTERVAL_SECONDS): + return True + elapsed += PROCESS_HEALTH_CHECK_INTERVAL_SECONDS + + # Check if process is still alive + if not self._check_process_alive(): + raise RuntimeError(f"[{ENGINE_NAME}] Checkpoint process died unexpectedly. " + f"Check logs for OOM or other errors in the checkpoint subprocess.") + + raise RuntimeError(f"[{ENGINE_NAME}] Checkpoint commit timed out after {timeout_seconds} seconds. " + f"Process alive: {self._check_process_alive()}") + + def create(self, info: CheckpointCommitInfo): + self.commit_info = info + if self.checkpoint_size.gb_size() is None: + pre_size = get_checkpoint_folder_size(info.save_dir, info.tag, self.local_rank) + self.checkpoint_size.set_pre_size(pre_size) + + def load(self, path: str, map_location=None): + sd = torch.load(path, map_location=map_location) + return sd + + def save(self, state_dict, path: str): + if self.ckpt_process is None: + return + + # Check process health before attempting to save + if not self._check_process_alive(): + return + + save_info = (state_dict, path) + self.save_queue.put((save_info, DecoupledEvent.SAVE_EVENT)) + + def commit(self, info: CheckpointCommitInfo): + # Use proper validation instead of assert (assert is disabled with python -O) + if info != self.commit_info: + raise ValueError(f"[{ENGINE_NAME}] Checkpoint commit info mismatch: " + f"expected {self.commit_info}, got {info}") + + if self.ckpt_process is not None: + # Check process health before waiting + if not self._check_process_alive(): + raise RuntimeError(f"[{ENGINE_NAME}] Cannot commit checkpoint: checkpoint process is not running.") + + self.save_queue.put((None, DecoupledEvent.COMMIT_EVENT)) + # Wait with timeout and health checks instead of blocking forever + self._wait_for_event_with_timeout() + self.save_event.clear() + + self.commit_info = None + + if self.checkpoint_size.gb_size() is None: + dist.barrier() + post_size = get_checkpoint_folder_size(info.save_dir, info.tag, self.local_rank) + self.checkpoint_size.set_post_size(post_size) + + assert self.checkpoint_size.gb_size() is not None, "Checkpoint size should be set after commit" + + if self.global_rank == 0: + print( + f'{self.name} self.global_rank={self.global_rank} created checkpoint of {round(self.checkpoint_size.gb_size(), 2)} GB' + ) + + return True + + def get_commit_info(self): + # print(f'getting commit info {self.commit_info=}') + return self.commit_info + + def is_decoupled(self): + return True + + def cleanup(self): + # Prevent multiple cleanup calls (especially from __del__) + if self._cleanup_called: + return + self._cleanup_called = True + + try: + if self.get_commit_info() is not None: + self.commit(self.commit_info) + except Exception as e: + logger.warning(f"[{ENGINE_NAME}] Error during commit in cleanup: {e}") + + if self.ckpt_process is not None: + try: + self.save_queue.put((None, DecoupledEvent.EXIT_EVENT)) + except Exception: + pass # Queue may be broken if process died + + # Join with timeout to avoid hanging forever + self.ckpt_process.join(timeout=DEFAULT_CHECKPOINT_TIMEOUT_SECONDS) + + # If process didn't exit, terminate it forcefully + if self.ckpt_process.is_alive(): + logger.warning( + f"[{ENGINE_NAME}] Checkpoint process did not exit within timeout, terminating forcefully.") + self.ckpt_process.terminate() + self.ckpt_process.join(timeout=5) # Brief wait after terminate + + # Last resort: kill + if self.ckpt_process.is_alive(): + self.ckpt_process.kill() + + self.ckpt_process = None + self.save_queue = None + + def is_data_parallel_writer(self, dp_rank): + return self.ckpt_process is not None diff --git a/deepspeed/runtime/checkpoint_engine/fast_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/fast_checkpoint_engine.py new file mode 100644 index 000000000000..4bfecf810dc7 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/fast_checkpoint_engine.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine, CheckpointCommitInfo +from deepspeed.runtime.model_checkpointing import ( + CHECKPOINT_WRITER, + CHECKPOINT_SERIALIZATION, + CheckpointWriterFactory, +) + + +class FastCheckpointEngine(CheckpointEngine): + + def __init__(self, config_params, dp_writer_config, optimize_dp_state): + super().__init__(config_params) + self.name = 'FastCheckpointEngine' + self.serialization_enabled = config_params.checkpoint_config[CHECKPOINT_SERIALIZATION] + self.optimize_dp_state = optimize_dp_state + if dp_writer_config is None: + self._writer = None + else: + self._writer = CheckpointWriterFactory(writer_config=config_params.checkpoint_config[CHECKPOINT_WRITER], + aio_config=config_params.aio_config, + dp_writer_config=dp_writer_config) + + def create(self, info: CheckpointCommitInfo): + pass + + def save(self, state_dict, path: str): + if self._writer is None: + return + + torch.save(obj=state_dict, + f=self._writer.create_writer(path, self.optimize_dp_state), + _use_new_zipfile_serialization=self.serialization_enabled) + self._writer.release_writer() + + def load(self, path: str, map_location=None): + sd = torch.load(path, map_location=map_location) + return sd + + def commit(self, info: CheckpointCommitInfo): + return True + + def is_data_parallel_writer(self, dp_rank): + return self._writer is not None diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 4842d73d8a66..9e18c8c24520 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -8,7 +8,7 @@ import torch_nebula from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ - CheckpointEngine + CheckpointEngine, CheckpointCommitInfo from deepspeed.utils import logger, log_dist from deepspeed.nebula.constants import * @@ -21,6 +21,7 @@ class NebulaCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): super().__init__(config_params) + self.name = "NebulaCheckpointEngine" self.checkpoint = None self.tag_flag = None self.enable_nebula_load = config_params.enable_nebula_load @@ -35,14 +36,14 @@ def __init__(self, config_params=None): } torch_nebula.init(**nebula_config_params) - def create(self, tag): - log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0]) + def create(self, info: CheckpointCommitInfo): + log_dist(f"[Nebula] Start Checkpoint for tag:{info.tag}", ranks=[0]) # -2 means: customer needs to explicitly tell nebula # current checkpoint is complete by commit methond. - self.checkpoint = torch_nebula.Checkpoint(tag, -2) + self.checkpoint = torch_nebula.Checkpoint(info.tag, -2) def save(self, state_dict, path: str): - log_dist(f"[Nebula] Create dummy files for loading.") + log_dist("[Nebula] Create dummy files for loading.") torch.save("", path) tag = _get_tag_from_path(path) @@ -50,7 +51,6 @@ def save(self, state_dict, path: str): logger.info(f"[Nebula] Saving {partititon_name} under tag {tag}...") self.checkpoint.save(partititon_name, state_dict) logger.info(f"[Nebula] Saved {partititon_name} under tag {tag}.") - return None def load(self, path: str, map_location=None): tag = _get_tag_from_path(path) @@ -58,11 +58,11 @@ def load(self, path: str, map_location=None): if not self.enable_nebula_load and first_load_flag: self.tag_flag = tag logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...") - partition = torch.load(path, map_location=map_location) + partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") return partition - partititon_name = os.path.basename(path) + partition_name = os.path.basename(path) logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...") checkpoint = None @@ -84,7 +84,7 @@ def load(self, path: str, map_location=None): checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path) if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): logger.info( - f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!" + "Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!" ) # nebula tier1 latest checkpoint = torch_nebula.get_latest_checkpoint() @@ -93,15 +93,16 @@ def load(self, path: str, map_location=None): tag = checkpoint.tag self.tag_flag = -1 - partition = checkpoint.load(partititon_name, map_location=map_location) + partition = checkpoint.load(partition_name, map_location=map_location) logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.") return partition - def commit(self, tag): + def commit(self, info: CheckpointCommitInfo): + tag = info.tag # nebula commit will be call when all files under give tag are ready to be persisted in the async way. logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting") commit_rls = self.checkpoint.commit() if not commit_rls: - logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.") + logger.error("[Nebula] failed to commit the checkpoint, please check the log.") return False return commit_rls diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 5cd44864bb2e..ce96e1f47ec6 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -4,31 +4,40 @@ # DeepSpeed Team import torch -from deepspeed.utils import logger, log_dist +from deepspeed.utils import log_dist from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ - CheckpointEngine + CheckpointEngine, CheckpointCommitInfo +from deepspeed.runtime.model_checkpointing import CHECKPOINT_SERIALIZATION + +ENGINE_NAME = "TorchCheckpointEngine" class TorchCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): super().__init__(config_params) + self.name = ENGINE_NAME + if config_params is None: + self.zipfile_serialization = False + else: + self.zipfile_serialization = config_params.checkpoint_config[CHECKPOINT_SERIALIZATION] + log_dist(f'[{ENGINE_NAME}] Initialized with serialization = {self.zipfile_serialization}', ranks=[0]) - def create(self, tag): - log_dist(f"[Torch] Checkpoint {tag} is about to be saved!", ranks=[0]) + def create(self, info: CheckpointCommitInfo): + log_dist(f"[Torch] Checkpoint {info.tag} is about to be saved!", ranks=[0]) + pass def save(self, state_dict, path: str): - logger.info(f"[Torch] Saving {path}...") - torch.save(state_dict, path) - logger.info(f"[Torch] Saved {path}.") - return None + # log_dist(f"[Torch] Saving [begin] {path}... {self.zipfile_serialization=}", ranks=[0]) + torch.save(state_dict, path, _use_new_zipfile_serialization=self.zipfile_serialization) + # log_dist(f"[Torch] Saving [end] {path}... {self.zipfile_serialization=}", ranks=[0]) def load(self, path: str, map_location=None): - logger.info(f"[Torch] Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) - logger.info(f"[Torch] Loaded checkpoint from {path}.") + log_dist(f"[Torch] Begin Load checkpoint from {path}...", ranks=[0]) + partition = torch.load(path, map_location=map_location, weights_only=False) + log_dist(f"[Torch] End Load checkpoint from {path}...", ranks=[0]) return partition - def commit(self, tag): - logger.info(f"[Torch] Checkpoint {tag} is ready now!") + def commit(self, info: CheckpointCommitInfo): + #logger.info(f"[Torch] Checkpoint {tag} is ready now!") return True diff --git a/deepspeed/runtime/checkpoint_engine/utils.py b/deepspeed/runtime/checkpoint_engine/utils.py new file mode 100644 index 000000000000..811cdb9b3d58 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.model_checkpointing.constants import * +from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config +from deepspeed.utils import logger +from deepspeed import comm as dist +from .decoupled_checkpoint_engine import DecoupledCheckpointEngine +from .fast_checkpoint_engine import FastCheckpointEngine +from .torch_checkpoint_engine import TorchCheckpointEngine + + +def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers, optimize_dp_state): + if config_params is not None: + if config_params.checkpoint_config[CHECKPOINT_WRITER] is not None: + writer_config = config_params.checkpoint_config[CHECKPOINT_WRITER] + dp_writer_config = create_data_parallel_writer_config( + groups=groups, + parallel_unit=writer_config[CHECKPOINT_DATA_PARALLEL], + zero_stage=zero_stage, + has_moe_layers=has_moe_layers) + if writer_config[CHECKPOINT_WRITER_DECOUPLED]: + return DecoupledCheckpointEngine(config_params, dp_writer_config, optimize_dp_state) + else: + return FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state) + + if config_params is not None and config_params.nebula_config.enabled: + try: + from .nebula_checkpoint_engine import NebulaCheckpointEngine + except ImportError as err: + logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}") + return TorchCheckpointEngine(config_params) + else: + return NebulaCheckpointEngine(config_params=config_params.nebula_config) + + if config_params.datastates_config.enabled: + try: + from .datastates_checkpoint_engine import DataStatesCheckpointEngine + return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank()) + except ImportError as err: + logger.error( + f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}" + ) + return TorchCheckpointEngine(config_params) + + return TorchCheckpointEngine(config_params) diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index 36f0cb80781f..2fadce52222c 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -8,20 +8,149 @@ """ import math -from typing import List - +from typing import List, Any import torch from torch import Tensor from deepspeed import comm as dist -# NOTE: Use torch.distributed's ProcessGroup class until we have our own. -from torch.distributed import ProcessGroup -import torch.nn.functional - +from deepspeed.comm import ProcessGroup, all_to_all_single +from deepspeed.accelerator import get_accelerator from deepspeed.utils import instrument_w_nvtx +from deepspeed.ops import op_builder +from deepspeed.utils import logger def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False): - return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=async_op) + return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=False) + + +quantizer_module = None + + +@instrument_w_nvtx +@torch.no_grad() +def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]: + global quantizer_module + if quantizer_module is None: + quantizer_module = op_builder.QuantizerBuilder().load() + local_world_size = get_accelerator().device_count() + global_world_size = dist.get_world_size() + num_nodes = global_world_size // local_world_size + this_rank = dist.get_rank() + intra_idx = int(this_rank / local_world_size) + inter_idx = this_rank % local_world_size + output_lst: List[Tensor] = [None] * len(tensors) + for idx, tensor in enumerate(tensors): + if tensor.dim() == 1: + output_lst[idx] = reduce_scatter_coalesced([tensor])[0] + elif tensor.numel() % (2 * global_world_size) != 0: + # Due to the constraint of 2-stage all-to-all, the input tensor must be divisible by 2 * global_world_size + # Otherwise, all-to-all cannot be performed because of shape mismatch. + # See more at https://github.com/deepspeedai/DeepSpeed/pull/5056 + logger.warning( + f"qgZ falls back to reduce_scatter because tensor size = {tensor.numel()} is not divisible by (2 * global_world_size) = {2 * global_world_size}. Please consider allocating a new world to enable qgZ" + ) + output_lst[idx] = reduce_scatter_coalesced([tensor])[0] + else: + intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size) + + inter_quant_group = intra_quant_group // local_world_size + intra_quant_int4, intra_q_scales = quantizer_module.swizzle_quant(tensor, intra_quant_group, 4, + quantizer_module.Symmetric, 1, num_nodes, + local_world_size) + local_output = torch.empty_like(intra_quant_int4) + scale_output = torch.empty_like(intra_q_scales) + all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}']) + all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}']) + global_input_tensor, global_scales = quantizer_module.quantized_reduction( + local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric, + local_world_size) + global_output = torch.empty_like(global_input_tensor) + global_scale_output = torch.empty_like(global_scales) + all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}']) + all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}']) + final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(), + 4, quantizer_module.Symmetric) + assert final_output.numel( + ) % num_nodes == 0, f"final_output.numel()={final_output.numel()} is not divisible by num_nodes={num_nodes}" + output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1) + return output_lst + + +@instrument_w_nvtx +@torch.no_grad() +def all_to_all_loco_quant_reduce( + params: List[Tensor], + groups: {}, + loco_param: Any = None, +) -> List[Tensor]: + global quantizer_module + global loco_idx + if quantizer_module is None: + quantizer_module = op_builder.QuantizerBuilder().load() + local_world_size = get_accelerator().device_count() + global_world_size = dist.get_world_size() + num_nodes = global_world_size // local_world_size + this_rank = dist.get_rank() + intra_idx = int(this_rank / local_world_size) + inter_idx = this_rank % local_world_size + output_lst: List[Tensor] = [None] * len(params) + for idx, p in enumerate(params): + tensor = p.grad + if tensor.dim() == 1: + output_lst[idx] = reduce_scatter_coalesced([tensor])[0] + elif tensor.numel() % (2 * global_world_size) != 0: + # Due to the constraint of 2-stage all-to-all, the input tensor must be divisible by 2 * global_world_size + # Otherwise, all-to-all cannot be performed because of shape mismatch. + # See more at https://github.com/deepspeedai/DeepSpeed/pull/5056 + logger.warning( + f"qgZ falls back to reduce_scatter because tensor size = {tensor.numel()} is not divisible by (2 * global_world_size) = {2 * global_world_size}. Please consider allocating a new world to enable qgZ" + ) + output_lst[idx] = reduce_scatter_coalesced([tensor])[0] + else: + err_beta = loco_param['err_beta'] + reset_T = loco_param['reset_T'] + if not hasattr(p, 'intra_ef_buf') or loco_idx > reset_T: + loco_idx = 0 + intra_err = torch.zeros_like(p.grad) + inter_err = torch.zeros(tensor.numel() // local_world_size, device=tensor.device, dtype=tensor.dtype) + else: + intra_err = quantizer_module.dequantize(p.intra_ef_buf[0], p.intra_ef_buf[1], + p.intra_ef_buf[1].numel(), 8, quantizer_module.Symmetric) + inter_err = quantizer_module.dequantize(p.inter_ef_buf[0], p.inter_ef_buf[1], + p.inter_ef_buf[1].numel(), 8, quantizer_module.Symmetric) + + intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size) + inter_quant_group = intra_quant_group // local_world_size + intra_quant_int4, intra_q_scales = quantizer_module.loco_swizzle_quant(tensor, intra_err, err_beta, + intra_quant_group, 4, + quantizer_module.Symmetric, 1, + num_nodes, local_world_size) + local_output = torch.empty_like(intra_quant_int4) + scale_output = torch.empty_like(intra_q_scales) + all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}']) + all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}']) + + p.intra_ef_buf = quantizer_module.quantize(intra_err, intra_quant_group, 8, quantizer_module.Symmetric) + + global_input_tensor, global_scales = quantizer_module.loco_quantized_reduction( + local_output, scale_output, inter_err, err_beta, intra_quant_group, inter_quant_group, 4, + quantizer_module.Symmetric, local_world_size) + + global_output = torch.empty_like(global_input_tensor) + global_scale_output = torch.empty_like(global_scales) + all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}']) + all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}']) + + p.inter_ef_buf = quantizer_module.quantize(inter_err, inter_quant_group, 8, quantizer_module.Symmetric) + + final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(), + 4, quantizer_module.Symmetric) + assert final_output.numel( + ) % num_nodes == 0, f"final_output.numel()={final_output.numel()} is not divisible by num_nodes={num_nodes}" + output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1) + loco_idx += 1 + + return output_lst @instrument_w_nvtx @@ -32,7 +161,6 @@ def reduce_scatter_coalesced( ) -> List[Tensor]: """simultaneously reduce-scatter a list of tensors - this can be done more efficiently than individual reduce scatter calls - TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL """ this_rank = dist.get_rank(group) @@ -87,5 +215,4 @@ def reduce_scatter_coalesced( 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel()) offset += padded_partition_sz_for_each_tensor[tensor_idx] - return output_lst diff --git a/deepspeed/runtime/comm/compressed.py b/deepspeed/runtime/comm/compressed.py new file mode 100644 index 000000000000..06ddfc3c71c9 --- /dev/null +++ b/deepspeed/runtime/comm/compressed.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np +import torch +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import PackbitsBuilder +from deepspeed.runtime.comm.utils import check_and_handle_empty_buffer + + +class CompressedBackend(object): + + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() + self.size = dist.get_world_size(group=self.world_group) + self.rank = dist.get_rank(group=self.world_group) + self.packer = PackbitsBuilder().load() + + def my_igather(self, rank, size, group, sendbuf, recvbuf, root): + req = [] + if rank == root: + for idx in range(size): + if idx != rank: + req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) + else: + recvbuf[rank] = sendbuf + else: + req.append(dist.isend(sendbuf, group=group, dst=root)) + return req + + def my_gather(self, rank, size, group, sendbuf, recvbuf, root): + if rank == root: + for idx in range(size): + if idx != rank: + dist.recv(recvbuf[idx], src=idx, group=group) + else: + recvbuf[rank] = sendbuf + else: + dist.send(sendbuf, group=group, dst=root) + + def pack(self, buffer, size): + # pack float tensor into uint8 tensor + packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank) + return packed.reshape(size, -1) + + def unpack(self, buffer, size, dtype): + # unpack uint8 to float tensor + unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank) + return unpacked.reshape(size, -1).to(dtype) + + def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): + original_shape = buffer_m.size() + if len(original_shape) > 1: + buffer_m = torch.flatten(buffer_m) + + # align size of original_buffer and error + original_size = buffer_m.numel() + worker_error_size = worker_error.numel() + result = check_and_handle_empty_buffer(buffer_m, original_shape, original_size, worker_error, server_error) + if result is not None: + return result + if original_size != worker_error_size: + empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) + buffer_m = torch.cat([buffer_m, empty_tensor]) + + buffer_m.add_(worker_error) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + + worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8) + + recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], + dtype=sign_list_packed_tmp[0].dtype, + device=sign_list_packed_tmp.device) + + sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] + + recvbuf_scale = [ + torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name()) + for _ in range(self.size) + ] + + # communication phase 1 + # all to all for sign + dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) + # all gather for scale + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) + + flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() + compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \ + .mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) + + compensated_server_m.add_(server_error) + + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + + server_error.set_(compensated_server_m - + server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8) + + # recvbuf_sign_server + recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], + dtype=recvbuf_sign.dtype, + device=server_sign_packed.device) + + recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] + + # recvbuf_scale_server + recvbuf_scale_server_tmp = torch.zeros([self.size, 1], + dtype=worker_scale.dtype, + device=server_sign_packed.device) + + recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] + + # communication Phase 2 + dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) + + recvbuf_sign_server = torch.stack(recvbuf_sign_server) + + flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() + + buffer_m.data.copy_( + self.unpack(flattened_recvbuf_sign_server, self.size, + torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) + + if original_size != worker_error_size: + buffer_m = buffer_m[0:original_size] + if len(original_shape) > 1: + buffer_m = buffer_m.reshape(original_shape) + + return buffer_m diff --git a/deepspeed/runtime/comm/hccl.py b/deepspeed/runtime/comm/hccl.py new file mode 100644 index 000000000000..6dfe610b4ba1 --- /dev/null +++ b/deepspeed/runtime/comm/hccl.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np +import torch +import torch_npu + +import deepspeed.comm as dist +from deepspeed.runtime.comm.utils import check_and_handle_empty_buffer + + +class HcclBackend(object): + + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() + self.size = dist.get_world_size(group=self.world_group) + self.rank = dist.get_rank(group=self.world_group) + + def my_igather(self, rank, size, group, sendbuf, recvbuf, root): + req = [] + if rank == root: + for idx in range(size): + if idx != rank: + req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) + else: + recvbuf[rank] = sendbuf + else: + req.append(dist.isend(sendbuf, group=group, dst=root)) + return req + + def my_gather(self, rank, size, group, sendbuf, recvbuf, root): + if rank == root: + for idx in range(size): + if idx != rank: + dist.recv(recvbuf[idx], src=idx, group=group) + else: + recvbuf[rank] = sendbuf + else: + dist.send(sendbuf, group=group, dst=root) + + def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): + original_shape = buffer_m.size() + if len(original_shape) > 1: + buffer_m = torch.flatten(buffer_m) + + # align size of original_buffer and error + original_size = buffer_m.numel() + worker_error_size = worker_error.numel() + result = check_and_handle_empty_buffer(buffer_m, original_shape, original_size, worker_error, server_error) + if result is not None: + return result + if original_size != worker_error_size: + empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) + buffer_m = torch.cat([buffer_m, empty_tensor]) + + buffer_m.add_(worker_error) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + + worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + sign_list_packed_tmp = torch_npu.npu_sign_bits_pack(buffer_m, self.size).type(torch.int8) + + recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], + dtype=sign_list_packed_tmp[0].dtype, + device=sign_list_packed_tmp.device) + + sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] + + recvbuf_scale = [ + torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(local_rank)) for _ in range(self.size) + ] + + # communication phase 1 + # all to all for sign + dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) + # all gather for scale + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) + + flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() + compensated_server_m = torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign, self.size, torch.float32) \ + .mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) + + compensated_server_m.add_(server_error) + + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + + server_error.set_(compensated_server_m - + server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + server_sign_packed = torch_npu.npu_sign_bits_pack(compensated_server_m, 1).type(torch.int8) + + # recvbuf_sign_server + recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], + dtype=recvbuf_sign.dtype, + device=server_sign_packed.device) + + recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] + + # recvbuf_scale_server + recvbuf_scale_server_tmp = torch.zeros([self.size, 1], + dtype=worker_scale.dtype, + device=server_sign_packed.device) + + recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] + + # communication Phase 2 + dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) + + recvbuf_sign_server = torch.stack(recvbuf_sign_server) + + flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() + + buffer_m.data.copy_( + torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign_server, self.size, + torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) + + if original_size != worker_error_size: + buffer_m = buffer_m[0:original_size] + if len(original_shape) > 1: + buffer_m = buffer_m.reshape(original_shape) + + return buffer_m diff --git a/deepspeed/runtime/comm/mpi.py b/deepspeed/runtime/comm/mpi.py index c2598f1e5986..94161c718400 100644 --- a/deepspeed/runtime/comm/mpi.py +++ b/deepspeed/runtime/comm/mpi.py @@ -9,6 +9,7 @@ import numpy as np from mpi4py import MPI +from deepspeed.runtime.comm.utils import check_and_handle_empty_buffer from deepspeed.runtime.compression.cupy import CupyBackend @@ -137,6 +138,9 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro buffer_m = torch.flatten(buffer_m) original_size = buffer_m.numel() worker_error_size = worker_error.numel() + result = check_and_handle_empty_buffer(buffer_m, original_shape, original_size, worker_error, server_error) + if result is not None: + return result cupy.cuda.Device(local_rank).use() if original_size != worker_error_size: @@ -144,7 +148,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m.add_(worker_error) - worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) cupy_sign_list_packed = self.compression_backend.compress_by_chunk( @@ -173,7 +177,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) - server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py index 0bd0d1361973..67099026123c 100644 --- a/deepspeed/runtime/comm/nccl.py +++ b/deepspeed/runtime/comm/nccl.py @@ -4,12 +4,14 @@ # DeepSpeed Team import torch -from deepspeed import comm as dist import cupy import numpy as np -from deepspeed.runtime.compression.cupy import CupyBackend +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.comm.utils import check_and_handle_empty_buffer +from deepspeed.runtime.compression.cupy import CupyBackend +from deepspeed.utils.torch import required_torch_version class NcclBackend(object): @@ -23,11 +25,7 @@ def __init__(self, mpu=None): self.rank = dist.get_rank(group=self.world_group) self.size = dist.get_world_size(group=self.world_group) self.compression_backend = CupyBackend() - self.bool_not_supported = False - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - if (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) or TORCH_MAJOR == 2: - self.bool_not_supported = True + self.bool_not_supported = required_torch_version(min_version=1.10) def my_igather(self, rank, size, group, sendbuf, recvbuf, root): req = [] @@ -59,6 +57,9 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro buffer_m = torch.flatten(buffer_m) original_size = buffer_m.numel() worker_error_size = worker_error.numel() + result = check_and_handle_empty_buffer(buffer_m, original_shape, original_size, worker_error, server_error) + if result is not None: + return result cupy.cuda.Device(local_rank).use() if original_size != worker_error_size: @@ -66,7 +67,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m.add_(worker_error) - worker_scale = torch.norm(buffer_m) / np.sqrt(buffer_m.numel()) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(buffer_m.numel()) worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) if self.bool_not_supported: @@ -112,7 +113,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) - server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) diff --git a/deepspeed/runtime/comm/utils.py b/deepspeed/runtime/comm/utils.py new file mode 100644 index 000000000000..731a7be13ca1 --- /dev/null +++ b/deepspeed/runtime/comm/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + + +def check_and_handle_empty_buffer( + buffer_m: torch.Tensor, + original_shape: torch.Size, + original_size: int, + worker_error: torch.Tensor, + server_error: torch.Tensor, +) -> Optional[torch.Tensor]: + if original_size == 0: + if worker_error.numel(): + worker_error.zero_() + if server_error.numel(): + server_error.zero_() + if len(original_shape) > 1: + return buffer_m.reshape(original_shape) + return buffer_m + return None diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py new file mode 100644 index 000000000000..6c605658e2a3 --- /dev/null +++ b/deepspeed/runtime/compiler.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import contextlib +import functools +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator + +try: + from torch.compiler import is_compiling as torch_is_compiling +except ImportError: + try: + from torch._dynamo.external_utils import is_compiling as torch_is_compiling + except ImportError: + # Torch does not have compiler support + torch_is_compiling = lambda: False + +try: + if required_torch_version(min_version="2.6.0a"): + from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable + else: + from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable + + _COMPILED_AUTOGRAD_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + _COMPILED_AUTOGRAD_AVAILABLE = False + + +def is_compile_supported(): + return required_torch_version(min_version=2.1) + + +def disable(func): + if is_compile_supported(): + return torch.compiler.disable(func) + return func + + +def enable(min_version=None): + """ + Decorator factory to enable compiling of a function if the minimum PyTorch version requirement is met. + + Args: + min_version (str, optional): Minimum PyTorch version required (e.g., "2.7.0"). + If None, the function is always enabled. + + Returns: + Callable: A decorator that wraps the function. + + Examples: + @enable("2.7.0") + def my_function(): + pass + + @enable + def another_function(): + pass + """ + + def decorator(func): + if not is_compiling(): + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if min_version is None or required_torch_version(min_version=min_version): + return func(*args, **kwargs) + return disable(func)(*args, **kwargs) + + return wrapper + + # Called with no arguments + if callable(min_version): + func = min_version + min_version = None + return decorator(func) + + return decorator + + +def is_compiling(): + return torch_is_compiling() + + +@contextlib.contextmanager +def compiled_autograd(enabled: bool, kwargs: dict): + if not enabled or not _COMPILED_AUTOGRAD_AVAILABLE: + yield + return + + if torch_is_compiling(): + yield + return + + compiler_fn = torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs) + + with compiled_autograd_enable(compiler_fn): + yield + + +def dummy_decorator(func): + return func + + +# robust version of @torch.compile +def compile(): + if hasattr(torch, "compile"): + return torch.compile + else: + return dummy_decorator diff --git a/deepspeed/runtime/compression/cupy.py b/deepspeed/runtime/compression/cupy.py index b959a9c20372..7133ac04ed2b 100644 --- a/deepspeed/runtime/compression/cupy.py +++ b/deepspeed/runtime/compression/cupy.py @@ -14,10 +14,10 @@ def __init__(self): pass def torch2cupy(self, tensor): - return cupy.fromDlpack(to_dlpack(tensor)) + return cupy.from_dlpack(to_dlpack(tensor)) def cupy2torch(self, cupy_tensor): - return from_dlpack(cupy_tensor.toDlpack()) + return from_dlpack(cupy_tensor) def compress_by_chunk(self, cupy_bool_tensor, num_chunks): packed_sign = cupy.packbits(cupy_bool_tensor) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 3c202a9acd07..ec3833cbdcc6 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -14,12 +14,6 @@ import base64 from .constants import * -from .fp16.loss_scaler import ( - INITIAL_LOSS_SCALE, - SCALE_WINDOW, - DELAYED_SHIFT, - MIN_LOSS_SCALE, -) from .config_utils import ( get_scalar_param, dict_raise_error_on_duplicate_keys, @@ -29,6 +23,9 @@ from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config +from ..inference.config import WeightQuantConfig +from .precision_config import get_bfloat16_config, get_float16_config +from ..compile.config import CompileConfig from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -46,8 +43,8 @@ ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, - MODEL_PARLLEL_SIZE, - MODEL_PARLLEL_SIZE_DEFAULT, + MODEL_PARALLEL_SIZE, + MODEL_PARALLEL_SIZE_DEFAULT, NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT, ) @@ -55,14 +52,19 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig from ..autotuning.config import DeepSpeedAutotuningConfig from ..nebula.config import DeepSpeedNebulaConfig +from ..datastates.config import DeepSpeedDataStatesConfig from ..compression.config import get_compression_config, get_quantize_enabled from ..compression.constants import * from .swap_tensor.aio_config import get_aio_config +from .model_checkpointing.config import get_checkpoint_config +from .tensor_parallel import get_tensor_parallel_config from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy from .data_pipeline.constants import * +from ..utils.config import get_timers_config + TENSOR_CORE_ALIGN_SIZE = 8 ADAGRAD_OPTIMIZER = 'adagrad' @@ -72,9 +74,15 @@ ONEBIT_ADAM_OPTIMIZER = 'onebitadam' ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam' ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' +MUADAM_OPTIMIZER = 'muadam' +MUADAMW_OPTIMIZER = 'muadamw' +MUSGD_OPTIMIZER = 'musgd' +LION_OPTIMIZER = 'lion' +MUON_OPTIMIZER = 'muon' + DEEPSPEED_OPTIMIZERS = [ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, - ZERO_ONE_ADAM_OPTIMIZER + ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER ] # extra optimizer parameters for adam/adamw @@ -148,76 +156,33 @@ def get_amp_params(param_dict): return False -def get_fp16_enabled(param_dict): - if FP16 in param_dict.keys(): - return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) - else: - return False - - -def get_bfloat16_enabled(param_dict): - for key in [BFLOAT16, BFLOAT16_OLD]: - if key in param_dict.keys(): - return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT) - return False - - -def get_fp16_master_weights_and_grads_enabled(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) +def get_torch_autocast_enabled(param_dict): + if TORCH_AUTOCAST in param_dict.keys(): + return get_scalar_param(param_dict[TORCH_AUTOCAST], TORCH_AUTOCAST_ENABLED, TORCH_AUTOCAST_ENABLED_DEFAULT) else: return False -def get_fp16_auto_cast(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT) - - -def get_loss_scale(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) - elif get_bfloat16_enabled(param_dict): - return 1.0 - else: - return FP16_LOSS_SCALE_DEFAULT +def get_torch_autocast_dtype(param_dict): + if TORCH_AUTOCAST in param_dict: + if TORCH_AUTOCAST_DTYPE in param_dict[TORCH_AUTOCAST]: + try: + return DtypeEnum(param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]).value + except KeyError: + raise ValueError( + f"Invalid dtype for torch autocast: {param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]}") + return None -def get_initial_dynamic_scale(param_dict): - if get_fp16_enabled(param_dict): - initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, - FP16_INITIAL_SCALE_POWER_DEFAULT) - elif get_bfloat16_enabled(param_dict): - initial_scale_power = 0 - else: - initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT - - return 2**initial_scale_power - - -def get_dynamic_loss_scale_args(param_dict): - loss_scale_args = None - if get_fp16_enabled(param_dict): - fp16_dict = param_dict[FP16] - dynamic_loss_args = [ - FP16_INITIAL_SCALE_POWER, - FP16_LOSS_SCALE_WINDOW, - FP16_MIN_LOSS_SCALE, - FP16_HYSTERESIS, - ] - if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): - init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) - scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT) - delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT) - min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT) - loss_scale_args = { - INITIAL_LOSS_SCALE: 2**init_scale, - SCALE_WINDOW: scale_window, - DELAYED_SHIFT: delayed_shift, - MIN_LOSS_SCALE: min_loss_scale, - } - - return loss_scale_args +def get_lower_precision_safe_modules(param_dict): + if TORCH_AUTOCAST in param_dict: + if TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES in param_dict[TORCH_AUTOCAST]: + module_names_with_package = param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES] + if not all(isinstance(module_name, str) for module_name in module_names_with_package): + raise ValueError( + f"Invalid module names for torch autocast: {module_names_with_package}. Expected list of strings.") + return module_names_with_package + return None def get_gradient_accumulation_steps(param_dict): @@ -228,8 +193,10 @@ def get_sparse_gradients_enabled(param_dict): return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT) -def get_communication_data_type(param_dict): - val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT) +def get_communication_data_type(param_dict, + comm_type=COMMUNICATION_DATA_TYPE, + comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT): + val = get_scalar_param(param_dict, comm_type, comm_data_type_default) val = val.lower() if val is not None else val if val is None: return val # we must determine it by other parameters @@ -237,10 +204,10 @@ def get_communication_data_type(param_dict): return torch.float32 elif val == "fp16": return torch.float16 - elif val == "bfp16": + elif val == "bf16": return torch.bfloat16 - raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}") + raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bf16', 'fp32']. Got: {val}") def get_prescale_gradients(param_dict): @@ -267,6 +234,10 @@ def get_gradient_clipping(param_dict): return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) +def get_graph_harvesting(param_dict): + return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT) + + def get_sparse_attention(param_dict): if SPARSE_ATTENTION in param_dict.keys(): sparsity = param_dict[SPARSE_ATTENTION] @@ -437,6 +408,8 @@ def get_pipeline_config(param_dict): "partition": "best", "seed_layers": False, "activation_checkpoint_interval": 0, + "pipe_partitioned": True, + "grad_partitioned": True, } config = default_pipeline for key, val in param_dict.get("pipeline", {}).items(): @@ -530,6 +503,10 @@ def get_hybrid_engine_config(param_dict): return hybrid_engine_config +def get_expert_data_topo_config(param_dict): + return get_scalar_param(param_dict, USE_DATA_BEFORE_EXPERT_PARALLEL, USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT) + + def get_eigenvalue_config(param_dict): if get_quantize_enabled(param_dict): param_dict = param_dict[QUANTIZE_TRAINING] @@ -673,7 +650,7 @@ def write_config(self, filename): class DeepSpeedConfig(object): - def __init__(self, config: Union[str, dict], mpu=None): + def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None): super(DeepSpeedConfig, self).__init__() if isinstance(config, dict): self._param_dict = config @@ -687,16 +664,27 @@ def __init__(self, config: Union[str, dict], mpu=None): raise ValueError( f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}" ) + try: self.global_rank = dist.get_rank() - if mpu is None: - self.world_size = dist.get_world_size() + if mpu is not None: + # Ulysses SP + if not hasattr(mpu, "get_data_parallel_world_size"): + self.world_size = dist.get_world_size() / mpu.get_sequence_parallel_world_size() + else: + self.world_size = mpu.get_data_parallel_world_size() + elif mesh_device is not None: + self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) else: - self.world_size = mpu.get_data_parallel_world_size() - except: + # HF zero.init case where there is no mpu + if "sequence_parallel_size" in config: + self.world_size = dist.get_world_size() / config["sequence_parallel_size"] + else: + self.world_size = dist.get_world_size() + except Exception: self.global_rank = 0 self.world_size = 1 - + logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}") # If elastic-mode enabled, update compute + update _param_dict self.elasticity_enabled = elasticity_enabled(self._param_dict) if self.elasticity_enabled: @@ -712,7 +700,7 @@ def __init__(self, config: Union[str, dict], mpu=None): # Ensure the resource scheduler saw the same elastic config we are using at runtime ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) - self.elastic_model_parallel_size = elastic_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT) + self.elastic_model_parallel_size = elastic_dict.get(MODEL_PARALLEL_SIZE, MODEL_PARALLEL_SIZE_DEFAULT) if self.elastic_model_parallel_size < 1: raise ElasticityConfigError("Model-Parallel size cannot be less than 1, " f"given model-parallel size: {self.elastic_model_parallel_size}") @@ -766,7 +754,6 @@ def __init__(self, config: Union[str, dict], mpu=None): def _initialize_params(self, param_dict): self.train_batch_size = get_train_batch_size(param_dict) - #print(f"beginning get_train_batch_size = {get_train_batch_size}") self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict) self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict) self.steps_per_print = get_steps_per_print(param_dict) @@ -774,11 +761,15 @@ def _initialize_params(self, param_dict): self.disable_allgather = get_disable_allgather(param_dict) self.communication_data_type = get_communication_data_type(param_dict) + self.seq_parallel_communication_data_type = get_communication_data_type( + param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT) self.prescale_gradients = get_prescale_gradients(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) self.zero_config = get_zero_config(param_dict) + self.mics_shard_size = self.zero_config.mics_shard_size + self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather self.zero_optimization_stage = self.zero_config.stage self.zero_enabled = self.zero_optimization_stage > 0 @@ -788,19 +779,20 @@ def _initialize_params(self, param_dict): self.monitor_config = get_monitor_config(param_dict) self.gradient_clipping = get_gradient_clipping(param_dict) - self.fp16_enabled = get_fp16_enabled(param_dict) - self.fp16_auto_cast = get_fp16_auto_cast(param_dict) - self.bfloat16_enabled = get_bfloat16_enabled(param_dict) - assert not (self.fp16_enabled - and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) + self.float16_config = get_float16_config(param_dict) + self.bfloat16_config = get_bfloat16_config(param_dict) + assert not (self.float16_config.enabled + and self.bfloat16_config.enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' + self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) - self.loss_scale = get_loss_scale(param_dict) - self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) - self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) + + self.torch_autocast_enabled = get_torch_autocast_enabled(param_dict) + self.torch_autocast_dtype = get_torch_autocast_dtype(param_dict) + self.torch_autocast_lower_precision_safe_modules = get_lower_precision_safe_modules(param_dict) self.compression_config = get_compression_config(param_dict) + self.graph_harvesting = get_graph_harvesting(param_dict) self.optimizer_name = get_optimizer_name(param_dict) if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS): @@ -832,6 +824,7 @@ def _initialize_params(self, param_dict): self.eigenvalue_layer_num, ) = get_eigenvalue_config(param_dict) + self.use_data_before_expert_parallel_ = get_expert_data_topo_config(param_dict) self.hybrid_engine = get_hybrid_engine_config(param_dict) self.sparse_attention = get_sparse_attention(param_dict) @@ -867,6 +860,16 @@ def _initialize_params(self, param_dict): self.dataloader_drop_last = get_dataloader_drop_last(param_dict) self.nebula_config = DeepSpeedNebulaConfig(param_dict) + self.datastates_config = DeepSpeedDataStatesConfig(param_dict) + self.checkpoint_config = get_checkpoint_config(param_dict) + + self.weight_quantization_config = WeightQuantConfig( + **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None + + self.compile_config = CompileConfig(**param_dict.get('compile', {})) + + self.timers_config = get_timers_config(param_dict) + self.tensor_parallel_config = get_tensor_parallel_config(param_dict) def _batch_assertion(self): @@ -891,7 +894,7 @@ def _set_batch_related_parameters(self): micro_batch = self.train_micro_batch_size_per_gpu grad_acc = self.gradient_accumulation_steps - #print(f"train_batch = {train_batch}, micro_batch={micro_batch}") + #print(f"in: train_batch = {train_batch}, micro_batch={micro_batch}") # all values are provided nothing needs to be set if train_batch is not None and micro_batch is not None and grad_acc is not None: @@ -930,6 +933,8 @@ def _set_batch_related_parameters(self): assert False, \ 'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided' + #print(f"final: {self.train_batch_size=} {self.train_micro_batch_size_per_gpu=} {self.gradient_accumulation_steps=}") + def _configure_train_batch_size(self): self._set_batch_related_parameters() self._batch_assertion() @@ -966,15 +971,26 @@ def _do_error_check(self): self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert (self.zero_optimization_stage <= - ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + assert (self.zero_optimization_stage + <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) - if self.fp16_master_weights_and_gradients: - assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." + if self.float16_config.fp16_master_weights_and_grads: + assert self.zero_enabled and self.zero_optimization_stage in ( + ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients, + ZeroStageEnum.weights), "Fp16_master_weights_and_grads is only supported with ZeRO Stage 1, 2, or 3." + if self.bfloat16_config.bf16_master_weights_and_grads: + assert self.zero_enabled and self.zero_optimization_stage in ( + ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients, + ZeroStageEnum.weights), "Bf16_master_weights_and_grads is only supported with ZeRO Stage 1, 2, or 3." + if self.bfloat16_config.bf16_optimizer_states: + assert self.zero_enabled and self.zero_optimization_stage in ( + ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients, + ZeroStageEnum.weights), "bf16_optimizer_states is only supported with ZeRO Stage 1, 2, or 3." + assert self.bfloat16_config.bf16_master_weights_and_grads, "bf16_optimizer_states requires bf16_master_weights_and_grads to be enabled." def _do_warning_check(self): - fp16_enabled = self.fp16_enabled + fp16_enabled = self.float16_config.enabled vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 0fb1372deac8..cf8a593cfba9 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -5,11 +5,12 @@ """ Collection of DeepSpeed configuration utilities """ -import json import collections -import collections.abc +import json +import torch from functools import reduce -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, field_serializer + from deepspeed.utils import logger @@ -54,67 +55,73 @@ def __init__(self, strict=False, **data): if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")} super().__init__(**data) - self._deprecated_fields_check(self) + self._deprecated_fields_check() - def _process_deprecated_field(self, pydantic_config, field): + def _process_deprecated_field(self, dep_field): # Get information about the deprecated field - fields_set = pydantic_config.__fields_set__ - dep_param = field.name - kwargs = field.field_info.extra + pydantic_config = self + fields_set = pydantic_config.model_fields_set + kwargs = type(pydantic_config).model_fields[dep_field].json_schema_extra new_param_fn = kwargs.get("new_param_fn", lambda x: x) - param_value = new_param_fn(getattr(pydantic_config, dep_param)) - new_param = kwargs.get("new_param", "") + param_value = new_param_fn(getattr(pydantic_config, dep_field)) + new_field = kwargs.get("new_param", "") dep_msg = kwargs.get("deprecated_msg", "") - if dep_param in fields_set: - logger.warning(f"Config parameter {dep_param} is deprecated" + - (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else "")) + if dep_field in fields_set: + logger.warning(f"Config parameter {dep_field} is deprecated" + + (f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else "")) # Check if there is a new param and if it should be set with a value - if new_param and kwargs.get("set_new_param", True): + if new_field and kwargs.get("set_new_param", True): # Remove the deprecate field if there is a replacing field try: - delattr(pydantic_config, dep_param) + delattr(pydantic_config, dep_field) except Exception as e: - logger.error(f"Tried removing deprecated '{dep_param}' from config") + logger.error(f"Tried removing deprecated '{dep_field}' from config") raise e # Set new param value - new_param_nested = new_param.split(".") + new_param_nested = new_field.split(".") if len(new_param_nested) > 1: # If the new param exists in a subconfig, we need to get # the fields set for that subconfig pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config) - fields_set = pydantic_config.__fields_set__ + fields_set = pydantic_config.model_fields_set new_param_name = new_param_nested[-1] assert ( new_param_name not in fields_set - ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together" + ), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together" # A custom function for converting the old param value to new param value can be provided try: setattr(pydantic_config, new_param_name, param_value) except Exception as e: - logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'") + logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'") raise e - def _deprecated_fields_check(self, pydantic_config): - fields = pydantic_config.__fields__ - for field in fields.values(): - if field.field_info.extra.get("deprecated", False): - self._process_deprecated_field(pydantic_config, field) + def _deprecated_fields_check(self): + fields = type(self).model_fields + for field_name, field_info in fields.items(): + if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False): + self._process_deprecated_field(field_name) + + model_config = ConfigDict( + validate_default=True, + validate_assignment=True, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + arbitrary_types_allowed=True, + protected_namespaces=(), + ) - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - allow_population_by_field_name = True - extra = "forbid" - arbitrary_types_allowed = True + @field_serializer("dtype", check_fields=False) + def serialize_torch_dtype(dtype: torch.dtype) -> str: + return str(dtype) def get_config_default(config, field_name): - assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}" - assert not config.__fields__.get( - field_name).required, f"'{field_name}' is a required field and does not have a default value" - return config.__fields__.get(field_name).default + assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}" + assert not config.model_fields.get( + field_name).is_required(), f"'{field_name}' is a required field and does not have a default value" + return config.model_fields.get(field_name).get_default() class pp_int(int): @@ -130,7 +137,7 @@ def __new__(cls, val, custom_print_str=None): return inst def __repr__(self): - if self.custom_print_str: + if hasattr(self, "custom_print_str") and self.custom_print_str: return self.custom_print_str return f"{self.real:,}" @@ -162,7 +169,7 @@ def iterencode(self, o, _one_shot=False, level=0): x = [f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, v in o.items()] return "{" + ", ".join(x) + f"\n{prefix_close}" + "}" elif isinstance(o, collections.abc.Sequence) and not isinstance(o, str): - return f"[{ f', '.join(map(self.iterencode, o)) }]" + return f"[{ ', '.join(map(self.iterencode, o)) }]" return "\n, ".join(super().iterencode(o, _one_shot)) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 62b46e2a6ce9..9e73bad73376 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from deepspeed.accelerator import get_accelerator + ############################################# # Routes ############################################# @@ -77,7 +79,7 @@ # Steps STEPS_PER_PRINT = "steps_per_print" -STEPS_PER_PRINT_DEFAULT = 10 +STEPS_PER_PRINT_DEFAULT = None ######################################### # Training micro batch size per GPU @@ -117,7 +119,9 @@ BFLOAT16_FORMAT = ''' BFLOAT16 parameters should be of the format: "bf16": { - "enabled": true + "enabled": true, + "immediate_grad_update": false, + "check_overflow": false } ''' BFLOAT16 = "bf16" @@ -126,6 +130,24 @@ BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False +CHECK_OVERFLOW = "check_overflow" +BFLOAT16_CHECK_OVERFLOW_DEFAULT = False + +# BFLOAT16 optimizer immediate gradient update +BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" +BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True + +# BFLOAT16 master weights and optimizer states options +BFLOAT16_MASTER_WEIGHTS_AND_GRADS = "bf16_master_weights_and_grads" +BFLOAT16_MASTER_WEIGHTS_AND_GRADS_DEFAULT = False +BFLOAT16_OPTIMIZER_STATES = "bf16_optimizer_states" +BFLOAT16_OPTIMIZER_STATES_DEFAULT = False + +# DDP variant of BFLOAT16 +# DDP variant: bf16 model with bf16 grad accumulation (uses FP16_Optimizer in bf16 mode) +# Must be different from BFLOAT16 to allow proper optimizer selection +DDP_BFLOAT16 = "ddp_bf16" + ######################################### # FP16 support ######################################### @@ -140,6 +162,7 @@ "initial_scale_power": 16, "loss_scale_window": 1000, "hysteresis": 2, + "consecutive_hysteresis": false, "min_loss_scale": 1 } ''' @@ -167,6 +190,10 @@ FP16_HYSTERESIS = "hysteresis" FP16_HYSTERESIS_DEFAULT = 2 +# FP16 consecutive hysteresis +FP16_CONSECUTIVE_HYSTERESIS = "consecutive_hysteresis" +FP16_CONSECUTIVE_HYSTERESIS_DEFAULT = False + # FP16 min loss scale FP16_MIN_LOSS_SCALE = "min_loss_scale" FP16_MIN_LOSS_SCALE_DEFAULT = 1 @@ -193,6 +220,27 @@ AMP_ENABLED = "enabled" AMP_ENABLED_DEFAULT = False +######################################### +# Torch AMP support +######################################### +TORCH_AUTOCAST_FORMAT = ''' +PyTorch autocast config should be of the format: +"torch_autocast": { + "enabled": true, + "dtype": "bfloat16", + "lower_precision_safe_modules": [ + "torch.nn.modules.linear.Linear", + "torch.nn.modules.conv.Conv2d" + ] +} +''' +TORCH_AUTOCAST = "torch_autocast" + +TORCH_AUTOCAST_ENABLED = "enabled" +TORCH_AUTOCAST_ENABLED_DEFAULT = False +TORCH_AUTOCAST_DTYPE = "dtype" +TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES = "lower_precision_safe_modules" + ######################################### # Gradient clipping ######################################### @@ -205,6 +253,18 @@ GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING_DEFAULT = 0. +######################################### +# Capture graph for short kernels sequences +######################################### +# Graph harvesting. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +GRAPH_HARVESTING_FORMAT = ''' +Graph harvesting should be enabled as: +"graph_harvesting": true +''' +GRAPH_HARVESTING = 'graph_harvesting' +GRAPH_HARVESTING_DEFAULT = False + ######################################### # Communication data type ######################################### @@ -218,6 +278,26 @@ COMMUNICATION_DATA_TYPE = "communication_data_type" COMMUNICATION_DATA_TYPE_DEFAULT = None +########################################################### +# Gradient communication data type for sequence parallelism +########################################################### +# Supported types: ['fp16', 'bf16','fp32'] +# Default value is fp32 +# Users can configure in ds_config.json as below example: +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_FORMAT = ''' +Optional comm data type for seq paralleism should be set as: +"seq_parallel_communication_data_type": "fp32" +''' +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_communication_data_type" + +if get_accelerator().device_name == 'cuda' and get_accelerator().communication_backend_version() >= (2, 27, 3): + # nccl>=2.27.3 uses fp32 accumulation for half precision inputs, so there is no need to waste compute and memory to manually upcast to fp32 unless the user wants it and then override + SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = None +else: + SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32" + +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32" + ######################################### # Scale/predivide gradients before allreduce ######################################### @@ -415,3 +495,9 @@ class ValidationMode: ######################################### DATA_PARALLEL_GROUP = "data_parallel_group" GLOBAL_RANK = "global_rank" + +######################################### +# EXPERT-DATA PARALLELISM TOPO Config +######################################### +USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" +USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False diff --git a/deepspeed/runtime/data_pipeline/config.py b/deepspeed/runtime/data_pipeline/config.py index 623480518925..690ce97034e4 100644 --- a/deepspeed/runtime/data_pipeline/config.py +++ b/deepspeed/runtime/data_pipeline/config.py @@ -20,7 +20,6 @@ def get_data_efficiency_config(param_dict): sub_param_dict = param_dict[DATA_EFFICIENCY] output[DATA_SAMPLING] = get_data_sampling(sub_param_dict) output[DATA_ROUTING] = get_data_routing(sub_param_dict) - return output @@ -39,15 +38,14 @@ def get_data_efficiency_seed(param_dict): def get_data_sampling(param_dict): - output = {} + sub_param_dict = param_dict.get(DATA_SAMPLING, {}) + output = copy.copy(sub_param_dict) output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict) output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict) output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict) - if DATA_SAMPLING not in param_dict.keys(): - param_dict[DATA_SAMPLING] = {} - sub_param_dict = param_dict[DATA_SAMPLING] + output[DATA_SAMPLING_PIN_MEMORY] = get_data_sampling_pin_memory(param_dict) output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict) - + output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict) return output @@ -73,6 +71,13 @@ def get_data_sampling_num_workers(param_dict): return DATA_SAMPLING_NUM_WORKERS_DEFAULT +def get_data_sampling_pin_memory(param_dict): + if DATA_SAMPLING in param_dict.keys(): + return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_PIN_MEMORY, DATA_SAMPLING_PIN_MEMORY_DEFAULT) + else: + return DATA_SAMPLING_PIN_MEMORY_DEFAULT + + def get_curriculum_learning(param_dict): output = {} output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict) @@ -87,6 +92,26 @@ def get_curriculum_learning(param_dict): return output +def get_dynamic_batching(param_dict): + output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {})) + output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT)) + output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str( + output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT)) + output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int( + output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT)) + output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \ + if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \ + else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT + output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str( + output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT)) + if output[DYNAMIC_BATCHING_ENABLED]: + assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys( + ), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified" + output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS]) + output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False)) + return output + + def get_curriculum_learning_enabled(param_dict): if CURRICULUM_LEARNING in param_dict.keys(): return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED, diff --git a/deepspeed/runtime/data_pipeline/constants.py b/deepspeed/runtime/data_pipeline/constants.py index 1ade640e38d9..73cc69c1f606 100644 --- a/deepspeed/runtime/data_pipeline/constants.py +++ b/deepspeed/runtime/data_pipeline/constants.py @@ -22,6 +22,8 @@ DATA_SAMPLING_NUM_EPOCHS_DEFAULT = 1000 DATA_SAMPLING_NUM_WORKERS = "num_workers" DATA_SAMPLING_NUM_WORKERS_DEFAULT = 0 +DATA_SAMPLING_PIN_MEMORY = "pin_memory" +DATA_SAMPLING_PIN_MEMORY_DEFAULT = False ######################################### # Data efficiency - Data Sampling - Curriculum Learning @@ -62,6 +64,24 @@ CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION = "data_cluster_current_position" CURRICULUM_LEARNING_NP_RNG_STATE = "np_rng_state" +######################################### +# Data efficiency - Dynamic batching and LR scaling +######################################### +DYNAMIC_BATCHING = "dynamic_batching" +DYNAMIC_BATCHING_ENABLED = "enabled" +DYNAMIC_BATCHING_ENABLED_DEFAULT = False +DYNAMIC_BATCHING_METRICS_PATH = "metrics_path" +DYNAMIC_BATCHING_LR_SCALING_METHOD = "lr_scaling_method" # "linear" / "sqrt" / "none" +DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT = "linear" +DYNAMIC_BATCHING_MIN_BATCH_SIZE = "min_batch_size" +DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT = 1 +DYNAMIC_BATCHING_MAX_BATCH_SIZE = "max_batch_size" +DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT = None +DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER = "sequence_picking_order" # "random" / "seqlen" / "dataloader" +DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT = "dataloader" # "random" / "seqlen" / "dataloader" +DYNAMIC_BATCHING_MAX_TOKENS = "max_tokens" +DYNAMIC_BATCHING_VERBOSE = "verbose" + ######################################### # Curriculum Learning legacy implementation ######################################### diff --git a/deepspeed/runtime/data_pipeline/curriculum_scheduler.py b/deepspeed/runtime/data_pipeline/curriculum_scheduler.py index 23d747957dc4..296cc7fcd32b 100644 --- a/deepspeed/runtime/data_pipeline/curriculum_scheduler.py +++ b/deepspeed/runtime/data_pipeline/curriculum_scheduler.py @@ -73,7 +73,7 @@ def __init__(self, config): f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE}'" if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0: logger.warning( - f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' + 'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' ) self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG] elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR: @@ -91,7 +91,7 @@ def __init__(self, config): f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'" if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0: logger.warning( - f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' + 'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' ) self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG] elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM: diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index 556a6fd1ddca..353bacf69003 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -3,7 +3,9 @@ # DeepSpeed Team +import glob import os +import sys from collections import defaultdict import csv import time @@ -12,9 +14,10 @@ import torch from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset +import deepspeed.comm as dist from deepspeed.utils import logger -from .indexed_dataset import MMapIndexedDataset -from .utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype +from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset, valid_dtypes +from deepspeed.runtime.data_pipeline.data_sampling.utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype class DataAnalyzer(object): @@ -36,7 +39,8 @@ def __init__(self, custom_map_init=None, custom_map_update=None, custom_map_finalize=None, - custom_reduce=None): + custom_reduce=None, + sample_indices=None): super().__init__() self.dataset = dataset self.num_workers = num_workers @@ -55,22 +59,22 @@ def __init__(self, self.custom_map_update = custom_map_update self.custom_map_finalize = custom_map_finalize self.custom_reduce = custom_reduce + self.sample_indices = sample_indices def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id): metric_results = [] for m_idx in range(len(metric_names)): metric_name, metric_type, metric_dtype = metric_names[m_idx], \ metric_types[m_idx], metric_dtypes[m_idx] - assert metric_dtype not in [ - np.float64, np.double - ], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)." + assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}" metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/" os.makedirs(metric_save_path, exist_ok=True) if metric_type == 'single_value_per_sample': sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric" sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype) metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample" - os.system(f"rm -rf {metric_to_sample_fname}*") + for _f in glob.glob(f"{glob.escape(metric_to_sample_fname)}*"): + os.remove(_f) metric_to_sample_dict = defaultdict(list) metric_results.append({ "sample_to_metric_fname": sample_to_metric_fname, @@ -84,16 +88,34 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_functions, metric_results): + def update_metric_results(self, + data, + metric_types, + metric_dtypes, + metric_functions, + metric_results, + batch_start_idx=0): for m_idx in range(len(metric_types)): - metric_type, metric_function, metric_result = metric_types[m_idx], \ - metric_functions[m_idx], metric_results[m_idx] + metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ + metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] + metric_values = metric_function(data) + + assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \ + "metric_function must return a tensor or array" + assert metric_values.dtype == metric_dtype, \ + f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}" + if isinstance(metric_values, np.ndarray): + metric_values = torch.from_numpy(metric_values) + if metric_type == 'single_value_per_sample': - metric_values = metric_function(data) for row in range(metric_values.size()[0]): + sample_idx = batch_start_idx + row # sample idx following dataset iteration order + if isinstance(data, dict) and 'index' in data: # Megatron use case, idx provided in 'index' field + sample_idx = data['index'][row][0].item() + elif self.sample_indices is not None: # user defined shuffling of indices + sample_idx = self.sample_indices[sample_idx] metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) - metric_result["metric_to_sample_dict"][metric_values[row].item()].append( - data['index'][row][0].item()) + metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx) for m_value in metric_result["metric_to_sample_dict"]: if len(metric_result["metric_to_sample_dict"][m_value]) > 100: metric_fname = metric_result["metric_to_sample_fname"] @@ -102,7 +124,6 @@ def update_metric_results(self, data, metric_types, metric_functions, metric_res writer.writerows([metric_result["metric_to_sample_dict"][m_value]]) metric_result["metric_to_sample_dict"][m_value] = [] elif metric_type == 'accumulate_value_over_samples': - metric_values = metric_function(data) if metric_result["metric_value"] is None: metric_result["metric_value"] = metric_values else: @@ -136,15 +157,12 @@ def run_map_helper(self, thread_id): f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) - if self.collate_fn is None: - iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False)) - else: - iterator = iter( - DataLoader(thread_dataset, - batch_sampler=sampler, - num_workers=0, - collate_fn=self.collate_fn, - pin_memory=False)) + iterator = iter( + DataLoader(thread_dataset, + batch_sampler=sampler, + num_workers=0, + collate_fn=self.collate_fn, + pin_memory=False)) if self.custom_map_init is None: metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types, self.metric_dtypes, self.save_path, self.worker_id) @@ -157,11 +175,14 @@ def run_map_helper(self, thread_id): while True: try: data = next(iterator) + batch_start_idx = start_idx + processed_sample if self.custom_map_update is None: - self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results) + self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) else: - self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) - processed_sample += self.batch_size + self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) + processed_sample += len(data) duration = (time.time() - start) / 3600.0 remain_duration = duration * total_sample / processed_sample - duration logger.info( @@ -367,26 +388,10 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_ index_to_metric_builder.merge_file_(chunk_im_fname) close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname) close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname) - num_sample_per_value = {} - index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True) - index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True) - index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged" - index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname, - sample_idx_dtype) - for v_idx in range(len(index_to_sample)): - if v_idx > 0: - assert index_to_metric[v_idx] > index_to_metric[v_idx - 1] - num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx]) - assert sum(num_sample_per_value.values()) == total_num_samples - merge_step = len(index_to_sample) // 100 - for v_idx in range(0, len(index_to_sample), merge_step): - merged_samples = np.copy( - np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))], - axis=None)) - index_to_sample_merged_builder.add_item( - torch.tensor(merged_samples.astype(np.int64), dtype=torch.long)) - logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.") - close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname) + + num_sample_per_value = DataAnalyzer.output_index_to_sample_percentile( + index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path, total_num_samples, + sample_idx_dtype) self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples) elif metric_type == 'accumulate_value_over_samples': metric_save_path = f"{save_path}/{metric_name}/" @@ -408,6 +413,29 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_ metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long)) close_mmap_dataset_builder(metric_value_builder, metric_value_fname) + @staticmethod + def output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path, + total_num_samples, sample_idx_dtype): + """ read index_to_metric and index_to_sample files and write distribution to index_to_sample_percentage_merged """ + num_sample_per_value = {} + index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True) + index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True) + index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged" + index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname, sample_idx_dtype) + for v_idx in range(len(index_to_sample)): + if v_idx > 0: + assert index_to_metric[v_idx] > index_to_metric[v_idx - 1] + num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx]) + assert sum(list(num_sample_per_value.values())) == total_num_samples + merge_step = max(1, len(index_to_sample) // 100) + for v_idx in range(0, len(index_to_sample), merge_step): + merged_samples = np.copy( + np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))], axis=None)) + index_to_sample_merged_builder.add_item(torch.tensor(merged_samples.astype(np.int64), dtype=torch.long)) + logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.") + close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname) + return num_sample_per_value + def run_reduce(self): if self.custom_reduce is None: self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path, @@ -415,3 +443,445 @@ def run_reduce(self): else: self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) + + def run_map_reduce(self, comm_group=None): + self.run_map() + # wait for the mapping operation, where all nodes outputs their own (partial) result files + dist.barrier(group=comm_group) + if self.worker_id == 0: + self.run_reduce() + # wait for the reduce, where rank 0 merges all (partial) files. Dataset can then be used by all nodes. + dist.barrier(group=comm_group) + + +class DistributedDataAnalyzer(object): + + def __init__( + self, + dataset, + num_workers=1, + num_threads=1, + worker_id=0, + batch_size=1, + metric_names=[], + metric_functions=[], + metric_types=[], + save_path="./", + collate_fn=None, + device='cuda', + comm_group=None, + sample_indices=None, + ) -> None: + self.dataset = dataset + self.batch_size = batch_size + self.metric_names = metric_names + self.metric_functions = metric_functions + self.metric_types = metric_types + self.save_path = save_path + self.collate_fn = collate_fn + self.device = device + self.sample_indices = sample_indices + self.num_threads = num_threads + self.worker_id = worker_id + + if not dist.is_initialized(): + dist.init_distributed() + + # comm_group and worker_id+num_workers are mutually exclusive + self.comm_group = comm_group + if self.comm_group is None: + # self.comm_group = deepspeed.utils.groups._clone_world_group() + self.num_workers = num_workers + self.worker_id = worker_id + else: + self.num_workers = self.comm_group.size() + self.worker_id = self.comm_group.rank() + + if self.worker_id == 0: + logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.") + + def run_map_helper(self, thread_id=0, metric_queues=None): + thread_start_idx, thread_end_idx = self.thread_splits[thread_id][0], self.thread_splits[thread_id][1] + worker_dataset = Subset(self.dataset, list(range(thread_start_idx, thread_end_idx))) + sampler = BatchSampler(SequentialSampler(worker_dataset), batch_size=self.batch_size, drop_last=False) + dataloader = DataLoader(dataset=worker_dataset, + batch_sampler=sampler, + num_workers=0, + collate_fn=self.collate_fn, + pin_memory=False) + + # set initial results list + metric_results = [] + for metric_type in self.metric_types: + assert metric_type in ['single_value_per_sample', 'accumulate_value_over_samples'], \ + f"metric_type {metric_type} not implemented." + metric_results.append([] if metric_type == 'single_value_per_sample' else None) + + # iterate dataloader and store metric results + batch_start_idx = thread_start_idx + for data in dataloader: + for m_idx in range(len(self.metric_names)): + metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx] + metric_values = metric_function(data) + assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \ + "metric_function must return a tensor or array" + if isinstance(metric_values, np.ndarray): + metric_values = torch.from_numpy(metric_values) + assert metric_values.dtype in valid_dtypes, \ + f"metric_function result dtype {metric_values.dtype} not supported. Supported dtypes {valid_dtypes}" + + if metric_type == 'single_value_per_sample': + for row in range(metric_values.size()[0]): + value = metric_values[row].item() + sample_idx = batch_start_idx + row # sample idx following dataset iteration order + if isinstance(data, dict) and 'index' in data: # Megatron use case + sample_idx = data['index'][row][0].item() + elif self.sample_indices is not None: # user defined shuffling of indices + sample_idx = self.sample_indices[sample_idx] + metric_results[m_idx].append((value, sample_idx)) + elif metric_type == 'accumulate_value_over_samples': + if metric_results[m_idx] is None: + metric_results[m_idx] = metric_values + else: + metric_results[m_idx].add_(metric_values) + batch_start_idx += len(data) + + if self.num_threads == 1: + return metric_results + + # copy metric_results to the shared queue + assert metric_queues + for m_idx in range(len(self.metric_names)): + results = metric_results[m_idx] + if torch.is_tensor(results): + results = results.item() if results.dim() == 0 else results.tolist() + try: + metric_queues[m_idx].put((thread_id, results)) + except Exception as e: + logger.error(f"Error putting metric results to queue: {e}") + sys.exit(1) + + def run_map_reduce(self): + + # setup individual dataloaders + self.worker_splits, self.thread_splits = split_dataset(self.dataset, + self.num_workers, + self.worker_id, + num_threads=self.num_threads) + node_start_idx, node_end_idx = self.worker_splits[self.worker_id] + logger.info(f"worker {self.worker_id} working on data subset {node_start_idx} to {node_end_idx}.") + + if self.num_threads in [0, 1, None]: + metric_results = self.run_map_helper() + metric_results = [torch.tensor(m).to(self.device) for m in metric_results] + else: + + # create a shared queue of results per metric to be populated by individual threads + with Manager() as manager: + metric_queues = [manager.Queue() for _ in self.metric_names] + threads = [ + Process(target=self.run_map_helper, args=(t, metric_queues)) for t in range(self.num_threads) + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # gather results from shared queues into metric_results + metric_results = [None for _ in self.metric_names] + for m_idx, (queue, metric_type) in enumerate(zip(metric_queues, self.metric_types)): + while not queue.empty(): + t_idx, t_results = queue.get() + t_start_idx, t_end_idx = self.thread_splits[t_idx] + if t_start_idx >= t_end_idx: # no results from this thread + continue #corner case for small datasets and high thread count + t_results = torch.tensor(t_results) + if metric_type == 'single_value_per_sample': + # add thread results to the metric_results list, ordered by thread idx + if metric_results[m_idx] is None: # initialize if needed + metric_results[m_idx] = torch.zeros(node_end_idx - node_start_idx, + t_results.size(1)).to(self.device) + metric_results[m_idx][t_start_idx - node_start_idx:t_end_idx - node_start_idx] = t_results + else: + if metric_results[m_idx] is None: # initialize if needed + metric_results[m_idx] = torch.zeros(t_results.size()).to(self.device) + metric_results[m_idx].add_(t_results) + + # compute dtype for sample ids + total_num_samples = len(self.dataset) + sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1) + logger.info(f"Total number of data samples: {total_num_samples}.") + logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.") + + for m_idx in range(len(self.metric_names)): + metric_values, metric_name, metric_type = \ + metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx] + metric_save_path = f"{self.save_path}/{metric_name}/" + os.makedirs(metric_save_path, exist_ok=True) + + if metric_type == 'single_value_per_sample': + + # Compute sample and metric value dtypes based on range + values, samples = metric_values[:, 0], metric_values[:, 1] + value_min, value_max = Dist.min_max(values, self.comm_group) + sample_min, sample_max = Dist.min_max(samples, self.comm_group) + metric_value_dtype = find_fit_int_dtype(value_min, value_max) + sample_value_dtype = find_fit_int_dtype(sample_min, sample_max) + + # sample_to_metric maps sample ids to metric values, as a list of metric values + sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric" + values = [torch.tensor([x]) for x in metric_values[:, 0]] + self.file_write_ordered(values, sample_to_metric_fname, metric_value_dtype) + + # distributed sorting by values, gives an ordered disjoint subset of keys on nodes + metric_values = Dist.sample_sort(metric_values, self.comm_group, self.num_workers) + metric_to_samples_dict = {} + if len(metric_values) > 0: + for value, sample in metric_values: + if value.item() not in metric_to_samples_dict: + metric_to_samples_dict[value.item()] = [] + metric_to_samples_dict[value.item()].append(sample.item()) + + # index_to_metric and index_to_sample serialize a dicitonary from metric to samples + # index_to_metric stores a key per row, index_to_sample stores the values per row + values = [torch.tensor([x]) for x in metric_to_samples_dict.keys()] + samples = [torch.tensor(metric_to_samples_dict[x]) for x in metric_to_samples_dict.keys()] + index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric" #dict keys + index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample" #dict values + self.file_write_ordered(values, index_to_metric_fname, metric_value_dtype) + self.file_write_ordered(samples, index_to_sample_fname, sample_value_dtype) + + if self.worker_id == 0: + DataAnalyzer.output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname, + metric_name, metric_save_path, total_num_samples, + sample_idx_dtype) + dist.barrier(self.comm_group) + + elif metric_type == 'accumulate_value_over_samples': + metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value" + dist.reduce(metric_values, dst=0, op=dist.ReduceOp.SUM, group=self.comm_group) + metric_value_dtype = find_fit_int_dtype(metric_values.min(), metric_values.max()) + + if self.worker_id == 0: + builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype) + builder.add_item(metric_values.cpu()) + close_mmap_dataset_builder(builder, metric_value_fname) + dist.barrier(self.comm_group) + + def file_write_ordered(self, tensor_list, fname, numpy_dtype): + """ MPI_file_write_ordered extended to write a list of tensors, by one rank, iteratively """ + + # each node has a list of rows (tensors) to be written to the file. + # we will serialize it in order to communicate it in one comm step. + + tkwargs = dict(dtype=torch.int64, device=self.device) + + # 1. gather on rank 0 the number of rows to be sent/recv + row_count = torch.tensor([len(tensor_list)], **tkwargs) + row_counts = torch.zeros(self.num_workers, **tkwargs) + dist.all_gather_into_tensor(row_counts, row_count, group=self.comm_group) + assert row_counts[self.worker_id] == row_count == len(tensor_list), "all_gather failed" + + # 2. gather on rank 0 the sizes of the rows to be sent/recv + row_len = torch.tensor([len(l) for l in tensor_list], **tkwargs) + row_lens = Dist.gather_v(row_len, 0, self.comm_group, self.num_workers, self.worker_id) + + # 4. gather on rank 0 of the total size (sum of all row lengths) to be received + size = torch.tensor([sum(row_len).item()], **tkwargs) + sizes = torch.zeros(self.num_workers, **tkwargs) + dist.all_gather_into_tensor(sizes, size, group=self.comm_group) + assert sizes[self.worker_id] == size.item(), "all_gather did not return the same sizes" #sanity check + + # method to deserializes a buffer into rows of different lengths and write them to file + def write_buffer_to_file(buff, src, builder): + assert self.worker_id == 0, "only rank 0 can write to file" + + # collect all buffers and write them at once + buff = buff.cpu().detach().numpy() + row_offsets = np.cumsum([0] + row_lens[src].tolist()) + arr_list = [] + for i in range(len(row_lens[src])): + arr_list.append(buff[row_offsets[i]:row_offsets[i + 1]]) + builder.add_items(arr_list) + + # 5. rank 0 prepares output folder and file + if self.worker_id == 0: + os.makedirs(os.path.dirname(fname), exist_ok=True) + builder = create_mmap_dataset_builder(fname, numpy_dtype) + + # iterate through ranks that have data to be sent/recv/written + for src in [rank for rank, count in enumerate(row_counts) if count > 0]: + + dist.barrier(group=self.comm_group) + if self.worker_id == 0 and src == 0: # rank 0's write its own data + buffer = torch.cat(tensor_list, dim=0).to(self.device) + write_buffer_to_file(buffer, 0, builder) + elif self.worker_id == 0 and src > 0: # rank 0 receives other rank's data and writes it + buffer = torch.empty(sizes[src].item(), dtype=numpy_dtype, device=self.device) + err = dist.recv(buffer, src=src, group=self.comm_group, tag=src) + assert err == src and len(buffer) > 0, "recv failed" + write_buffer_to_file(buffer, src, builder) + elif self.worker_id == src: # current rank sends data to rank 0 + buffer = torch.cat(tensor_list, dim=0).to(self.device) + dist.send(buffer, 0, group=self.comm_group, tag=src) + + # rank 0 closes the file + if self.worker_id == 0: + close_mmap_dataset_builder(builder, fname) # close file + dist.barrier(self.comm_group) + + +class Dist: + """ auxiliary class to perform distributed operations on tensors""" + + @staticmethod + def min_max(tensor, comm_group): + """ given a distributed tensor, return the min/max values across all ranks""" + + value_min, value_max = tensor.min(), tensor.max() + dist.reduce(value_min, 0, op=dist.ReduceOp.MIN, group=comm_group) + dist.reduce(value_max, 0, op=dist.ReduceOp.MAX, group=comm_group) + return value_min.item(), value_max.item() + + @staticmethod + def gather_v(tensor, dst, comm_group, num_workers, worker_id): + """ MPI_Gatherv. gather tensors of variable sizes in a single rank """ + + # gather the number of rows to be sent/recv + size = torch.tensor([len(tensor)], dtype=torch.int64, device=tensor.device) + sizes = torch.zeros(num_workers, dtype=torch.int64, device=tensor.device) + dist.all_gather_into_tensor(sizes, size, group=comm_group) + assert sizes[worker_id] == size, "all_gather failed" + + # all_gather requires all tensors to be of same size so we need to pad them + max_size = max(sizes).item() + buffer = torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) + buffer[0:size] = tensor.data + buffer_list = None + if worker_id == 0: # create padded recv buffers + buffer_list = [torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) for _ in range(num_workers)] + dist.gather(buffer, buffer_list, dst=dst, group=comm_group) + + # revert padding and return value + if worker_id == 0: + buffer_list = [r[:s.item()] for r, s in zip(buffer_list, sizes)] + return buffer_list + + @staticmethod + def sample_sort(tensor, comm_group, num_workers, n_samples=100): + """ perform a distributed random sort of a tensor, and returns the sorted partial tensor""" + device, dims = tensor.device, tensor.size()[1] + + # 1 - sort rows by first column, then second column, then third, etc... + tensor = torch.tensor(sorted(tensor.tolist()), dtype=tensor.dtype, device=tensor.device) + + # 2 - collect few samples per rank + idx = torch.round(torch.linspace(0, len(tensor) - 1, n_samples)).to(int) + samples = tensor[idx][:, 0].contiguous().to(device) #only first column, all but last row + + # 2 - Allgather samples + all_samples = [torch.zeros(n_samples, dtype=samples.dtype, device=device) for _ in range(num_workers)] + dist.all_gather(all_samples, samples, group=comm_group) + all_samples = torch.cat(all_samples, dim=0).to(device) + + # 3 - Sort all samples and collect the ranges of each rank as equidistant + all_samples = all_samples.sort()[0] + idx = torch.round(torch.linspace(0, len(all_samples) - 1, num_workers + 1)).to(int) + ranges = all_samples[idx] # range of each rank r as ranges[r] <= x < ranges[r+1] + ranges[-1] += 1 # increase upper limit of last rank so that x < ranges[r+1]. + + # 4 - collect elements to send to each rank, based on the rank ranges + send = [] + for rank in range(num_workers): + mask = (tensor[:, 0] >= ranges[rank]) & (tensor[:, 0] < ranges[rank + 1]) + send.append(tensor[mask]) + + # 5. all to all to communicate the sizes to be sent/recv + send_count = [torch.tensor([len(s) * dims], dtype=torch.int64, device=device) for s in send] + recv_count = list(torch.empty([num_workers], dtype=torch.int64, device=device).chunk(num_workers)) + dist.all_to_all(recv_count, send_count, group=comm_group) + + # 6. all-to-all-v to communicate the elements to be sent/recv as a single tensor + send = torch.cat(send, dim=0).flatten().to(device) + recv = torch.zeros(sum(recv_count), dtype=send.dtype).to(device) + send_count = [s.item() for s in send_count] # convert to list of ints + recv_count = [r.item() for r in recv_count] + dist.all_to_all_single(recv, send, recv_count, send_count, group=comm_group) + del send + + # 7. the received tensor is the 1D disjoint subset of the distributed tensor. + # We will recover the original dimensionality and sort it by columns again. + recv = recv.view(-1, dims) + recv = torch.tensor(sorted(recv.tolist()), dtype=recv.dtype, device=recv.device) + return recv + + +def test_compare_both_data_analyzers(dataset): + """ given a dataset, compare file and memory based data analyser""" + + id = lambda t: t.to(torch.int64) # identity + batch_sum = lambda t: id(t).sum() #sum batch + num_threads = 4 + kwargs = dict( + dataset=dataset, + batch_size=2**10, + worker_id=int(os.environ['RANK']), + num_workers=int(os.environ['WORLD_SIZE']), + metric_names=["mod", "batch_sum"], + metric_functions=[id, batch_sum], + metric_types=['single_value_per_sample', 'accumulate_value_over_samples'], + num_threads=num_threads, + ) + + dda = DistributedDataAnalyzer( + save_path="./output_dist", + device=f"cuda:{int(os.environ['LOCAL_RANK'])}", + **kwargs, + ) + start_time = time.time() + dda.run_map_reduce() + if dda.worker_id == 0: + print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time)) + + da = DataAnalyzer(num_threads_reduce=num_threads, + save_path="./output_disk", + metric_dtypes=[torch.int64, torch.int64], + **kwargs) + start_time = time.time() + da.run_map_reduce() + if da.worker_id == 0: + print("DataAnalyzer runtime: %s seconds " % (time.time() - start_time)) + + output_paths = [ + "batch_sum/batch_sum_metric_value.bin", "batch_sum/batch_sum_metric_value.idx", \ + "mod/mod_index_to_metric.bin", "mod/mod_index_to_metric.idx", \ + "mod/mod_index_to_sample.bin", "mod/mod_index_to_sample.idx", \ + "mod/mod_index_to_sample_percentile_merged.bin", "mod/mod_index_to_sample_percentile_merged.idx", \ + "mod/mod_sample_to_metric.bin", "mod/mod_sample_to_metric.idx" + ] + + if dda.worker_id == 0: + for path in output_paths: + with open(os.path.join(da.save_path, path), 'rb') as f1, \ + open(os.path.join(dda.save_path, path), 'rb') as f2: + # if files have suffix .bin, they should be identical + if path.endswith(".bin"): + assert f1.read() == f2.read(), f"files {path} are not identical." + elif f1.read() != f2.read(): + print(f"files {path} are not identical.") + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + + class TestDataset(torch.utils.data.Dataset): + + def __init__(self, size=10_000_000): + self.values = [(x + 7) % 10_000 for x in range(size)] + self.size = size + + __len__ = lambda self: self.size + __getitem__ = lambda self, idx: self.values[idx] + + test_compare_both_data_analyzers(TestDataset()) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py index ef845e4bc490..100bef3f7946 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py @@ -119,9 +119,15 @@ def set_custom_curriculum_learning_schedule(self, schedule_func_dict): if metric in schedule_func_dict: self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric]) - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.micro_batch_size - end_idx = start_idx + self.micro_batch_size + def get_start_end_idx(self, batch_len=None): + """ + given the length of a minibatch (defaults to micro-batch size * data_parallel_size), + return the start and end indices of the current data parallel rank + """ + batch_len = batch_len or self.micro_batch_times_data_parallel_size + start_idx_fn = lambda r: round(r * batch_len / self.data_parallel_group.size()) + start_idx = start_idx_fn(self.data_parallel_rank) + end_idx = start_idx_fn(self.data_parallel_rank + 1) return start_idx, end_idx def get_sample_based_on_metric_value(self, metric, value_start, value_end): @@ -281,12 +287,17 @@ def get_next_global_batch(self): for cidx in range(len(samples_per_cluster)): batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx]) self.np_rng.shuffle(batch) + + # broadcast tensor must have same shape across participants. So we fill batch with -1s when not full + assert len(batch) <= self.global_batch_size + batch += [-1] * (self.global_batch_size - len(batch)) batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1) else: batch = torch.empty(self.global_batch_size, device=get_accelerator().current_device_name(), dtype=torch.long) dist.broadcast(batch, 0, group=self.data_parallel_group) + batch = batch[batch != -1] # remove trailing -1s used to fill incomplete batch tensor self.batch = batch.tolist() def __iter__(self): @@ -297,7 +308,7 @@ def __iter__(self): self.batch = self.batch[self.micro_batch_times_data_parallel_size:] if len(current_batch) == self.micro_batch_times_data_parallel_size or \ (len(current_batch) > 0 and not self.drop_last): - start_idx, end_idx = self.get_start_end_idx() + start_idx, end_idx = self.get_start_end_idx(len(current_batch)) yield current_batch[start_idx:end_idx] self.consumed_samples += len(current_batch) current_batch = [] diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 60115fa6efef..872d05de0145 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -98,25 +98,26 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) +# valid metric_dtypes as numpy and torch types dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float64, - 7: np.double, - 8: np.uint16, - 9: np.uint32, - 10: np.uint64 + 1: (np.uint8, torch.uint8), + 2: (np.int8, torch.int8), + 3: (np.int16, torch.int16), + 4: (np.int32, torch.int32), + 5: (np.int64, torch.int64), + 6: (np.uint16, None), + 7: (np.uint32, None), + 8: (np.uint64, None), } +valid_dtypes = set([dt[0] for dt in dtypes.values()] + [dt[1] for dt in dtypes.values() if dt[1] is not None]) + def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: - return k - raise ValueError(dtype) + for c, (np_dt, torch_dt) in dtypes.items(): + if dtype in [np_dt, torch_dt]: + return c + raise ValueError(f"{dtype} not supported. Supported types: {valid_dtypes}") def index_file_path(prefix_path): @@ -153,7 +154,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(' max_tokens] + if len(long_ids) > 0: + logger.warning(f"Data indices {long_ids} ignored as metrics exceed {max_tokens}.") + logger.info(f"Original dataset length: {len(metrics)}. New dataset length: {len(long_ids)}") + metrics = [m for m in metrics if m[1] not in long_ids] + + def is_microbatch_valid(metrics): + if min_batch_size and len(metrics) < min_batch_size: return False # insufficient sample count + if max_batch_size and len(metrics) > max_batch_size: return False # too many samples + if sum([m[0] for m in metrics]) > max_tokens: return False # exceeds max + return True + + # go through all samples and pack then in microbatches of metric sums below the threshold + # `required_microbatches_of_same_size` means all minibatches in a batch must be of equal size + equal_size_multiple = effective_batch_size if required_microbatches_of_same_size else 1 + microbatches = [] + batch_init = 0 + while batch_init < len(metrics): + + # we iterate over possible effective batch sizes (groups of microbatches of same size) + valid_batch_end = batch_init + for batch_end in range(batch_init + equal_size_multiple, len(metrics), equal_size_multiple): + + # attempt effective batch + batch = metrics[batch_init:batch_end] + + # pick interleaved samples for each microbatch to help with load balancing + # (in the ordered use case), and to replicate what the distributed sampler does. + mbs = [batch[b::equal_size_multiple] for b in range(equal_size_multiple)] + + # if they are all valid micro-batches, keep them until you find longer mbatches, if any + is_batch_valid = all([is_microbatch_valid(mb) for mb in mbs]) + if is_batch_valid: + valid_batch_end = batch_end + + if batch_init == valid_batch_end: break # last batch is not valid (size zero), so we are done + batch = metrics[batch_init:valid_batch_end] + mbs = [batch[b::equal_size_multiple] for b in range(equal_size_multiple)] + batch_init += sum([len(l) for l in mbs]) + microbatches += mbs + + # make sure we give the same number of (micro-)batches to each dataloader by trimming the dataset + assert len(microbatches) >= effective_batch_size, "not enough datapoints to create a single sample per dataloader" + microbatches = microbatches[:len(microbatches) - len(microbatches) % effective_batch_size] + + #compute the effective batch size for each microbatch. + batch_sizes, batch_max_seqlens, microbatch_ids = [], [], [] + for rank in range(0, len(microbatches), effective_batch_size): + batch_id = rank // effective_batch_size + mbs = microbatches[rank:rank + effective_batch_size] + # compute the number of samples (not tokens) in this batch (not microbatch) + n_sequences = sum([len(mb) for mb in mbs]) + # compute the longest sequence (as number of tokens) in this batch (not microbatch) + sequence_ids_per_mb = [[m[1] for m in metrics] for metrics in mbs] + sequence_lens_per_mb = [[m[0] for m in metrics] for metrics in mbs] + batch_max_seqlen = max([max(seqlens) for seqlens in sequence_lens_per_mb]) + batch_and_mb_ids = zip([batch_id] * effective_batch_size, sequence_ids_per_mb) + batch_sizes.append(n_sequences) + batch_max_seqlens.append(batch_max_seqlen) + microbatch_ids += batch_and_mb_ids + if verbose: + n_tokens_per_mb = [sum([m[0] for m in mb]) for mb in mbs] + n_sequences_per_mb = [len(mb) for mb in mbs] + assert all([n <= max_tokens for n in n_tokens_per_mb]), "size of microbatch exceeds max tokens" + logger.info( + f"Batch id {batch_id} contains in total {len(mbs)} microbatches or {n_sequences} sequences. "\ + f"n_sequences per microbatch {n_sequences_per_mb}. "\ + f"n_tokens per microbatch {n_tokens_per_mb}. "\ + f"sequence ids per microbatch: {sequence_ids_per_mb}. "\ + f"sequence lengths per microbatch: {sequence_lens_per_mb}.") + + # return the sample ids of each microbatch, and the batch sizes + assert len(batch_sizes) == len(microbatch_ids) // effective_batch_size + return microbatch_ids, batch_sizes, batch_max_seqlens + + +def scale_lr(base_batch_size, batch_size, base_lr=1, method="linear"): + """ given a reference lr and batch_size, compute the new LR for a given batch size """ + if method == "linear": + # Linear Scaling Rule: "When the minibatch size is multiplied by k, multiply the learning + # rate by k" (Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al) + return base_lr * batch_size / base_batch_size + if method == "sqrt": + # Square Root scaling: "when multiplying the batch size by k, multiply the learning rate + # by √k, to keep the variance in the gradient expectation constant" + # (A. Krizhevsky. One weird trick for parallelizing convolutional neural networks) + return base_lr * math.sqrt(batch_size / base_batch_size) + elif method == None or method.upper() == "NONE": + return base_lr + raise ValueError("Unknown scaling method: {}".format(method)) + + +def dataloader_for_variable_batch_size( + dataset, + microbatch_ids, + batch_max_seqlens, + dataloader_rank=0, + dataloader_batch_size=1, + dataloader_num_replicas=1, + dataloader_collate_fn=None, + dataloader_num_workers=2, + dataloader_pin_memory=False, + required_microbatches_of_same_seqlen=False, + sample_padding_fn=None, +): + + # equidistantly distribute the microbatches across the replicas in an interleaved fashion. + sampler = DistributedSampler( + dataset=microbatch_ids, + num_replicas=dataloader_num_replicas, + rank=dataloader_rank, + shuffle=False, + drop_last=False, + ) + + # collate function wraps user-defined collate function to the variable batch data + def collate_fn_wrapper(list_microbatch_ids): + # each batch is a list of sample ids that fill up to the max tokens per batch + # we return the collated batch of all dataset samples of all input batches. + batch = [] + for batch_id, microbatch_ids in list_microbatch_ids: + batch_data = [dataset[idx] for idx in microbatch_ids] + if required_microbatches_of_same_seqlen: + assert sample_padding_fn is not None, \ + "padding dataloader_padding_fn must be provided if required_microbatches_of_same_seqlen is True" + max_seqlen = batch_max_seqlens[batch_id] + assert all([len(sample) <= max_seqlen for sample in batch_data]), \ + "some samples are longer than the computed max seqlen for the batch those samples belong to" + batch_data = [sample_padding_fn(sample, max_seqlen) for sample in batch_data] + batch += batch_data + return dataloader_collate_fn(batch) if dataloader_collate_fn else batch + + dataloader = DataLoader( + dataset=microbatch_ids, + batch_size=dataloader_batch_size, + sampler=sampler, + num_workers=dataloader_num_workers, + collate_fn=collate_fn_wrapper, + pin_memory=dataloader_pin_memory, + ) + + deepspeed_io_kwargs = dict( + dataset=microbatch_ids, + batch_size=dataloader_batch_size, + pin_memory=dataloader_pin_memory, + data_sampler=sampler, + collate_fn=collate_fn_wrapper, + num_local_io_workers=dataloader_num_workers, + ) + + return dataloader, deepspeed_io_kwargs + + +class VariableBatchSizeLR(LRScheduler): + """ an LR scheduler that scales the LR of a given scheduler's LR """ + + @property + def optimizer(self): + return self.base_lr_scheduler.optimizer + + def __init__(self, + lr_scheduler, + base_batch_size, + batch_sizes, + dataloader, + lr_scaling_method="linear", + last_epoch=-1, + verbose=False): + self.batch_sizes = batch_sizes + self.base_batch_size = base_batch_size + self.lr_scaling_method = lr_scaling_method + self.dataloader = dataloader + self.base_lr_scheduler = lr_scheduler + # the following exist in LRScheduler but not in DeepSpeed's LRScheduler so we redefine them here + self.base_lrs = self.base_lr_scheduler.get_lr() + self.last_epoch = last_epoch + self.verbose = verbose + self.step(0) # scale LR for first sample in the dataloader + + def state_dict(self): + return { + 'base_lr_scheduler': self.base_lr_scheduler.state_dict() + } | { + 'base_batch_size': self.base_batch_size, + 'lr_scaling_method': self.lr_scaling_method, + 'batch_sizes': self.batch_sizes, + 'base_lrs': self.base_lrs, + 'last_epoch': self.last_epoch, + 'verbose': self.verbose, + } + + def load_state_dict(self, state_dict): + self.base_lr_scheduler.load_state_dict(state_dict['base_lr_scheduler']) + self.base_batch_size = state_dict['base_batch_size'] + self.lr_scaling_method = state_dict['lr_scaling_method'] + self.batch_sizes = state_dict['batch_sizes'] + self.base_lrs = state_dict['base_lrs'] + self.last_epoch = state_dict['last_epoch'] + self.verbose = state_dict['verbose'] + + def get_last_lr(self): + return self.base_lr_scheduler._last_lr + + def get_lr(self): + return [group['lr'] for group in self.base_lr_scheduler.optimizer.param_groups] + + def step(self, epoch=None): + # call the base scheduler's step method to get LR for next epoch + # Note: optimizer.step precedes lr_scheduler.step(), so the stepping workflow is: + # init: lr_scheduler.step(0) --> set LR for epoch 0 + # epoch 0: optimizer.step(); lr_scheduler.step(1) --> set LR for epoch 1 + # epoch 1: optimizer.step(); lr_scheduler.step(2) --> set LR for epoch 2 + + # reset unscaled LRs (to the original scheduler's one) to be able to step the base LR scheduler + # Note: epoch==0: reset LR scheduler; epoch==None: scale LR for next epoch; + unscaled_lrs = self.base_lrs if epoch == 0 else self.get_last_lr() + for group, lr in zip(self.base_lr_scheduler.optimizer.param_groups, unscaled_lrs): + group['lr'] = lr + + self.base_lr_scheduler.step(epoch) # set unscaled lr, _step_count, last_epoch, _last_lr for new epoch + + # scale the learning rate for the the next iteration for each parameter group. + self.last_epoch = self.last_epoch + 1 if epoch is None else epoch + # batch sizes are precomputed and stored in batch_sizes se we loop around to get the next one + batch_size = self.batch_sizes[self.last_epoch % len(self.batch_sizes)] + for group in self.base_lr_scheduler.optimizer.param_groups: + group['lr'] = scale_lr(self.base_batch_size, batch_size, group['lr'], self.lr_scaling_method) + + if self.verbose: + logger.info( + f"Next batch id {self.last_epoch}. "\ + f"Reference batch_size {self.base_batch_size} and lr {unscaled_lrs}. "\ + f"Scaled batch_size {batch_size} and lr {self.get_lr()}.") + + +def lr_scheduler_for_variable_batch_size(base_batch_size, + batch_sizes, + dataloader, + lr_scheduler_or_optimizer, + lr_scaling_method='linear', + verbose=False): + """ + returns a class that provides an LR scheduler that scales the learning rate at every + iteration taking into account the batch size of that iteration. + If learning rate is constant, ie no LR scheduler, then the base LR will be taken from the + constant LR values in the optimizer param groups. Otherwise from the scheduler's LR. + + Arguments: + - `base_batch_size`: the batch size that the base LR in the optimizer or scheduler refers to; + - `lr_scaling_method`: method to use to scale LR - see `scale_lr()`; + - `lr_scheduler_or_optimizer`: one instance of `LRScheduler` or `Optimizer` to be used as base; + - `batch_sizes`: the effective batch size of each batch in the dataloader; + + Returns the new LRScheduler + """ + + class StubLRScheduler(LRScheduler): + """ a stub LR scheduler that does not change the LR, keeps it constant """ + + def get_lr(self) -> float: + return self.base_lrs + + if isinstance(lr_scheduler_or_optimizer, Optimizer): + lr_scheduler = StubLRScheduler(lr_scheduler_or_optimizer) + elif hasattr(lr_scheduler_or_optimizer, 'optimizer'): #LRScheduler or DeepSpeed 'object' schedulers + assert isinstance(lr_scheduler_or_optimizer.optimizer, Optimizer) + lr_scheduler = lr_scheduler_or_optimizer + else: + raise ValueError("Unknown type for lr_scheduler_or_optimizer: {}".format(type(lr_scheduler_or_optimizer))) + + return VariableBatchSizeLR(lr_scheduler=lr_scheduler, + base_batch_size=base_batch_size, + batch_sizes=batch_sizes, + dataloader=dataloader, + lr_scaling_method=lr_scaling_method, + verbose=verbose) + + +def get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(dataset, + engine, + dataset_seqlens=None, + dataset_filter_ids=None, + dataloader_collate_fn=None, + sample_padding_fn=None, + batch_seqlens_fn=None): + """ + a simplified call to get_dataloader_and_lr_scheduler_for_variable_batch_size for the deepspeed runtime. + Needs the seqlens of every sample. It will try three alternatives: + - if `dataset_seqlens` is provided by user, use that. + - otherwise, looks for the seqlen metric path (in the connfig) that contains the output of the Data Analyzer + - otherwise, use the user-provided function `batch_seqlens_fn` and call Data Analyzer to output seqlen metric + See `batch_by_seqlens()` for arguments and more documentation. + """ + data_efficiency_config = engine._config.data_efficiency_config + data_sampling_config = data_efficiency_config[DATA_SAMPLING] + batching_config = data_sampling_config[DYNAMIC_BATCHING] + assert batching_config[DYNAMIC_BATCHING_ENABLED], "Dynamic batching is not enabled in the config" + + if dataset_seqlens is None: + # In seqlen provided by user, look for the seqlen metric that was output by the Data Analyzer + # (see the main in deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py for an example) + metrics_path = batching_config[DYNAMIC_BATCHING_METRICS_PATH] + sample_to_seqlen_path = os.path.join(metrics_path, "seqlen/seqlen_sample_to_metric") + if not (os.path.exists(f"{sample_to_seqlen_path}.bin") and os.path.exists(f"{sample_to_seqlen_path}.idx")): + # if the metric files are not found, we run the DataAnalyzer to write the metric files + msg = f"Cannot find metric files for sequence length in {sample_to_seqlen_path}.idx or *.bin." + msg += " We will run data analyzer to generated them..." + logger.warning(msg) + + if batch_seqlens_fn is None: + raise ValueError("sample_seqlen_fn must be provided if dataset_seqlens is not provided") + + DistributedDataAnalyzer( + dataset=dataset, + metric_functions=[batch_seqlens_fn], + collate_fn=dataloader_collate_fn, + batch_size=2**10, # batch size for map-reduce, not training + num_workers=engine.world_size, + worker_id=engine.global_rank, + save_path=pathlib.Path(metrics_path), + metric_types=['single_value_per_sample'], + metric_names=["seqlen"], + device=engine.device, + ).run_map_reduce() + + dataset_seqlens = MMapIndexedDataset(sample_to_seqlen_path, skip_warmup=True) + assert len(dataset_seqlens) == len(dataset), \ + "Seqlens size does not match the input dataset size. If you changed the dataset, delete the metrics_path folder." + + # TODO we are copying all seqlens into memory, we should adapt the code to use an iterative streamer + # and use the other files output by DataAnalyzer that returns an ordered dictionary of seqlen to sample ids + dataset_seqlens = np.array(list(dataset_seqlens), dtype=np.int64).flatten() # from Nx1 to N + + dataloader, lr_scheduler, deepspeed_io_kwargs = get_dataloader_and_lr_scheduler_for_variable_batch_size( + dataset=dataset, + dataset_filter_ids=dataset_filter_ids, + dataset_seqlens=dataset_seqlens, + effective_batch_size=engine.train_batch_size(), + max_tokens=batching_config[DYNAMIC_BATCHING_MAX_TOKENS], + lr_scaling_method=batching_config[DYNAMIC_BATCHING_LR_SCALING_METHOD], + sequence_picking_order=batching_config[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER], + min_batch_size=batching_config[DYNAMIC_BATCHING_MIN_BATCH_SIZE], + max_batch_size=batching_config[DYNAMIC_BATCHING_MAX_BATCH_SIZE], + dataloader_batch_size=engine.train_micro_batch_size_per_gpu(), + dataloader_rank=engine.data_parallel_group.rank(), + dataloader_num_replicas=engine.data_parallel_group.size(), + dataloader_num_workers=data_sampling_config[DATA_SAMPLING_NUM_WORKERS], + dataloader_collate_fn=dataloader_collate_fn, + dataloader_pin_memory=data_sampling_config[DATA_SAMPLING_PIN_MEMORY], + sample_padding_fn=sample_padding_fn, + lr_scheduler_or_optimizer=engine.lr_scheduler or engine.optimizer, + required_microbatches_of_same_size=isinstance(engine, PipelineEngine), + required_microbatches_of_same_seqlen=isinstance(engine, PipelineEngine), + verbose=batching_config[DYNAMIC_BATCHING_VERBOSE], + seed=data_efficiency_config[DATA_EFFICIENCY_SEED], + ) + return dataloader, lr_scheduler, deepspeed_io_kwargs + + +def get_dataloader_and_lr_scheduler_for_variable_batch_size( + dataset, + dataset_seqlens, + max_tokens, + effective_batch_size, + dataset_filter_ids=None, + lr_scaling_method="linear", + min_batch_size=1, + max_batch_size=None, + sequence_picking_order="dataloader", + dataloader_batch_size=1, + dataloader_rank=0, + dataloader_num_replicas=1, + dataloader_num_workers=0, + dataloader_collate_fn=None, + dataloader_pin_memory=False, + lr_scheduler_or_optimizer=None, + required_microbatches_of_same_size=False, + required_microbatches_of_same_seqlen=False, + sample_padding_fn=None, + verbose=False, + seed=None, +): + """ returns a dataloader and LR scheduler for the variable batch size. see `batch_by_seqlens()` for details. """ + + # effective_batch_size = train_micro_batch_size_per_gpu * gradient_accumulation_steps * number of dataloaders + microbatch_ids, batch_sizes, batch_max_seqlens = batch_by_seqlens( + seqlens=dataset_seqlens, + max_tokens=max_tokens, + sequence_ids_per_mb=dataset_filter_ids, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + sequence_picking_order=sequence_picking_order, + effective_batch_size=effective_batch_size, + required_microbatches_of_same_size=required_microbatches_of_same_size, + verbose=verbose, + seed=seed, + ) + + dataloader, deepspeed_io_kwargs = dataloader_for_variable_batch_size( + dataset=dataset, + microbatch_ids=microbatch_ids, + batch_max_seqlens=batch_max_seqlens, + dataloader_rank=dataloader_rank, + dataloader_num_replicas=dataloader_num_replicas, + dataloader_batch_size=dataloader_batch_size, + dataloader_collate_fn=dataloader_collate_fn, + dataloader_num_workers=dataloader_num_workers, + dataloader_pin_memory=dataloader_pin_memory, + required_microbatches_of_same_seqlen=required_microbatches_of_same_seqlen, + sample_padding_fn=sample_padding_fn, + ) + + lr_scheduler = lr_scheduler_for_variable_batch_size(base_batch_size=effective_batch_size, + batch_sizes=batch_sizes, + lr_scaling_method=lr_scaling_method, + lr_scheduler_or_optimizer=lr_scheduler_or_optimizer, + dataloader=dataloader, + verbose=verbose) + + return dataloader, lr_scheduler, deepspeed_io_kwargs diff --git a/deepspeed/runtime/domino/__init__.py b/deepspeed/runtime/domino/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/runtime/domino/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/domino/async_linear.py b/deepspeed/runtime/domino/async_linear.py new file mode 100644 index 000000000000..8e01da500409 --- /dev/null +++ b/deepspeed/runtime/domino/async_linear.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/23.08/megatron/core/tensor_parallel/layers.py + +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist +from typing import Callable + +TP_group = None + + +class DominoAsyncColumnParallelLinearImpl(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp, weight, bias, handle_dic, h_id): # inp: (b, s, k), weight: (m, k), bias (m) + ctx.save_for_backward(inp, weight, bias) + ctx.handle_dic = handle_dic + ctx.h_id = h_id + output = torch.matmul(inp, weight.t()) # (b, s, k) @ (k, m) -> (b, s, m) + if bias is not None: # bias (m) + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + inp, weight, bias = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input = torch.matmul(grad_output, weight) # (b, s, m) @ (m, k) -> (b, s, k) + handle = dist.all_reduce(grad_input, group=TP_group, async_op=True) + ctx.handle_dic[ctx.h_id] = handle + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) # (b*s, m) + + inp = inp.view(inp.shape[0] * inp.shape[1], inp.shape[2]) # (b*s, k) + grad_weight = torch.matmul(grad_output.t(), inp) # (m, b*s) @ (b*s, k) -> (m, k) + + if bias is not None: + grad_bias = grad_output.sum(dim=0) # (b*s, m) -> (m) + return grad_input, grad_weight, grad_bias, None, None + + +class DominoAsyncColumnParallelLinear(torch.nn.Module): + + def __init__(self, + input_size, + output_size, + _tp_group, + config, + init_method: Callable, + bias=True, + skip_bias_add=False): + super(DominoAsyncColumnParallelLinear, self).__init__() + + self.skip_bias_add = skip_bias_add + + global TP_group + if TP_group == None: + TP_group = _tp_group + + self.weight = Parameter( + torch.empty( + output_size, + input_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + if config.perform_initialization: + init_method(self.weight) + + if bias: + self.bias = Parameter( + torch.empty(output_size, device=get_accelerator().current_device_name(), dtype=config.params_dtype)) + + if config.perform_initialization: + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_: torch.Tensor, handle_dic, h_id): + + bias = self.bias if not self.skip_bias_add else None + + output = DominoAsyncColumnParallelLinearImpl.apply(input_, self.weight, bias, handle_dic, h_id) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinearNoComm(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int, + config, + init_method: Callable, + bias: bool = True, + stride: int = 1, + skip_bias_add: bool = False, + ): + super(RowParallelLinearNoComm, self).__init__() + + self.skip_bias_add = skip_bias_add + + self.weight = Parameter( + torch.empty( + output_size, + input_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + if config.perform_initialization: + init_method(self.weight) + if bias: + self.bias = Parameter( + torch.empty( + output_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + + if config.perform_initialization: + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + output = F.linear(input_, self.weight, bias) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py new file mode 100644 index 000000000000..3dfb133373b5 --- /dev/null +++ b/deepspeed/runtime/domino/transformer.py @@ -0,0 +1,605 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +import enum +import deepspeed.comm as dist + +from .async_linear import DominoAsyncColumnParallelLinear, RowParallelLinearNoComm + + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + + +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 + + +class DominoUtil: + + BATCH_0 = "BATCH0" + + BATCH_1 = "BATCH1" + + HANDLE_DIC = {"BATCH0": None, "BATCH1": None} + + +class DominoModule(torch.nn.Module): + """extensions of torch Module.""" + + def __init__(self, ): + super(DominoModule, self).__init__() + + +def _Wait_bwd_comm(input_, dic_, h_id): + return NoOper.apply(input_, dic_, h_id) + + +class NoOper(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_, handle_dic, h_id): + return input_ + + @staticmethod + def forward(ctx, input_, handle_dic, h_id): + ctx.handle_dic = handle_dic + ctx.h_id = h_id + return input_ + + @staticmethod + def backward(ctx, grad_output): + handle = ctx.handle_dic[ctx.h_id] + handle.wait() + return grad_output, None, None + + +class CoreAttention(DominoModule): + + def __init__(self, config, tp_world_size, attn_mask_type=AttnMaskType.causal): + super(CoreAttention, self).__init__() + + self.attn_mask_type = attn_mask_type + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + assert projection_size % tp_world_size == 0, f"projection size {projection_size} should be multiple of TP world size {tp_world_size}" + self.hidden_size_per_partition = projection_size // tp_world_size + self.attention_dropout_rate = config.attention_dropout + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attn_mask=None, + dropout_p=self.attention_dropout_rate, + is_causal=True, + scale=None) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class ShardedAttention(DominoModule): + """Sharded self-attention layer class. + Only support self attention and causal attention mask for now. + """ + + def __init__(self, + config, + mpu, + apply_rotary_pos_emb, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.causal): + super(ShardedAttention, self).__init__() + + assert attention_type == AttnType.self_attn, "Only support self_attn for now!" + + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = config.params_dtype + self.apply_rotary_pos_emb = apply_rotary_pos_emb + + query_projection_size = config.kv_channels * config.num_attention_heads + kv_projection_size = config.kv_channels * config.num_attention_heads + + tp_world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads // tp_world_size + + qkv_projection_per_partition = (query_projection_size + 2 * kv_projection_size) // tp_world_size + + self.query_key_value = DominoAsyncColumnParallelLinear(config.hidden_size, + qkv_projection_per_partition, + mpu.get_tensor_model_parallel_group(), + config=config, + init_method=config.init_method, + bias=config.add_bias_linear) + + self.core_attention = CoreAttention(config, tp_world_size, self.attn_mask_type) + + query_projection_size_per_partition = query_projection_size // tp_world_size + + # Output. + self.dense = RowParallelLinearNoComm(query_projection_size_per_partition, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=config.add_bias_linear, + skip_bias_add=True) + + def forward(self, hidden_states, attention_mask, micro_batch_num, rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + mixed_x_layer, _ = self.query_key_value(hidden_states, DominoUtil.HANDLE_DIC, micro_batch_num) + + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() + + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ + self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, + self.hidden_size_per_attention_head) + + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb, ) * 2) + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb) + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + output, bias = self.dense(context_layer) + return output, bias + + def domino_core_attention_forward(self, mixed_x_layer, attention_mask, rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + # To illustrate the difference between intra-layer overlap and inter-layer overlap + # mixed_x_layer, _ = self.query_key_value(hidden_states, handle_dic, micro_batch_num) + + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() + + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ + self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, + self.hidden_size_per_attention_head) + + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb, ) * 2) + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb) + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # output, bias = self.dense(context_layer) + # return output, bias + + return context_layer + + +class bias_dropout_add(torch.nn.Module): + + def __init__(self, prob: float): + super(bias_dropout_add, self).__init__() + self.dropout = torch.nn.Dropout(prob) + + def forward(self, x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + if bias is not None: + x = x + bias + out = self.dropout(x) + out = out + residual + return out + + +class DominoTransformerLayer(DominoModule): + """A domino single transformer layer. + [s, b, h] -> [s, b, h] + """ + + def __init__(self, + config, + mpu, + apply_rotary_pos_emb, + layer_number, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.causal, + drop_path_rate=0.): + + super(DominoTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.llama_model = False + + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = ShardedAttention(config, + mpu, + apply_rotary_pos_emb, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + + self.hidden_dropout = config.hidden_dropout + + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + ffn_hidden_size = config.ffn_hidden_size + if config.gated_linear_unit: + ffn_hidden_size *= 2 + + self.output_size_c = config.ffn_hidden_size + self.input_size_c = config.hidden_size + self.input_size_r = config.ffn_hidden_size + self.output_size_r = self.input_size_c + + tp_world_size = mpu.get_tensor_model_parallel_world_size() + self.TP_group = mpu.get_tensor_model_parallel_group() + self.output_size_per_partition = self.output_size_c // tp_world_size + self.input_size_per_partition = self.input_size_r // tp_world_size + + self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c, + self.output_size_per_partition, + mpu.get_tensor_model_parallel_group(), + config=config, + init_method=config.init_method, + bias=config.add_bias_linear) + + self.mlp_activation_func = F.gelu + + self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition, + self.output_size_r, + config=config, + init_method=config.output_layer_init_method, + bias=config.add_bias_linear, + skip_bias_add=True) + + self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + + hidden_states0, hidden_states1 = hidden_states + + layernorm_output0 = self.input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + # Micro batch 0: attention + attention_output0, attention_bias0 = self.self_attention(layernorm_output0, + attention_mask, + DominoUtil.BATCH_0, + rotary_pos_emb=rotary_pos_emb) + + fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True) + # End of Micro batch 0: attention + + # Micro batch 1: attention + layernorm_output1 = self.input_layernorm(hidden_states1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + + attention_output1, attention_bias1 = self.self_attention(layernorm_output1, + attention_mask, + DominoUtil.BATCH_1, + rotary_pos_emb=rotary_pos_emb) + fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True) + + # Micro batch 0: Residual connection. + fwd_handle0.wait() + if self.apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = hidden_states0 + + layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0) + + layernorm_output0 = self.post_attention_layernorm(layernorm_input0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + if self.apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = layernorm_input0 + # End of Micro batch 0: Residual connection. + + # ------------ MLP ------------ + # Micro batch 0: MLP + output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + output0 = self.mlp_activation_func(output0) + + # Micro batch 1: Residual connection. + fwd_handle1.wait() + if self.apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = hidden_states1 + + layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1) + + layernorm_output1 = self.post_attention_layernorm(layernorm_input1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + + if self.apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = layernorm_input1 + # End of Micro batch 1: Residual connection. + + hidden_states0, last_mlp_bias = self.linear_fc2(output0) + fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True) + # End of Micro batch 0: MLP + + # Micro batch 1: MLP + output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + output1 = self.mlp_activation_func(output1) + + hidden_states1, last_mlp_bias = self.linear_fc2(output1) + + fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True) + # End of Micro batch 1: MLP + + # ------------ End of MLP ------------ + + fwd_handle0.wait() + hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0) + + fwd_handle1.wait() + hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) + + return hidden_states0, hidden_states1 + + +class DominoTransformer(DominoModule): + """Transformer class.""" + + def __init__(self, + config, + mpu, + apply_rotary_pos_emb, + model_type, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.causal, + post_layer_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0): + super(DominoTransformer, self).__init__() + + self.layer_type = layer_type + self.model_type = model_type + self.post_layer_norm = post_layer_norm + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + self.TP_group = mpu.get_tensor_model_parallel_group() + + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm failed to initialize!" + + self.num_layers = config.num_layers + + self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, config.num_layers)] + + def build_layer(layer_number): + + current_layer_type = layer_type + return DominoTransformerLayer(config, + mpu, + apply_rotary_pos_emb, + layer_number, + layer_type=current_layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1]) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + self.final_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + + self._forward_impl = self.inter_layer_overlap_forward + if config.domino_intra_layer_overlap: + self._forward_impl = self.intra_layer_overlap_forward + + def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + + return self._forward_impl(hidden_states, attention_mask, rotary_pos_emb) + + def inter_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + # hidden_states: [s, b, h] + + hidden_states0, hidden_states1 = torch.chunk(hidden_states, chunks=2, dim=1) + + last_mlp_bias = None + fwd_handle0, fwd_handle1 = None, None + residual0, residual1 = None, None + + layernorm_output0 = self.layers[0].input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + for index in range(self.num_layers): + + # Micro batch 0: attention + attention_output0, _ = self.layers[index].self_attention.query_key_value( + layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + attention_output0 = self.layers[index].self_attention.domino_core_attention_forward( + attention_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) + + # Micro batch 1: Residual connection + if index > 0: + fwd_handle1.wait() + hidden_states1 = self.layers[index - 1].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) + + layernorm_output1 = self.layers[index].input_layernorm(hidden_states1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + # End of Micro batch 1: Residual connection + + attention_output0, attention_bias0 = self.layers[index].self_attention.dense(attention_output0) + + fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True) + # End of Micro batch 0: attention + + # Micro batch 1: attention + attention_output1, _ = self.layers[index].self_attention.query_key_value( + layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + attention_output1 = self.layers[index].self_attention.domino_core_attention_forward( + attention_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) + + # Micro batch 0: Residual connection. + fwd_handle0.wait() + if self.layers[index].apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = hidden_states0 + + layernorm_input0 = self.layers[index].bias_dropout_add_func(attention_output0, attention_bias0, residual0) + + layernorm_output0 = self.layers[index].post_attention_layernorm(layernorm_input0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + if self.layers[index].apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = layernorm_input0 + # End of Micro batch 0: Residual connection. + + attention_output1, attention_bias1 = self.layers[index].self_attention.dense(attention_output1) + fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True) + # End of Micro batch 1: attention + + # ------------ MLP ------------ + # Micro batch 0: MLP + output0, _ = self.layers[index].linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + output0 = self.layers[index].mlp_activation_func(output0) + + # Micro batch 1: Residual connection. + fwd_handle1.wait() + if self.layers[index].apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = hidden_states1 + + layernorm_input1 = self.layers[index].bias_dropout_add_func(attention_output1, attention_bias1, residual1) + + layernorm_output1 = self.layers[index].post_attention_layernorm(layernorm_input1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + + if self.layers[index].apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = layernorm_input1 + # End of Micro batch 1: Residual connection. + + hidden_states0, last_mlp_bias = self.layers[index].linear_fc2(output0) + fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True) + # End of Micro batch 0: MLP + + # Micro batch 1: MLP + output1, _ = self.layers[index].linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + output1 = self.layers[index].mlp_activation_func(output1) + + # Micro batch 0: Residual connection. + fwd_handle0.wait() + hidden_states0 = self.layers[index].bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0) + + if index < self.num_layers - 1: + layernorm_output0 = self.layers[index + 1].input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + # End of Micro batch 0: Residual connection. + + hidden_states1, last_mlp_bias = self.layers[index].linear_fc2(output1) + + fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True) + # End of Micro batch 1: MLP + + # ------------ End of MLP ------------ + + if self.post_process and self.post_layer_norm: + hidden_states0 = self.final_layernorm(hidden_states0) + + index = self.num_layers - 1 + + fwd_handle1.wait() + hidden_states1 = self.layers[index].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) + + if self.post_process and self.post_layer_norm: + hidden_states1 = self.final_layernorm(hidden_states1) + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + + return hidden_states + + def intra_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + + hidden_states = torch.chunk(hidden_states, chunks=2, dim=1) + + for index in range(self.num_layers): + layer = self.layers[index] + hidden_states = layer(hidden_states, attention_mask, rotary_pos_emb) + + hidden_states0, hidden_states1 = hidden_states + if self.post_process and self.post_layer_norm: + hidden_states0 = self.final_layernorm(hidden_states0) + hidden_states1 = self.final_layernorm(hidden_states1) + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + return hidden_states diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index df63854dd1ca..a4609c5eb6cd 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -7,6 +7,7 @@ from deepspeed.utils import log_dist import numpy as np import logging +from deepspeed.utils.torch import required_torch_version class Eigenvalue(object): @@ -36,12 +37,15 @@ def __init__(self, ranks=[0]) # Replace all nan/pos-inf/neg-inf to zero - # TODO: Pytorch new version may add this function, replace this one by then. def nan_to_num(self, x): - device = x.device - x = x.cpu().numpy() - x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) - return torch.from_numpy(x).to(device) + if required_torch_version(min_version=1.8): + return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + else: + # Fallback to numpy based implementation for backwards-compatibility with PyTorch 1.7 or older versions. + device = x.device + x = x.cpu().numpy() + x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + return torch.from_numpy(x).to(device) def normalize(self, v): norm_squared = self.inner_product(v, v) @@ -103,15 +107,15 @@ def compute_eigenvalue(self, module, device=None, scale=1.0): # Disable eigenvalue if the model doesn't support second order gradients computation, # e.g. when enabling DS transformer kernel. if len(grads) == 0 or len(params) == 0: - log_dist(f'The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING) + log_dist('The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING) return [] i = 0 eigenvalue_current, eigenvalue_previous = 1., 0. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >= - self.tol): # test convergence criteria + (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) + >= self.tol): # test convergence criteria eigenvalue_previous = eigenvalue_current Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py old mode 100644 new mode 100755 index e953938c06a4..092c391ead6e --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -16,32 +16,47 @@ from torch.nn.parameter import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from contextlib import contextmanager -from typing import Callable, Dict, Union, Iterable +from typing import Callable, Dict, Union, Iterable, Container, List import deepspeed -from deepspeed.runtime.utils import see_memory_usage, DummyOptim -from .zero.offload_config import OffloadDeviceEnum +from deepspeed import comm as dist +from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters +from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum +from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer +from deepspeed.runtime.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException -from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, ZeROOrderedDict, ensure_zero_ordered_dict from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION +from deepspeed.runtime.zenflow.engine import (configure_zenflow, zenflow_step, is_zenflow_update_boundary, + sync_zenflow_optimizer_lr) from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer +from deepspeed.runtime.fp16.loss_scaler import LossScaleConfig, LossScaleProfile from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear +from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime, collect_autotp_universal_checkpoint_info from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ - TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER + TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ + MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER + +from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \ + CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION from deepspeed.runtime.dataloader import DeepSpeedDataLoader +from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \ - DATA_PARALLEL_GROUP, GLOBAL_RANK + DATA_PARALLEL_GROUP, GLOBAL_RANK, DDP_BFLOAT16 from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.compression import compression_scheduler from deepspeed.compression.constants import \ @@ -55,17 +70,25 @@ WEIGHT_QUANTIZE_ROUNDING, \ WEIGHT_QUANTIZE_VERBOSE, \ WEIGHT_QUANTIZE_KERNEL -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO +from deepspeed.checkpoint.utils import clone_tensors_for_torch_save +from deepspeed.checkpoint.ds_to_universal import dp_index_to_str from deepspeed.runtime.sparse_tensor import SparseTensor from deepspeed.runtime import lr_schedules from deepspeed.utils import groups -from deepspeed.utils import logger, log_dist, instrument_w_nvtx -from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer -from deepspeed.utils.debug import debug_extract_module_and_param_names +from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx +from deepspeed.utils.torch import required_torch_version +from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config +from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \ + FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \ + STEP_MICRO_TIMER, \ + FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ + STEP_GLOBAL_TIMER +from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_ +from deepspeed.runtime.utils import clip_grad_norm_, compare_tensors_in_structures, maybe_loss_for_backward from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ @@ -74,30 +97,37 @@ RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \ RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler +from deepspeed.runtime.checkpoint_engine import (create_checkpoint_engine, TorchCheckpointEngine, CheckpointCommitInfo) + from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop -from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine +from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint +from deepspeed.runtime.torch_autocast import init_autocast_params, get_default_autocast_lower_precision_modules, autocast_if_enabled from .pipe.module import PipelineModule from .utils import get_ma_status +from .compiler import is_compile_supported, compiled_autograd from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE -from ..moe.utils import is_moe_param +from ..moe.utils import is_moe_param, configure_moe_param_groups from ..git_version_info import version from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler from deepspeed.utils.logging import print_json_dist, print_configuration from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.runtime.config import DtypeEnum -# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init -dist = None +from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue +from deepspeed.compile.backend import register_compile_pass, opt_passes +from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states +from deepspeed.compile.init_z1 import init_z1 +from deepspeed.compile.init_z3 import init_z3 +from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 @@ -116,33 +146,25 @@ def split_half_float_double_sparse(tensors): device_type = get_accelerator().device_name() - supported_types = [ - "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), - "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type), - SparseTensor.type() - ] + supported_types = get_accelerator().supported_dtypes() for t in tensors: - assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}" + assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}" - buckets = [] + sparse_tensor_buckets, dense_tensor_buckets = [], [] for i, dtype in enumerate(supported_types): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append((dtype, bucket)) - return buckets - - -FORWARD_MICRO_TIMER = 'forward_microstep' -FORWARD_GLOBAL_TIMER = 'forward' -BACKWARD_MICRO_TIMER = 'backward_microstep' -BACKWARD_GLOBAL_TIMER = 'backward' -BACKWARD_INNER_MICRO_TIMER = 'backward_inner_microstep' -BACKWARD_INNER_GLOBAL_TIMER = 'backward_inner' -BACKWARD_REDUCE_MICRO_TIMER = 'backward_allreduce_microstep' -BACKWARD_REDUCE_GLOBAL_TIMER = 'backward_allreduce' -STEP_MICRO_TIMER = 'step_microstep' -STEP_GLOBAL_TIMER = 'step' + sparse_bucket, dense_bucket = [], [] + for t in tensors: + if t.dtype == dtype: + if isinstance(t, SparseTensor): + sparse_bucket.append(t) + else: + dense_bucket.append(t) + if sparse_bucket: + sparse_tensor_buckets.append((dtype, sparse_bucket)) + if dense_bucket: + dense_tensor_buckets.append((dtype, dense_bucket)) + return sparse_tensor_buckets, dense_tensor_buckets class EngineTimers(object): @@ -179,25 +201,27 @@ def __init__(self, enable_micro_timers, enable_global_timers): STEP_GLOBAL_TIMER ] + def active_timers(self): + return self.micro_timers + self.global_timers + class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" - def __init__( - self, - args, - model, - optimizer=None, - model_parameters=None, - training_data=None, - lr_scheduler=None, - mpu=None, - dist_init_required=None, - collate_fn=None, - config=None, - config_class=None, - dont_change_device=False, - ): + def __init__(self, + args, + model, + optimizer=None, + model_parameters=None, + training_data=None, + lr_scheduler=None, + mpu=None, + dist_init_required=None, + collate_fn=None, + config=None, + config_class=None, + mesh_device=None, + dont_change_device=False): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device self.client_optimizer = optimizer @@ -205,6 +229,7 @@ def __init__( self.training_data = training_data self.collate_fn = collate_fn self.mpu = mpu + self.all_to_all_group = None self.data_parallel_group = None self.global_steps = 0 self.global_samples = 0 @@ -217,6 +242,7 @@ def __init__( self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.enable_backward_allreduce = True + self.inside_no_sync_ctxt = False self.progressive_layer_drop = None self.eigenvalue = None self.block_eigenvalue = None @@ -229,40 +255,31 @@ def __init__( self._step_applied = False self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. - self.checkpoint_engine = None + self.optimizer = None + self.basic_optimizer = None + self.lr_scheduler = None - global dist - from deepspeed import comm as dist self._is_gradient_accumulation_boundary = None self.scale_wrt_gas = None + self.losses = None + self.mesh_device = mesh_device + + # Flag to indicate that scale() was called before manual backward pass + self._manual_backward_expected = False # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) - # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict - self.param_names = {param: name for name, param in model.named_parameters()} - - from deepspeed.comm import supported_torch_version - # This supported_torch_version check is for torch1.2 compatibility only - if supported_torch_version: - dist.init_distributed(dist_backend=self.dist_backend, dist_init_required=dist_init_required) - else: - if dist_init_required is None: - dist_init_required = not dist.is_initialized() - - if dist_init_required is False: - assert ( - dist.is_initialized() is True - ), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()" - else: - if not dist.is_initialized(): - dist.init_process_group(backend=self.dist_backend) + if self.mesh_device: + groups.mesh_device = self.mesh_device self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() - see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) + if self.autotp_size() > 1: + self._configure_tensor_parallel(model, self.tensor_parallel_config()) + see_memory_usage("DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) if mpu is not None: if self.elasticity_enabled(): if not self.is_elastic_model_parallel_supported(): @@ -276,32 +293,40 @@ def __init__( self.monitor = MonitorMaster(self._config.monitor_config) see_memory_usage( - f"DeepSpeed Engine: Before configure distributed model", + "DeepSpeed Engine: Before configure distributed model", force=self.memory_breakdown(), ) self.pipeline_parallelism = isinstance(model, PipelineModule) + self._deepcompile_active = False + # Configure distributed model self._configure_distributed_model(model) + # These hooks should be disabled later if DeepCompile is not active. + self.module_forward_pre_hook = self._create_module_forward_pre_hook() + self.module_forward_post_hook = self._create_module_forward_post_hook() + + # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict + self.param_names = {param: name for name, param in model.named_parameters()} + self._get_model_parameters() - see_memory_usage(f"DeepSpeed Engine: After configure distributed model") + see_memory_usage("DeepSpeed Engine: After configure distributed model") # Configure wall clock timers self.timers = SynchronizedWallClockTimer() # Throughput timer - self.tput_timer = ThroughputTimer( - batch_size=self.train_batch_size(), - steps_per_output=self.steps_per_print(), - monitor_memory=False, - ) + self.tput_timer = ThroughputTimer(self._config.timers_config, + batch_size=self.train_batch_size(), + steps_per_output=self.steps_per_print(), + monitor_memory=False) log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0]) if self.flops_profiler_enabled(): - self.flops_profiler = FlopsProfiler(self.module, self) + self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor()) if training_data: self.training_dataloader = self.deepspeed_io(training_data) @@ -309,9 +334,6 @@ def __init__( self.training_dataloader = None # Configure optimizer and scheduler - self.optimizer = None - self.basic_optimizer = None - self.lr_scheduler = None has_optimizer = False if optimizer or self.optimizer_name(): @@ -324,9 +346,24 @@ def __init__( if not isinstance(model_parameters, list): model_parameters = list(model_parameters) + # grad scaler only for Z0 (no ZeRO) + fp16 + torch_autocast + # ZeRO1/2/3 optimizers have their own grad scaler logic + self.torch_autocast_z0_gradscaler = None + if self.torch_autocast_enabled(): + init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) + if (not self.zero_optimization() and self.torch_autocast_dtype() == torch.float16): + self.torch_autocast_z0_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) + + self._configure_zenflow = lambda: configure_zenflow(self) + self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self) + self._zenflow_step = lambda lr_kwargs: zenflow_step(self, lr_kwargs) + self._sync_zenflow_optimizer_lr = lambda: sync_zenflow_optimizer_lr(self) + + self._configure_zenflow() + if has_optimizer: self._configure_optimizer(optimizer, model_parameters) - self._configure_lr_scheduler(lr_scheduler) + self._configure_lr_scheduler() self._report_progress(0) elif self.zero_optimization(): # no optim selected but zero is enabled @@ -334,6 +371,12 @@ def __init__( elif self.bfloat16_enabled(): self.optimizer = self._configure_bf16_optimizer(optimizer=None) + # Hook optimizer for snip_momentum pruning + if hasattr(model, 'pruners'): + from ..compression.helper import rewrite_optimizer_step + self.optimizer.pruners = model.pruners + rewrite_optimizer_step(self.optimizer) + # Bookkeeping for sparse support self.sparse_tensor_module_names = set() # if self.sparse_gradients_enabled(): @@ -342,10 +385,12 @@ def __init__( self.sparse_tensor_module_names.add(name + ".weight") logger.info("Will convert {} to sparse tensor during training".format(name)) + self._optimized_linear_offload_setup() + self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False if not isinstance(self.optimizer, DeepSpeedZeRoOffload): - self._configure_checkpointing(dist_init_required) + self._configure_checkpointing() if self.eigenvalue_enabled(): self.eigenvalue = self._configure_eigenvalue() @@ -368,19 +413,257 @@ def __init__( enable_global_timers=self.wall_clock_breakdown() or self.flops_profiler_enabled()) + self.engine_timers_cache = {} + if self.global_rank == 0: self._config.print("DeepSpeedEngine configuration") if self.dump_state(): print_configuration(self, "DeepSpeedEngine") - # Load pre-installed or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten + # Use torch (un)flatten ops + self.flatten = _flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors + + self._is_compiled = False + if is_deepcompile_supported(): + # Predefined compile passes + self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release) + self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch) + self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather) + self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states) + + # We now support PyTorch style backward, but it relies on the counter in ZeRO optimizers. + # However, we need some internal APIs to count the number of only used parameters. + # So we only enable this feature when those internal APIs are available. + # Otherwise, we fallback to DeepSpeed style backward only. + # See `count_used_parameters_in_backward` for more details. + self._running_engine_backward = False + self._support_torch_style_backward = False + # Flag to control whether gradients should be scaled by gradient accumulation steps + self._scale_wrt_gas = True + if isinstance(self.optimizer, ZeROOptimizer) and check_internal_apis_for_count_used_parameters(): + self._support_torch_style_backward = True + # These hooks are used for non-scalar backward support, such as `out.backward(out_grad)`, + # not for `engine.backward(loss)`. In this case, we need to ensure that the preprocessing + # and postprocessing around the backward call are handled correctly. + # However, we cannot use `register_full_backward_hook` for post-backward hooks. + # If none of the module inputs require gradients, `register_full_backward_hook` fires + # when the gradients of the module outputs are computed. Our gradient + # accumulation hooks are called later. But we want `_backward_post_hook` to be called + # only after all gradients have been computed. + # To handle this, the optimizer maintains a counter to track the number of gradients + # that have been computed. When all gradients are ready, it calls `_backward_post_hook`. + # See also: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook + self.optimizer.register_grad_acc_post_hook(self._backward_post_hook) + + self._is_compiled_autograd_enabled = False + self._compile_kwargs = {} + + if self.dist_backend is None: + self.enable_backward_allreduce = False + + def _optimized_linear_offload_setup(self): + self.optimized_linear_base_weight_sharding = False + self.optimized_linear_lora_enabled = False + offload_ratio = None + for _, module in self.module.named_modules(): + if isinstance(module, LoRAOptimizedLinear): + self.optimized_linear_lora_enabled = True + if offload_ratio is not None: + assert offload_ratio == module.lora_config.offload_ratio, \ + "all lora_config offload ratios should be the same across the model" + offload_ratio = module.lora_config.offload_ratio + if module.zero_shards > 1: + # set attr so checkpoint saving can handle BWS properly + self.optimized_linear_base_weight_sharding = True + + if offload_ratio is None: + # Nothing enabled, do nothing + return + + total_params = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + total_params += p.numel() + + offload_limit = total_params * offload_ratio + logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params') + total_offloaded = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + if total_offloaded < offload_limit: + total_offloaded += p.numel() + p.ds_offload = True + p.offload() + else: + p.ds_offload = False + + def _configure_tensor_parallel(self, model, tp_config): + self._configure_tensor_parallel_states(model) + configure_tensor_parallel_runtime(tp_config) + self._apply_autotp_partitioning(model, tp_config) + + def _configure_tensor_parallel_states(self, model): + """ + Configures the tensor parallel states for the model. + This includes setting up the tensor parallel groups, initializing the TP mesh, + and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks. + """ + self._set_client_model(model) + # sanity check + # currently, the compatibility between 'autotp' and 'zero > 1' has not been validated + assert self.zero_optimization_stage( + ) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated" + + self.mpu = groups + self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size()) + + self.first_dataloader_check = None + + def check_dataloader_inputs_same_across_ranks(module, args, kwargs): + + def broadcast_and_check(args, bcast_rank, bcast_group): + if isinstance(args, tuple): + args = list(args) + if len(args) > 0: + if self.mpu.get_tensor_model_parallel_rank() == 0: + _src_args = [args] + dist.broadcast_object_list(object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=torch.device(get_accelerator().current_device_name())) + # Rank 0 does not need to compare with itself + is_equal = True + else: + _src_args = [None] + dist.broadcast_object_list(object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=torch.device(get_accelerator().current_device_name())) + + is_equal = compare_tensors_in_structures(args, _src_args[0]) + + equal_tensor = torch.tensor(is_equal, + dtype=self.communication_data_type, + device=torch.device(get_accelerator().current_device_name())) + dist.all_reduce(equal_tensor, group=bcast_group) + assert torch.equal( + equal_tensor, + torch.tensor(groups.get_tensor_model_parallel_world_size(), + dtype=self.communication_data_type, + device=torch.device(get_accelerator().current_device_name())) + ), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." + + bcast_rank = self.mpu.get_tensor_model_parallel_src_rank() + bcast_group = self.mpu.get_tensor_model_parallel_group() + + broadcast_and_check(args, bcast_rank, bcast_group) + broadcast_and_check(kwargs, bcast_rank, bcast_group) + + logger.info(":The Dataloader has passed the TP group consistency check.") + self.first_dataloader_check.remove() + + self.first_dataloader_check = self.module.register_forward_pre_hook(check_dataloader_inputs_same_across_ranks, + prepend=True, + with_kwargs=True) + + def _apply_autotp_partitioning(self, model, tp_config): + if getattr(model, "ds_autotp_parsed", False): + return + if get_accelerator().is_available() and self.local_rank >= 0: + get_accelerator().set_device(self.local_rank) + + tp_size = self.autotp_size() + if tp_config.tensor_parallel.tp_size not in (1, tp_size): + raise ValueError(f"tensor_parallel.tp.tp_size ({tp_config.tensor_parallel.tp_size}) " + f"does not match tensor_parallel.autotp_size ({tp_size}).") + tp_config.tensor_parallel.tp_size = tp_size + if tp_config.tensor_parallel.tp_group is None: + tp_config.tensor_parallel.tp_group = groups.get_tensor_model_parallel_group() + + from deepspeed.module_inject.auto_tp import AutoTP + + # Tensor parallel priority: custom config > HF tp_plan > AutoTP + partition_config = None + if hasattr(tp_config, "get_partition_config_object"): + partition_config = tp_config.get_partition_config_object() + + if partition_config is not None: + autotp = AutoTP(module=model, + all_reduce_linears=(), + prefix="", + state_dict=None, + linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding), + orig_layer_impl=None, + keep_module_on_host=tp_config.keep_module_on_host, + partition_config=partition_config) + autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group) + autotp.update_linear_policies() + autotp._replace_module(model) + setattr(model, UNIVERSAL_CHECKPOINT_INFO, collect_autotp_universal_checkpoint_info(model)) + setattr(model, "ds_autotp_parsed", True) + return + + if tp_size <= 1: + setattr(model, "ds_autotp_parsed", True) + return + + model_config = getattr(model, "config", None) + from deepspeed.module_inject import replace_transformer_layer + + from deepspeed.runtime.tensor_parallel.config import _get_hf_tp_plan + + hf_tp_plan = _get_hf_tp_plan(model) + if hf_tp_plan: + from deepspeed.module_inject.tp_plan_converter import TPPlanConverter + from deepspeed.module_inject.autotp_config import AutoTPConfig + + layer_specs = TPPlanConverter.convert(hf_tp_plan) + if layer_specs is not None: + logger.info(f"Using HuggingFace tp_plan with {len(layer_specs)} layer specifications") + tp_plan_config = AutoTPConfig(tp_size=tp_size, layer_specs=layer_specs) + autotp = AutoTP( + module=model, + all_reduce_linears=(), + prefix="", + state_dict=None, + linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding), + orig_layer_impl=None, + keep_module_on_host=tp_config.keep_module_on_host, + partition_config=tp_plan_config, + ) + autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group) + autotp.update_linear_policies() + autotp._replace_module(model) + setattr(model, "ds_autotp_parsed", True) + return + + parser_dict = AutoTP.tp_parser(model) + for client_module, injection_policy in parser_dict: + tp_config.injection_policy_tuple = injection_policy + replace_transformer_layer(client_module, model, None, tp_config, model_config) + + setattr(model, UNIVERSAL_CHECKPOINT_INFO, collect_autotp_universal_checkpoint_info(model)) + setattr(model, "ds_autotp_parsed", True) + + def __del__(self): + try: + self.destroy() + except Exception as exc: + # Avoid destructor-time exceptions for partially initialized engines. + logger.debug("DeepSpeedEngine.__del__ cleanup skipped: %s", exc, exc_info=True) def destroy(self): - if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): - self.optimizer.destroy() + optimizer = getattr(self, "optimizer", None) + if optimizer is not None and hasattr(optimizer, 'destroy'): + optimizer.destroy() + if self.is_deepcompile_active(): + get_deepcompile_handle().cleanup() + debug_clear_module_and_param_names() + + checkpoint_engine = getattr(self, "checkpoint_engine", None) + if checkpoint_engine is not None and checkpoint_engine.is_decoupled(): + checkpoint_engine.cleanup() def _get_model_parameters(self): if self.autotuning_profile_model_info(): @@ -389,7 +672,7 @@ def _get_model_parameters(self): trainable_num_params = 0 for p in self.module.parameters(): - # since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attrbuite to check if the parameter is partitioned in zero 3 already or not + # since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not n = 0 if hasattr(p, "ds_tensor"): # if the parameter is partitioned in zero 3 n += p.ds_numel @@ -432,12 +715,23 @@ def set_train_batch_size(self, train_batch_size): """ if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0: #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') - raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism') + raise ValueError('Train batch size must be divisible by micro-batch data parallelism') new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size) # overwrite config self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas + def set_train_micro_batch_size(self, micro_batch_size): + """Adjust the micro batch size(i.e., the micro batch size in every data parallel group), + while keep the gradient accumulation steps the same. + Args: + micro_batch_size (int): The new micro batch size for training. + """ + # overwrite config + new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size + self._config.train_batch_size = new_global_batch_size + self._config.train_micro_batch_size_per_gpu = micro_batch_size + def set_data_post_process_func(self, post_process_func): if self.training_dataloader is not None: self.training_dataloader.post_process_func = post_process_func @@ -473,11 +767,17 @@ def __getattr__(self, name): else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def checkpoint_serialization_enabled(self): + return self._config.checkpoint_config[CHECKPOINT_SERIALIZATION] + + def checkpoint_writer_enabled(self): + return self._config.checkpoint_config[CHECKPOINT_WRITER] is not None + def checkpoint_tag_validation_enabled(self): - return self._config.checkpoint_tag_validation_enabled + return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] != ValidationMode.IGNORE def checkpoint_tag_validation_fail(self): - return self._config.checkpoint_tag_validation_fail + return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] == ValidationMode.FAIL def elasticity_enabled(self): return self._config.elasticity_enabled @@ -574,15 +874,30 @@ def random_ltd_initialize(self): if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]: assert self.client_lr_scheduler is None - raise ValueError(f'not yet support') + raise ValueError('not yet support') #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) + def get_data_parallel_rank(self): + return groups.get_data_parallel_rank() + + def get_tensor_parallel_rank(self): + return groups.get_tensor_model_parallel_rank() + + def get_model_parallel_rank(self): + return groups.get_model_parallel_rank() + + def get_sequence_parallel_group(self): + return self.seq_parallel_group + def wall_clock_breakdown(self): return self._config.wall_clock_breakdown def flops_profiler_enabled(self): return self._config.flops_profiler_config.enabled or self.autotuning_enabled() + def flops_profiler_recompute_fwd_factor(self): + return self._config.flops_profiler_config.recompute_fwd_factor + def flops_profiler_profile_step(self): step = self._config.flops_profiler_config.profile_step if self._config.autotuning_config.enabled: @@ -705,15 +1020,38 @@ def zero_cpu_offload(self): return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu return False + def zero_partial_offload(self): + return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) + + def super_offload(self): + return getattr(self._config.zero_config.offload_optimizer, "super_offload", False) + + def cpuadam_cores_perc(self): + return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9) + def zero_sub_group_size(self): return self._config.zero_config.sub_group_size def zero_optimization_stage(self): return self._config.zero_optimization_stage + def compile_zero_optimization_stage(self): + """Determines if zero-pass is set in deepcompile's passes attributes.""" + return "z1" in self._config.compile_config.passes or "z3" in self._config.compile_config.passes + + def compile_autosp(self): + """Determines if AutoSP is set in deepcompile's passes attributes.""" + return "autosp" in (getattr(self._config.compile_config, "passes", None) or []) + + def mics_shard_size(self): + return self._config.mics_shard_size + def zero_reduce_bucket_size(self): return self._config.zero_config.reduce_bucket_size + def zero_multi_rank_bucket_allreduce(self): + return self._config.zero_config.use_multi_rank_bucket_allreduce + def zero_allgather_bucket_size(self): return self._config.zero_config.allgather_bucket_size @@ -723,6 +1061,13 @@ def zero_optimization_partition_gradients(self): def zero_optimization_partition_weights(self): return self.zero_optimization_stage() >= ZeroStageEnum.weights + def is_first_weights_partition_group(self): + ret = True if self.mics_shard_size() < 0 \ + and self.zero_optimization_partition_weights() else False + if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size(): + ret = True + return ret + def zero_contiguous_gradients(self): return self._config.zero_config.contiguous_gradients @@ -732,6 +1077,9 @@ def zero_load_from_fp32_weights(self): def zero_elastic_checkpoint(self): return self._config.zero_config.elastic_checkpoint + def zero_nvme_offload_optimizer(self): + return getattr(self.optimizer, "swap_optimizer", False) + def zero_max_live_parameters(self): return self._config.zero_config.max_live_parameters @@ -741,6 +1089,9 @@ def zero_max_reuse_distance(self): def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size + def zero_module_granularity_threshold(self): + return self._config.zero_config.module_granularity_threshold + def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -759,14 +1110,32 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def zero_save_muon_momentum_buffer_in_memory(self): + return self._config.zero_config.save_muon_momentum_buffer_in_memory + + def tensor_parallel_config(self): + return self._config.tensor_parallel_config + + def autotp_size(self): + return self._config.tensor_parallel_config.autotp_size + + def graph_harvesting(self): + return self._config.graph_harvesting + def fp16_enabled(self): - return self._config.fp16_enabled + return self._config.float16_config.enabled def bfloat16_enabled(self): - return self._config.bfloat16_enabled + return self._config.bfloat16_config.enabled def fp16_master_weights_and_gradients(self): - return self._config.fp16_master_weights_and_gradients + return self._config.float16_config.fp16_master_weights_and_grads + + def bf16_master_weights_and_gradients(self): + return self._config.bfloat16_config.bf16_master_weights_and_grads + + def bf16_optimizer_states(self): + return self._config.bfloat16_config.bf16_optimizer_states def amp_enabled(self): return self._config.amp_enabled @@ -774,11 +1143,21 @@ def amp_enabled(self): def amp_params(self): return self._config.amp_params + def torch_autocast_enabled(self) -> bool: + return self._config.torch_autocast_enabled + + def torch_autocast_dtype(self) -> torch.dtype: + return self._config.torch_autocast_dtype + + def torch_autocast_lower_precision_safe_modules(self) -> List[str]: + module_names = self._config.torch_autocast_lower_precision_safe_modules + return get_default_autocast_lower_precision_modules() if module_names is None else module_names + def fp16_auto_cast(self): - return self._config.fp16_auto_cast + return self._config.float16_config.auto_cast def loss_scale(self): - return self._config.loss_scale + return self._config.float16_config.loss_scale def gradient_accumulation_steps(self): return self._config.gradient_accumulation_steps @@ -798,8 +1177,15 @@ def communication_data_type(self): if self.fp16_enabled(): return torch.float16 + if self.bfloat16_enabled(): + return torch.bfloat16 + return torch.float32 + @communication_data_type.setter + def communication_data_type(self, value): + self._config.communication_data_type = value + def postscale_gradients(self): return not self._config.prescale_gradients @@ -815,6 +1201,30 @@ def zero_allgather_partitions(self): def zero_round_robin_gradients(self): return self._config.zero_config.round_robin_gradients + def zero_hpz_partition_size(self): + return self._config.zero_config.zero_hpz_partition_size + + def zero_quantized_weights(self): + return self._config.zero_config.zero_quantized_weights + + def zero_quantized_nontrainable_weights(self): + return self._config.zero_config.zero_quantized_nontrainable_weights + + def zero_quantized_gradients(self): + return self._config.zero_config.zero_quantized_gradients + + def zeropp_loco_param(self): + return self._config.zero_config.zeropp_loco_param + + def zero_log_trace_cache_warnings(self): + return self._config.zero_config.log_trace_cache_warnings + + def zero_allgather_sequential(self): + return self._config.zero_config.allgather_sequential + + def is_sanity_checks_enabled(self): + return self._config.zero_config.enable_sanity_checks + def dump_state(self): return self._config.dump_state @@ -822,13 +1232,13 @@ def gradient_clipping(self): return self._config.gradient_clipping def dynamic_loss_scale(self): - return self._config.loss_scale == 0 + return self._config.float16_config.loss_scale == 0 def initial_dynamic_scale(self): - return self._config.initial_dynamic_scale + return self._config.float16_config.initial_dynamic_scale() def dynamic_loss_scale_args(self): - return self._config.dynamic_loss_scale_args + return self._config.float16_config.dynamic_loss_scale_args() def swap_tensor_config(self): return self._config.swap_tensor_config @@ -836,6 +1246,9 @@ def swap_tensor_config(self): def aio_config(self): return self._config.aio_config + def zenflow_config(self): + return self._config.zero_config.zenflow + def get_data_types(self): model_dtype = torch.float32 if self.fp16_enabled(): @@ -843,56 +1256,52 @@ def get_data_types(self): elif self.bfloat16_enabled(): model_dtype = torch.bfloat16 - if self._config.grad_accum_dtype == None: - if model_dtype == torch.bfloat16 and not self.zero_optimization(): - grad_accum_dtype = torch.float32 - else: - grad_accum_dtype = model_dtype + if self._config.grad_accum_dtype is None: + grad_accum_dtype = model_dtype else: grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value - return (model_dtype, grad_accum_dtype) - def _configure_lr_scheduler(self, client_lr_scheduler): - # First check for scheduler in json configuration - lr_scheduler = self._scheduler_from_config(self.optimizer) - if lr_scheduler: - log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) - self.lr_scheduler = lr_scheduler - else: - if isinstance(client_lr_scheduler, Callable): + def _optimizer_has_ckpt_event_prologue(self): + return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue') + + def _optimizer_has_ckpt_event_epilogue(self): + return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue') + + def _configure_lr_scheduler(self): + if self.client_lr_scheduler: + if isinstance(self.client_lr_scheduler, Callable): log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0]) - self.lr_scheduler = client_lr_scheduler(self.basic_optimizer) + self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer) else: log_dist('DeepSpeed using client LR scheduler', ranks=[0]) - self.lr_scheduler = client_lr_scheduler + self.lr_scheduler = self.client_lr_scheduler + else: + # load lr scheduler from json configuration if lr scheduler is not defined and passed in + lr_scheduler = self._scheduler_from_config(self.optimizer) + log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) + self.lr_scheduler = lr_scheduler log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) - def _configure_checkpointing(self, dist_init_required): - self.checkpoint_engine = TorchCheckpointEngine() - - if self._config is not None and self._config.nebula_config.enabled: - try: - from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ - NebulaCheckpointEngine - self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config) - except ImportError as err: - logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}") - self.checkpoint_engine = TorchCheckpointEngine() - - dp_rank = self.global_rank - if self.mpu: - dp_rank = self.mpu.get_data_parallel_rank() + def _configure_checkpointing(self): + # Enable optimization to parallelize checkpointing of DP state + optimize_dp_state = not self.zero_optimization_partition_weights() + self.checkpoint_engine = create_checkpoint_engine(config_params=self._config, + groups=groups, + zero_stage=self.zero_optimization_stage(), + has_moe_layers=self.has_moe_layers, + optimize_dp_state=optimize_dp_state) + dp_rank = groups._get_sequence_data_parallel_rank() rank = self.local_rank if self.use_node_local_storage() else dp_rank - # only the first data parallel process needs to store the model checkpoint - # if you want to use node local storage this must be done by rank 0 on each - # node - self.save_non_zero_checkpoint = (rank == 0) or self.zero_optimization_partition_weights() + # Determine if this data parallel process needs to store the model checkpoint + if self.checkpoint_engine.is_data_parallel_writer(rank) \ + or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()): + self.save_non_zero_checkpoint = True - if self.zero_optimization() or self.bfloat16_enabled(): + if hasattr(self.optimizer, 'dp_process_group'): param_rank = dist.get_rank(group=self.optimizer.dp_process_group) # Only the first parameter parallel process needs to store the @@ -920,13 +1329,13 @@ def _set_distributed_vars(self, args): device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank if device_rank >= 0: get_accelerator().set_device(device_rank) - self.device = torch.device(get_accelerator().device_name(), device_rank) + self.device = torch.device(get_accelerator().device_name(device_rank)) self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() else: self.world_size = 1 self.global_rank = 0 - self.device = torch.device(get_accelerator().device_name()) + self.device = get_accelerator().device() # Configure based on command line arguments def _configure_with_arguments(self, args, mpu): @@ -952,7 +1361,7 @@ def _do_args_sanity_check(self, args): "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." - if hasattr(args, 'local_rank') and args.local_rank != None: + if hasattr(args, 'local_rank') and args.local_rank is not None: assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" if args.local_rank >= 0: @@ -979,6 +1388,12 @@ def _supported_optims(self): # Validate configuration based on command line arguments def _do_sanity_check(self): + if self.fp16_enabled() and not get_accelerator().is_fp16_supported(): + raise ValueError("Type fp16 is not supported on your device.") + + if self.bfloat16_enabled() and not get_accelerator().is_bf16_supported(): + raise ValueError("Type bf16 is not supported on your device.") + expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ @@ -999,30 +1414,31 @@ def _do_sanity_check(self): f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' def _broadcast_model(self): + if self.dist_backend is None: + return def is_replicated(p): if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: return False + elif hasattr(p, 'ds_optim_param'): + # do not broadcast OptimizedLinear parameters, they are unique per base weight shard + return False return True - for p in self.module.parameters(): + for n, p in self.module.named_parameters(): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, + dist.broadcast(p.data, groups._get_expert_broadcast_src_rank(p.group_name), group=self.expert_data_parallel_group[p.group_name]) else: if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.data_parallel_group) + dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: return - if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0: - raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is " - f"not {dtype}: " - f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}") def _set_client_model(self, model): # register client model in _modules so that nn.module methods work correctly @@ -1033,21 +1449,23 @@ def _set_client_model(self, model): def _configure_distributed_model(self, model): self._set_client_model(model) + apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None)) + is_zero_init_model = self.zero_optimization_partition_weights() and any( + [hasattr(param, "ds_id") for param in self.module.parameters()]) if self.fp16_enabled(): - if self.zero_optimization_partition_weights() and any( - [hasattr(param, "ds_id") for param in self.module.parameters()]): + if is_zero_init_model: self.__check_params(self.module, torch.half) self.module.half() elif self.bfloat16_enabled(): - if self.zero_optimization_partition_weights() and any( - hasattr(param, 'ds_id') for param in self.module.parameters()): + if is_zero_init_model: self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: self.__check_params(self.module, torch.float) - if not self.dont_change_device: + # zero.Init() handles device placement of model + if not (self.dont_change_device or is_zero_init_model): self.module.to(self.device) # MoE related initialization @@ -1074,16 +1492,44 @@ def _configure_distributed_model(self, model): # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): if hasattr(module, 'set_deepspeed_parallelism'): - module.set_deepspeed_parallelism() + module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) # Query the groups module to get information about various parallel groups + self.local_all_to_all_group = None + if self.zero_quantized_gradients(): + message = "Using LoCo quantized gradients" if self.zeropp_loco_param() else "Using quantized gradients" + log_dist(message, ranks=[0]) + self.local_all_to_all_group = groups._get_local_all_to_all_group() self.data_parallel_group = groups._get_data_parallel_group() self.dp_world_size = groups._get_data_parallel_world_size() + self.seq_data_parallel_group = groups._get_sequence_data_parallel_group() + self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size() self.mp_world_size = groups._get_model_parallel_world_size() self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() + if self.sequence_parallel_size > 1: + # Inserted Warning for PyTorch < 2.3 + if not required_torch_version(min_version=2.3): + logger.warning( + "DeepSpeed Sequence Parallelism (Ulysses) with PyTorch < 2.3 may encounter " + "rank indexing errors during backward pass when sp_size < world_size. " + "Please use the weighted all-reduce workaround shown in the regression test " + "(https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) " + "or upgrade to PyTorch 2.3+.") + self.communication_data_type = self._config.seq_parallel_communication_data_type + self.seq_parallel_group = groups._get_sequence_parallel_group() - if not self.amp_enabled(): + if dist.get_rank() == 0: + summary = "********** distributed groups summary **********\n" + summary += f"\t {self.dp_world_size=}\n" + summary += f"\t {self.mp_world_size=}\n" + summary += f"\t {self.seq_dp_world_size=}\n" + summary += f"\t {self.sequence_parallel_size=}\n" + summary += "***********************************************" + logger.info(summary) + + if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() # check if parameters are duplicated in optimizer param_groups @@ -1098,7 +1544,7 @@ def ids_list(group): ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0 for group in optimizer.param_groups ]) - assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour." + assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior." def _do_optimizer_sanity_check(self, basic_optimizer): model_dtype, grad_accum_dtype = self.get_data_types() @@ -1116,14 +1562,9 @@ def _do_optimizer_sanity_check(self, basic_optimizer): if self.global_rank == 0: logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****") - if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage( - ) == 1: + ) == 1 and not self.zero_cpu_offload(): return BFLOAT16 - - if model_dtype != grad_accum_dtype: - raise NotImplementedError( - "Model data type and gradient accumulation data type must be equal to use ZeRO") return ZERO_OPTIMIZATION elif amp_enabled: if model_dtype != grad_accum_dtype: @@ -1139,47 +1580,47 @@ def _do_optimizer_sanity_check(self, basic_optimizer): return AMP # data type checks elif model_dtype == grad_accum_dtype: - if model_dtype == torch.bfloat16: - raise NotImplementedError( - "Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation" + if model_dtype == torch.float32: + return None + if model_dtype == torch.bfloat16 and self.pipeline_parallelism: + logger.warning( + "**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****" ) - if model_dtype == torch.float16: - return FP16 - # else optimizer_wrapper = None - elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32: - return BFLOAT16 + return BFLOAT16 + return FP16 if model_dtype == torch.float16 else DDP_BFLOAT16 else: - raise NotImplementedError("unsupported mix of model dtype and gradient accummulation type") + raise NotImplementedError(f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}") return None # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): - if client_optimizer is not None: + if client_optimizer is None: + if self.has_moe_layers: + model_parameters = configure_moe_param_groups(model_parameters) + basic_optimizer = self._configure_basic_optimizer(model_parameters) + log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) + else: if isinstance(client_optimizer, tuple(self._supported_optims())): - client_optimizer.param_groups[:] = [ - pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0 - ] - log_dist("Removing param_group that has no 'params' in the client Optimizer", ranks=[0]) - basic_optimizer = client_optimizer log_dist('Using client Optimizer as basic optimizer', ranks=[0]) else: basic_optimizer = client_optimizer(model_parameters) log_dist('Using client callable to create basic optimizer', ranks=[0]) - if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam): + if (self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam) + and not isinstance(basic_optimizer, deepspeed.ops.lion.DeepSpeedCPULion)): if self.zero_force_ds_cpu_optimizer(): msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.' raise ZeRORuntimeException(msg) - else: - basic_optimizer = self._configure_basic_optimizer(model_parameters) - log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) + + basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0] + log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0]) self._check_for_duplicates(basic_optimizer) self.basic_optimizer = basic_optimizer - log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0]) + log_dist(f"DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", ranks=[0]) optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer) @@ -1192,29 +1633,30 @@ def _configure_optimizer(self, client_optimizer, model_parameters): self._set_client_model(model) self._broadcast_model() # TODO: maybe need to broadcast experts differently? - elif optimizer_wrapper == FP16: - self.optimizer = self._configure_fp16_optimizer(basic_optimizer) + elif optimizer_wrapper in [FP16, DDP_BFLOAT16]: + lp_dtype = torch.float16 if optimizer_wrapper == FP16 else torch.bfloat16 + self.optimizer = self._configure_fp16_optimizer(basic_optimizer, lp_dtype) elif optimizer_wrapper == BFLOAT16: self.optimizer = self._configure_bf16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer - log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0]) + log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() def _configure_basic_optimizer(self, model_parameters): - optimizer_parameters = self.optimizer_params() - if optimizer_parameters is None: - optimizer_parameters = {} + # Copy so the pop() calls below (torch_adam, adam_w_mode, fp32_optimizer_states) do not + # mutate the shared config dict returned by optimizer_params(). + optimizer_parameters = dict(self.optimizer_params() or {}) # print(optimizer_parameters.keys()) if "max_grad_norm" in optimizer_parameters.keys(): raise ValueError( "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) - if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: + if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) @@ -1228,14 +1670,29 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) else: if self.zero_use_cpu_optimizer(): - if self.optimizer_name() == ADAGRAD_OPTIMIZER: - from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad - optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) + from deepspeed.ops.adam import DeepSpeedCPUAdam, ZenFlowCPUAdam + CPUAdam = ZenFlowCPUAdam if self.zenflow else DeepSpeedCPUAdam + + zenflow_kwargs = {'overlap_step': self.overlap_step} if self.zenflow else {} + # Pop so a user-supplied value does not collide with the keyword built below. + # None means the user did not set it, so no override warning is needed. + user_fp32_optimizer_states = optimizer_parameters.pop('fp32_optimizer_states', None) + if self.bf16_optimizer_states(): + # bf16 moments are required so the offloaded state matches the bf16 master weights. + if user_fp32_optimizer_states: + logger.warning("bf16_optimizer_states is enabled; overriding fp32_optimizer_states " + "to False so CPU Adam moments are stored in bf16.") + fp32_optimizer_states = False + elif user_fp32_optimizer_states is None: + # Default preserves the pre-existing fp32 optimizer-state behavior. + fp32_optimizer_states = True else: - from deepspeed.ops.adam import DeepSpeedCPUAdam - optimizer = DeepSpeedCPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=effective_adam_w_mode) + fp32_optimizer_states = user_fp32_optimizer_states + optimizer = CPUAdam(model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode, + fp32_optimizer_states=fp32_optimizer_states, + **zenflow_kwargs) else: from deepspeed.ops.adam import FusedAdam @@ -1245,6 +1702,12 @@ def _configure_basic_optimizer(self, model_parameters): adam_w_mode=effective_adam_w_mode, ) + elif self.optimizer_name() == ADAGRAD_OPTIMIZER: + if self.zero_use_cpu_optimizer(): + from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad + optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) + else: + optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb @@ -1255,21 +1718,74 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16") + logger.warning("Currently the convergence of 1-bit Adam is only verified under FP16") elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER: assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16') + logger.warning('Currently the convergence of 0/1 Adam is only verified under FP16') elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16") + logger.warning("Currently the convergence of 1-bit Lamb is only verified under FP16") + elif self.optimizer_name() == LION_OPTIMIZER: + if self.zero_use_cpu_optimizer(): + from deepspeed.ops.lion import DeepSpeedCPULion + optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters) + else: + from deepspeed.ops.lion import FusedLion + optimizer = FusedLion(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == MUADAM_OPTIMIZER: + try: + from mup import MuAdam + except ImportError: + logger.error("Install mup to use MuAdam optimizer") + optimizer = MuAdam(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == MUADAMW_OPTIMIZER: + try: + from mup import MuAdamW + except ImportError: + logger.error("Install mup to use MuAdamW optimizer") + optimizer = MuAdamW(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == MUSGD_OPTIMIZER: + try: + from mup import MuSGD + except ImportError: + logger.error("Install mup to use MuSGD optimizer") + optimizer = MuSGD(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == MUON_OPTIMIZER: + zero_stage = self.zero_optimization_stage() + if not all([hasattr(p, 'use_muon') for p in model_parameters]): + msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \ + "please set by `param.use_muon = True / False` for all params" + logger.error(msg) + muon_params = [p for p in model_parameters if p.use_muon and p.requires_grad] + non_muon_params = [p for p in model_parameters if (not p.use_muon) and p.requires_grad] + param_groups = [] + if muon_params: + accepted_parameters = dict() + for key in ["lr", "momentum", "weight_decay", "muon_lr", "ns_method"]: + if key in optimizer_parameters: + if key == "muon_lr": # muon_lr will override lr + accepted_parameters['lr'] = optimizer_parameters[key] + else: + accepted_parameters[key] = optimizer_parameters[key] + param_groups.append(dict(params=muon_params, use_muon=True, **accepted_parameters)) + if non_muon_params: + accepted_parameters = dict() + for key in ["lr", "betas", "eps", "weight_decay", "adam_lr"]: + if key in optimizer_parameters: + if key == "adam_lr": # adam_lr will override lr + accepted_parameters['lr'] = optimizer_parameters[key] + else: + accepted_parameters[key] = optimizer_parameters[key] + param_groups.append(dict(params=non_muon_params, use_muon=False, **accepted_parameters)) + optimizer = MuonWithAuxAdam(param_groups) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -1313,50 +1829,56 @@ def _configure_quantization(self): ) return quantizer - def _configure_fp16_optimizer(self, optimizer): - initial_dynamic_scale = self.initial_dynamic_scale() + def _configure_fp16_optimizer(self, optimizer, low_precision_dtype): dynamic_loss_args = self.dynamic_loss_scale_args() clip_grad = self.gradient_clipping() + if APEX_INSTALLED: fused_opts = (apex.optimizers.FusedAdam, FusedAdam) else: fused_opts = FusedAdam - if isinstance(optimizer, fused_opts) \ - or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]: - if self.dynamic_loss_scale(): - log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0]) - timers = self.timers if self.wall_clock_breakdown() else None - optimizer = FP16_Optimizer( - optimizer, - deepspeed=self, - dynamic_loss_scale=True, - initial_dynamic_scale=initial_dynamic_scale, - dynamic_loss_args=dynamic_loss_args, - mpu=self.mpu, - clip_grad=clip_grad, - fused_adam_legacy=self.optimizer_legacy_fusion(), - timers=timers, - has_moe_layers=self.has_moe_layers, - ) + + use_fused_optimizer = isinstance(optimizer, fused_opts) \ + or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER] + loss_scale_profile = LossScaleProfile.FUSED if use_fused_optimizer else LossScaleProfile.UNFUSED + initial_dynamic_scale = self.initial_dynamic_scale() if loss_scale_profile == LossScaleProfile.FUSED else None + loss_scale_config = LossScaleConfig( + low_precision_dtype=low_precision_dtype, + dynamic_loss_scale=self.dynamic_loss_scale(), + static_loss_scale=self.loss_scale(), + dynamic_loss_args=dynamic_loss_args, + profile=loss_scale_profile, + initial_dynamic_scale=initial_dynamic_scale, + ) + + if use_fused_optimizer: + if loss_scale_config.dynamic_loss_scale: + log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0]) else: - log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0]) - optimizer = FP16_Optimizer( - optimizer, - deepspeed=self, - static_loss_scale=self.loss_scale(), - mpu=self.mpu, - clip_grad=clip_grad, - fused_adam_legacy=self.optimizer_legacy_fusion(), - has_moe_layers=self.has_moe_layers, - ) + log_dist(f'Creating fp16 optimizer with static loss scale: {loss_scale_config.cur_scale}', ranks=[0]) + timers = self.timers if self.wall_clock_breakdown() else NoopTimer() + optimizer = FP16_Optimizer( + optimizer, + deepspeed=self, + loss_scale_config=loss_scale_config, + low_precision_dtype=low_precision_dtype, + mpu=self.mpu, + clip_grad=clip_grad, + fused_adam_legacy=self.optimizer_legacy_fusion(), + timers=timers, + has_moe_layers=self.has_moe_layers, + ) else: - log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0]) + if loss_scale_config.dynamic_loss_scale: + log_dist('Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0]) + else: + log_dist(f'Creating fp16 unfused optimizer with static loss scale: {loss_scale_config.cur_scale}', + ranks=[0]) optimizer = FP16_UnfusedOptimizer( optimizer, deepspeed=self, - static_loss_scale=self.loss_scale(), - dynamic_loss_scale=self.dynamic_loss_scale(), - dynamic_loss_args=dynamic_loss_args, + loss_scale_config=loss_scale_config, + low_precision_dtype=low_precision_dtype, mpu=self.mpu, clip_grad=clip_grad, fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER, @@ -1372,21 +1894,35 @@ def _configure_bf16_optimizer(self, optimizer): log_dist('Creating BF16 optimizer', ranks=[0]) - timers = self.timers if self.wall_clock_breakdown() else None + timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = BF16_Optimizer(optimizer, self.param_names, + bfloat16_config=self._config.bfloat16_config, mpu=self.mpu, clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), - dp_process_group=self.data_parallel_group, - timers=timers) + dp_process_group=self.seq_data_parallel_group, + timers=timers, + grad_acc_dtype=self.get_data_types()[1], + graph_harvesting=self.graph_harvesting(), + has_moe_layers=self.has_moe_layers) return optimizer def _configure_zero_optimizer(self, optimizer): zero_stage = self.zero_optimization_stage() - model_dtype, grad_accum_dtype = self.get_data_types() - timers = self.timers if self.wall_clock_breakdown() else None + + mics_shard_size = self.mics_shard_size() + model_dtype, gradient_accumulation_dtype = self.get_data_types() + + if self.bfloat16_enabled(): + check_grad_overflow = self._config.bfloat16_config.check_grad_overflow + elif self.fp16_enabled(): + check_grad_overflow = True + else: + check_grad_overflow = False + + timers = self.timers if self.wall_clock_breakdown() else NoopTimer() if optimizer is None: optimizer = DummyOptim(list(self.module.parameters())) @@ -1403,35 +1939,34 @@ def _configure_zero_optimizer(self, optimizer): assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage) log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) - # Overlap and contiguous grads are meaningless in stage 1 and are ignored - if zero_stage == ZeroStageEnum.optimizer_states: - overlap_comm = False - round_robin_gradients = False - # Non-MoE requires contiguous grads to be disabled w. stage 1 - if not self.has_moe_layers: - contiguous_gradients = False if isinstance(self.module, PipelineModule): if overlap_comm: logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.") overlap_comm = False - optimizer = DeepSpeedZeroOptimizer( + Stage1And2ZeroOptimizer = DeepSpeedZeroOptimizer if not self.zenflow else ZenFlowZeroOptimizer.create( + zenflow_config=self.zenflow_config()) + + optimizer = Stage1And2ZeroOptimizer( optimizer, self.param_names, timers=timers, + optimizer_params=self.optimizer_params(), static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), + use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(), allgather_bucket_size=self.zero_allgather_bucket_size(), - dp_process_group=self.data_parallel_group, + dp_process_group=self.seq_data_parallel_group, expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None, expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=overlap_comm, - cpu_offload=self.zero_cpu_offload(), + offload_optimizer_config=self.zero_offload_optimizer(), + zenflow_config=self.zenflow_config(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), @@ -1441,30 +1976,61 @@ def _configure_zero_optimizer(self, optimizer): round_robin_gradients=round_robin_gradients, has_moe_layers=self.has_moe_layers, fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), + bf16_master_weights_and_gradients=self.bf16_master_weights_and_gradients(), + bf16_optimizer_states=self.bf16_optimizer_states(), + gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type, - elastic_checkpoint=self.zero_elastic_checkpoint()) + elastic_checkpoint=self.zero_elastic_checkpoint(), + check_grad_overflow=check_grad_overflow) elif zero_stage == ZeroStageEnum.weights: assert not self.has_moe_layers, "MoE not supported with Stage 3" if isinstance(optimizer, DummyOptim): log_dist("Creating ZeRO Offload", ranks=[0]) - optimizer = DeepSpeedZeRoOffload(self.module, - timers=timers, - ds_config=self.config, - overlap_comm=self.zero_overlap_comm(), - prefetch_bucket_size=self.zero_prefetch_bucket_size(), - max_reuse_distance=self.zero_max_reuse_distance(), - max_live_parameters=self.zero_max_live_parameters(), - param_persistence_threshold=self.zero_param_persistence_threshold(), - model_persistence_threshold=self.zero_model_persistence_threshold(), - offload_param_config=self.zero_offload_param(), - mpu=self.mpu) + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None: + self._set_zero_group_parallelism() + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + optimizer = DeepSpeedZeRoOffload( + self.module, + timers=timers, + ds_config=self.config, + overlap_comm=self.zero_overlap_comm(), + prefetch_bucket_size=self.zero_prefetch_bucket_size(), + max_reuse_distance=self.zero_max_reuse_distance(), + max_live_parameters=self.zero_max_live_parameters(), + param_persistence_threshold=self.zero_param_persistence_threshold(), + model_persistence_threshold=self.zero_model_persistence_threshold(), + offload_param_config=self.zero_offload_param(), + mpu=self.mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=self.zero_quantized_weights(), + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), + log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), + ) else: + log_dist( + f'Creating fp16 ZeRO stage {zero_stage} optimizer,' + f' MiCS is enabled {mics_shard_size>0},' + f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}', + ranks=[0]) + if mics_shard_size > 0: + return self._return_mics_optimizer(optimizer, timers) + + if self.zero_allgather_sequential(): + log_dist(f"If zero_allgather_sequential is True, set prefetch_bucket_size to 1", ranks=[0]) + self._config.zero_config.prefetch_bucket_size = 1 + log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - optimizer = DeepSpeedZeroOptimizer_Stage3( + from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3 + Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload( + ) else SuperOffloadOptimizer_Stage3 + optimizer = Stage3ZeroOptimizer( self.module, optimizer, + self.param_names, timers=timers, ds_config=self.config, static_loss_scale=self.loss_scale(), @@ -1478,24 +2044,78 @@ def _configure_zero_optimizer(self, optimizer): max_live_parameters=self.zero_max_live_parameters(), param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(), - dp_process_group=self.data_parallel_group, + dp_process_group=self.seq_data_parallel_group, + all2all_process_group=self.local_all_to_all_group, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=self.zero_overlap_comm(), offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), + zenflow_config=self.zenflow_config(), sub_group_size=self.zero_sub_group_size(), + offload_ratio=self.zero_partial_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), aio_config=self.aio_config(), - communication_data_type=self.communication_data_type) + gradient_accumulation_dtype=gradient_accumulation_dtype, + communication_data_type=self.communication_data_type, + fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), + bf16_master_weights_and_gradients=self.bf16_master_weights_and_gradients(), + bf16_optimizer_states=self.bf16_optimizer_states(), + zero_hpz_partition_size=self.zero_hpz_partition_size(), + zero_quantized_weights=self.zero_quantized_weights(), + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), + zeropp_loco_param=self.zeropp_loco_param(), + log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), + enable_sanity_checks=self.is_sanity_checks_enabled(), + cpuadam_cores_perc=self.cpuadam_cores_perc(), + save_muon_momentum_buffer_in_memory=self.zero_save_muon_momentum_buffer_in_memory(), + ) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) return optimizer + def _return_mics_optimizer(self, basic_optimizer, timers): + from deepspeed.runtime.zero.mics import MiCS_Optimizer + model_dtype, gradient_accumulation_dtype = self.get_data_types() + optimizer = MiCS_Optimizer(self.module, + basic_optimizer, + self.param_names, + timers=timers, + ds_config=self.config, + static_loss_scale=self.loss_scale(), + dynamic_loss_scale=self.dynamic_loss_scale(), + dynamic_loss_args=self.dynamic_loss_scale_args(), + clip_grad=self.gradient_clipping(), + contiguous_gradients=self.zero_contiguous_gradients(), + reduce_bucket_size=self.zero_reduce_bucket_size(), + prefetch_bucket_size=self.zero_prefetch_bucket_size(), + max_reuse_distance=self.zero_max_reuse_distance(), + max_live_parameters=self.zero_max_live_parameters(), + param_persistence_threshold=self.zero_param_persistence_threshold(), + model_persistence_threshold=self.zero_model_persistence_threshold(), + dp_process_group=self.seq_data_parallel_group, + reduce_scatter=self.zero_reduce_scatter(), + overlap_comm=self.zero_overlap_comm(), + offload_optimizer_config=self.zero_offload_optimizer(), + offload_param_config=self.zero_offload_param(), + sub_group_size=self.zero_sub_group_size(), + mpu=self.mpu, + postscale_gradients=self.postscale_gradients(), + gradient_predivide_factor=self.gradient_predivide_factor(), + gradient_accumulation_steps=self.gradient_accumulation_steps(), + aio_config=self.aio_config(), + gradient_accumulation_dtype=gradient_accumulation_dtype, + communication_data_type=self.communication_data_type, + fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), + bf16_master_weights_and_gradients=self.bf16_master_weights_and_gradients(), + bf16_optimizer_states=self.bf16_optimizer_states()) + return optimizer + def _configure_eigenvalue(self): eigenvalue = Eigenvalue( verbose=self.eigenvalue_verbose(), @@ -1586,7 +2206,6 @@ def deepspeed_io(self, GLOBAL_RANK: self.global_rank, DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] } - return DeepSpeedDataLoader(dataset=dataset, batch_size=batch_size, pin_memory=pin_memory, @@ -1612,14 +2231,17 @@ def eval(self): self.warn_unscaled_loss = True self.module.train(False) - def _scale_loss_by_gas(self, prescaled_loss): + def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): + # In pipeline evaluation, there is an option to use different micro-bs, which creates different number of + # micro batches, thus the training gas, is not valid in this case. need to use the number of eval_micro_batches + scaling_factor = self.gradient_accumulation_steps() if eval_micro_batches is None else eval_micro_batches if isinstance(prescaled_loss, torch.Tensor): - scaled_loss = prescaled_loss / self.gradient_accumulation_steps() + scaled_loss = prescaled_loss / scaling_factor elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list): scaled_loss = [] for l in prescaled_loss: if isinstance(l, torch.Tensor): - scaled_loss.append(l / self.gradient_accumulation_steps()) + scaled_loss.append(l / scaling_factor) else: scaled_loss.append(l) else: @@ -1630,17 +2252,24 @@ def _scale_loss_by_gas(self, prescaled_loss): return scaled_loss - @instrument_w_nvtx - def forward(self, *inputs, **kwargs): - r"""Execute forward propagation - Arguments: - *inputs: Variable length input list - **kwargs: variable length keyword arguments - """ + def _create_module_forward_pre_hook(self): - if self.autotuning_profile_model_info(): - ma = get_ma_status() - else: + def _module_forward_pre_hook(module, inputs, kwargs): + return self._forward_prologue(inputs, kwargs) + + return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True) + + def _create_module_forward_post_hook(self): + + def _module_forward_post_hook(module, input, output): + self._forward_epilogue() + + return self.module.register_forward_hook(_module_forward_post_hook) + + def _forward_prologue(self, inputs, kwargs): + return_modified = False + + if not self.autotuning_profile_model_info(): see_memory_usage("Engine before forward", force=self.memory_breakdown()) flops_profiler_active = (self.flops_profiler_enabled() @@ -1659,58 +2288,105 @@ def forward(self, *inputs, **kwargs): self.eigenvalue_enabled(), None, ) + return_modified = True if flops_profiler_active: self.flops_profiler.start_profile(ignore_list=None) - if self.module.training: - if self.progressive_layer_drop: - kwargs.update(self.progressive_layer_drop.get_state()) + if kwargs is not None: + if self.module.training: + if self.progressive_layer_drop: + kwargs.update(self.progressive_layer_drop.get_state()) - if self.__class__.__name__ != "PipelineEngine": - # TODO: The above if condition is a HACK since for PipelineEngine - # it's difficult to inject argument in forward pass. - if self.module.training and self.curriculum_enabled_legacy(): - self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) - if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": - kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + if self.__class__.__name__ != "PipelineEngine": + # TODO: The above if condition is a HACK since for PipelineEngine + # it's difficult to inject argument in forward pass. + if self.module.training and self.curriculum_enabled_legacy(): + self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) + if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": + kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + return_modified = True if self.module.training and self.random_ltd_enabled(): self.random_ltd_scheduler.update_seq(self.global_steps) + if self.training_dataloader is None: + self.tput_timer.start() + + self._start_timers(self.engine_timers.forward_timers) + if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in self.module.modules(): + ensure_zero_ordered_dict(module) module._parameters._in_forward = True - self._start_timers(self.engine_timers.forward_timers) - - if self.training_dataloader is None: - self.tput_timer.start() - if self.fp16_auto_cast(): inputs = self._cast_inputs_half(inputs) + return_modified = True - loss = self.module(*inputs, **kwargs) + if return_modified: + return inputs, kwargs + def _forward_epilogue(self): if self.zero_optimization_partition_weights(): # Disable automated discovery of external parameters for module in self.module.modules(): - module._parameters._in_forward = False + if isinstance(module._parameters, ZeROOrderedDict): + module._parameters._in_forward = False self._stop_timers(self.engine_timers.forward_timers) + flops_profiler_active = (self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) + if flops_profiler_active: self.flops_profiler.stop_profile() + if not self.autotuning_profile_model_info(): + see_memory_usage("Engine after forward", force=self.memory_breakdown()) + + @instrument_w_nvtx + def forward(self, *inputs, **kwargs): + r"""Execute forward propagation + Arguments: + *inputs: Variable length input list + **kwargs: variable length keyword arguments + """ + # Clear the backward seen flag at the start of each forward pass. + # This is used to track multiple gradient hook phases with reentrant checkpointing. + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.clear_backward_seen_flag() + + if self.autotuning_profile_model_info(): + ma = get_ma_status() + + if self.is_deepcompile_enabled() and not self.is_deepcompile_active() and not self.is_compiled: + log_dist_once( + "DeepCompile is enabled but engine.compile() has not been called; executing without DeepCompile until compile() runs.", + ranks=[0]) + + if self.is_deepcompile_active() and hasattr(self, "launch_compile_passes"): + # We can't have this in forward prologue as the compiler compiles hooks including the forward prologue. + self.launch_compile_passes(self.global_steps) + + with autocast_if_enabled(self): + loss = self.module(*inputs, **kwargs) + + # Register output backward hooks + # preprocess_once_fn is called for preprocessing + # preprocess_per_tensor_fn scales a tensor for gradient accumulation + register_output_backward_hooks(loss, + preprocess_once_fn=self._backward_prologue, + preprocess_per_tensor_fn=self._backward_prologue_per_tensor) + if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) exit() - else: - see_memory_usage("Engine after forward", force=self.memory_breakdown()) + return loss def _cast_inputs_half(self, inputs): @@ -1724,7 +2400,7 @@ def _cast_inputs_half(self, inputs): for k, v in inputs.items(): new_inputs[k] = self._cast_inputs_half(v) return new_inputs - elif hasattr(inputs, 'half'): + elif hasattr(inputs, 'half') and inputs.is_floating_point(): return inputs.half() else: return inputs @@ -1749,13 +2425,15 @@ def print_forward_breakdown(self, fwd_time): # if deepspeed.comm.get_rank() == 0: log_dist( - f"rank={dist.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})", + f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})", ranks=[0]) @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ - f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' + # Skip gradient reduction when DeepCompile is enabled + # DeepCompile handles its own gradient reduction through compiled graph operations + if self.is_deepcompile_active() and not self.compile_autosp(): + return # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() @@ -1769,89 +2447,220 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): self.optimizer, 'reduce_gradients'): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: - self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) - - @instrument_w_nvtx - def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True): - r"""Execute backward pass on the loss - Arguments: - loss: Torch tensor on which to execute backward propagation - allreduce_gradients: is deprecated, ignored, and will soon be removed' - retain_graph: bool, default: false - forward on user defined choice of retain_graph - """ - - see_memory_usage("Engine before backward", force=self.memory_breakdown()) + grads = None + self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) + elif self.zenflow: + self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) - if self.scale_wrt_gas is not None: - scale_wrt_gas = self.scale_wrt_gas + def _backward_prologue(self): + self._start_timers(self.engine_timers.backward_timers) - if not allreduce_gradients: - logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed") + # When necessary internal APIs are not available, we disable direct calls to tensor.backward() + # and limit to engine.backward(loss) only. + if not self._support_torch_style_backward and not self._running_engine_backward: + raise RuntimeError("Direct calls to tensor.backward() are not supported in this configuration. " + "This occurs when either: (1) your PyTorch version lacks required internal APIs, " + "or (2) using ZeRO stage 0. " + "Please use engine.backward(loss) instead.") - # scale loss w.r.t. gradient accumulation if needed - if self.gradient_accumulation_steps() > 1 and scale_wrt_gas: - loss = self._scale_loss_by_gas(loss.float()) + see_memory_usage("Engine before backward", force=self.memory_breakdown()) - # Log training Loss - if self.monitor.enabled: - if self.is_gradient_accumulation_boundary(): - if self.global_rank == 0: - self.summary_events = [( - f"Train/Samples/train_loss", - loss.mean().item() * self.gradient_accumulation_steps(), - self.global_samples, - )] - self.monitor.write_events(self.summary_events) + assert not self.eigenvalue_enabled(), "Eigenvalue is not supported with non-scalar backward" + assert not self.amp_enabled(), "Apex AMP is not supported with non-scalar backward" - self._start_timers(self.engine_timers.backward_timers) + if self.is_deepcompile_active(): + deepcompile_backward_prologue(self.is_gradient_accumulation_boundary()) - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use backward" + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.backward_prologue() + self.optimizer.enter_backward() + self.optimizer.queue_post_backward_callback() - self._start_timers(self.engine_timers.backward_inner_timers) + if self.zenflow and self.auto_update: + self.optimizer.zenflow_state ^= 1 if self.zero_optimization(): self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() - self.optimizer.backward(loss, retain_graph=retain_graph) - elif self.amp_enabled(): - # AMP requires delaying unscale when inside gradient accumulation boundaries - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: - scaled_loss.backward(retain_graph=retain_graph) - elif self.fp16_enabled(): - if self.eigenvalue_enabled(): - self.optimizer.backward(loss, create_graph=True, retain_graph=True) - else: - self.optimizer.backward(loss, retain_graph=retain_graph) - elif self.bfloat16_enabled(): - self.optimizer.backward(loss) - else: - if self.eigenvalue_enabled(): - loss.backward(create_graph=True, retain_graph=True) - else: - loss.backward(retain_graph=retain_graph) - self._stop_timers(self.engine_timers.backward_inner_timers) + self._start_timers(self.engine_timers.backward_inner_timers) + def _backward_epilogue(self): + self._stop_timers(self.engine_timers.backward_inner_timers) self._start_timers(self.engine_timers.backward_reduce_timers) - - if allreduce_gradients and self.enable_backward_allreduce: + # BF16_Optimizer (without immediate_grad_update) accumulates low + # precision grads into a separate fp32 buffer in backward_epilogue(). + # Run it before allreduce so the boundary microbatch is reduced. + bf16_optimizer = isinstance(self.optimizer, BF16_Optimizer) + if bf16_optimizer: + self.optimizer.backward_epilogue() + + if self.enable_backward_allreduce and not self.inside_no_sync_ctxt: # Traditional code path that allreduces the module parameter grads self.allreduce_gradients() - self._stop_timers(self.engine_timers.backward_reduce_timers) + if isinstance(self.optimizer, ZeROOptimizer): + if not bf16_optimizer: + self.optimizer.backward_epilogue() + self.optimizer.exit_backward() + see_memory_usage("Engine after backward", force=self.memory_breakdown()) + self._stop_timers(self.engine_timers.backward_reduce_timers) self._stop_timers(self.engine_timers.backward_timers) - if release_loss: - # loss.data = None - pass + def _backward_prologue_per_tensor(self, grad): + # Only scale gradients if scale_wrt_gas is True, consistent with backward() parameter + if grad is not None and self._scale_wrt_gas: + return grad / self.gradient_accumulation_steps() + return grad + + def _backward_post_hook(self): + if not self._running_engine_backward: + # Check if loss scaling was required but not applied + needs_scaler = False + if isinstance(self.optimizer, ZeROOptimizer): + needs_scaler = self.optimizer.needs_scaler() + elif self.torch_autocast_z0_gradscaler is not None: + needs_scaler = True + elif self.amp_enabled(): + needs_scaler = True + + if needs_scaler and not self._manual_backward_expected: + # User called backward() directly without using engine.scale() or engine.backward() + error_msg = ("Loss scaling is required for this configuration, but backward() was called " + "directly without scaling the loss. Please use one of the following:" + " 1. engine.backward(loss)" + " 2. engine.scale(loss).backward()") + if self.amp_enabled(): + error_msg += " Note: AMP (NVIDIA Apex) only supports engine.backward(loss)." + raise RuntimeError(error_msg) + + # Clear the flag for next backward + self._manual_backward_expected = False + + self._backward_epilogue() + + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient reduction during backward pass. + This context manager has the following effects on other DeepSpeed features: + 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. + 2. It is illegal to call engine.step() within the context manager. + 3. Tracking of gradient accumulation steps is disabled. + """ + assert not self.zero_optimization_partition_gradients(), \ + f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - see_memory_usage("Engine after backward", force=self.memory_breakdown()) + assert not self.inside_no_sync_ctxt, "no_sync context manager reentry is unsupported" - return loss + self.inside_no_sync_ctxt = True + try: + yield + finally: + self.inside_no_sync_ctxt = False + + def scale(self, loss): + r"""Apply loss scaler for manual backward pass. + + Use this method when calling loss.backward() directly instead of engine.backward(). + This applies the appropriate loss scaler for mixed precision training, allowing you + to manually control the backward pass while still benefiting from DeepSpeed's + gradient scaling functionality. + + Example:: + + output = engine(input) + loss = criterion(output, target) + scaled_loss = engine.scale(loss) + scaled_loss.backward() # Manual backward call + engine.step() + + Arguments: + loss: Scalar loss tensor to be scaled + + Returns: + Scaled loss tensor ready for .backward() call + + Raises: + RuntimeError: If AMP (NVIDIA Apex) is enabled. AMP requires using engine.backward() + directly as it uses a context manager that cannot be separated from + the backward call. + AssertionError: If loss is not a scalar tensor with grad_fn, or if no optimizer + is configured. + """ + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use scale" + assert maybe_loss_for_backward(loss), \ + "loss must be a scalar tensor with grad_fn. For non-scalar tensors, use tensor.backward(grad)" + + # AMP (NVIDIA Apex) uses a context manager that wraps both scaling and backward, + # so it cannot be used with manual backward calls + if self.amp_enabled(): + raise RuntimeError("engine.scale() is not compatible with AMP (NVIDIA Apex). " + "When using AMP, you must call engine.backward(loss) instead of manual backward.") + + # Apply loss scaler based on optimizer type + scaled_loss = loss + if isinstance(self.optimizer, ZeROOptimizer): + scaled_loss = self.optimizer.scale_if_loss(loss) + elif self.torch_autocast_z0_gradscaler: + scaled_loss = self.torch_autocast_z0_gradscaler.scale(loss) + + # Mark that scale() was called for validation in backward hook + self._manual_backward_expected = True + + return scaled_loss + + @instrument_w_nvtx + def backward(self, loss, retain_graph=False, scale_wrt_gas=True): + r"""Execute backward pass on the loss + Arguments: + loss: Torch tensor on which to execute backward propagation + retain_graph: bool, default: false + forward on user defined choice of retain_graph + scale_wrt_gas: bool, default: true + whether to scale gradients and return value by gradient accumulation steps + """ + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use backward" + assert maybe_loss_for_backward( + loss), "loss must be a scalar tensor. If you need to pass output gradients, backward() of output tensors" + + self._running_engine_backward = True + # Store scale_wrt_gas so the hook can respect it + self._scale_wrt_gas = scale_wrt_gas + + # Set flag to prevent hooks from firing (we'll manually call prologue/epilogue) + backward_kwargs = {"retain_graph": retain_graph} + if self.eigenvalue_enabled(): + backward_kwargs["create_graph"] = True + backward_kwargs["retain_graph"] = True + + # Used only for return value + gas_scaled_loss = loss / self.gradient_accumulation_steps() if scale_wrt_gas else loss + + # TODO: handle these scaling with direct calls to loss.backward() + if isinstance(self.optimizer, ZeROOptimizer): + loss = self.optimizer.scale_if_loss(loss) + elif self.torch_autocast_z0_gradscaler: + loss = self.torch_autocast_z0_gradscaler.scale(loss) + + with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + if self.zero_optimization() or not self.amp_enabled(): + loss.backward(**backward_kwargs) + elif self.amp_enabled(): + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward(**backward_kwargs) + + # backward_epilogue is not called in a hook when self._support_torch_style_backward is False + self._backward_epilogue() + + self._running_engine_backward = False + + return gas_scaled_loss def is_gradient_accumulation_boundary(self): """ @@ -1864,8 +2673,10 @@ def is_gradient_accumulation_boundary(self): """ if self._is_gradient_accumulation_boundary is None: - return (self.micro_steps + 1) % \ - self.gradient_accumulation_steps() == 0 + if self.zenflow: + return self._is_zenflow_update_boundary() + else: + return (self.micro_steps + 1) % self.gradient_accumulation_steps() == 0 else: return self._is_gradient_accumulation_boundary @@ -1873,7 +2684,7 @@ def set_gradient_accumulation_boundary(self, is_boundary): """ Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional feature and should be used with care. The state should be set before to the intended - value before each forward/backward. The final fordward/backward should have the + value before each forward/backward. The final forward/backward should have the boundary state set to True. This style allows client code to only call engine.step() once after all the gradient accumulation passes are complete. See example below: .. code-block:: python @@ -1905,6 +2716,9 @@ def clip_fp32_gradients(self): def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: + if self.torch_autocast_z0_gradscaler: + # Unscale for gradient clipping + self.torch_autocast_z0_gradscaler.unscale_(self.optimizer) if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): self.clip_fp32_gradients() elif self.amp_enabled(): @@ -1912,7 +2726,11 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu) - self.optimizer.step() + if self.torch_autocast_z0_gradscaler: + self.torch_autocast_z0_gradscaler.step(self.optimizer) + self.torch_autocast_z0_gradscaler.update() + else: + self.optimizer.step() if hasattr(self.optimizer, '_global_grad_norm'): self._global_grad_norm = self.optimizer._global_grad_norm @@ -1929,20 +2747,18 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): block_eigenvalue, ) # zero grad in basic optimizer could be unreliable and may not exhibit - # the behaviour that we want + # the behavior that we want if self.bfloat16_enabled(): # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated - if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): + if hasattr(self.optimizer, "zero_grad"): self.optimizer.zero_grad() else: - pass + self.zero_grad() elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): self.optimizer.zero_grad() else: self.zero_grad() - report_progress = self.global_rank == 0 if self.global_rank else True - # Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function. overflow = False if hasattr(self.optimizer, "overflow"): @@ -1960,11 +2776,14 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. # We don't currently have a way to specify lr_kwargs from # pipe_engine.train_batch() - self.lr_scheduler.step(increment=self.train_batch_size()) + self.lr_scheduler.step(self.train_batch_size()) - if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: - self._report_progress(self.global_steps + 1) + if self.steps_per_print() is not None: + report_progress = self.global_rank == 0 if self.global_rank else True + if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: + self._report_progress(self.global_steps + 1) + self.losses = None self.global_steps += 1 self.global_samples += self.train_batch_size() @@ -1972,6 +2791,9 @@ def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ + assert not self.inside_no_sync_ctxt, \ + "It is illegal to call Engine.step() inside no_sync context manager" + see_memory_usage("Engine before step", force=self.memory_breakdown()) # Check early because self.global_steps is incremented at some point here. @@ -1988,15 +2810,23 @@ def step(self, lr_kwargs=None): self._step_applied = False # assume False, will flip to True + if self.zenflow: + self.optimizer._sync_selective_optimizer_lr() + if self.auto_update: + self.update_interval += 1 + # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1 + if self.checkpoint_engine.is_decoupled(): + self._commit_decoupled_checkpoint() + if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) and self.quantizer.any_precision_switch()): - log_dist(f"computing eigenvalue...", ranks=[0]) - self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, - self.optimizer.cur_scale) + log_dist("computing eigenvalue...", ranks=[0]) + loss_scale = self._get_optimizer_loss_scale() or 1.0 + self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, loss_scale) if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) @@ -2009,6 +2839,9 @@ def step(self, lr_kwargs=None): report_progress = self.global_rank == 0 if self.global_rank else True + if self.zenflow: + self._zenflow_step(lr_kwargs) + self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) self._stop_timers(self.engine_timers.step_timers) @@ -2017,12 +2850,13 @@ def step(self, lr_kwargs=None): if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: - self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)] + self.summary_events = [("Train/Samples/lr", self.get_lr()[0], self.global_samples)] - if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): + loss_scale = self._get_optimizer_loss_scale() if self.fp16_enabled() else None + if loss_scale is not None: self.summary_events.append(( - f"Train/Samples/loss_scale", - self.optimizer.cur_scale, + "Train/Samples/loss_scale", + loss_scale, self.global_samples, )) @@ -2041,6 +2875,7 @@ def step(self, lr_kwargs=None): if flops_profiler_active: if self.autotuning_enabled(): self.flops = self.flops_profiler.get_total_flops() * 3 + self.fwd_duration = self.flops_profiler.get_total_duration() else: self.flops_profiler.print_model_profile( profile_step=self.global_steps, @@ -2055,6 +2890,9 @@ def step(self, lr_kwargs=None): self._autotuning_exit() if self.wall_clock_breakdown(): + # Update client accessible wall clock timers cache + self._update_wall_clock_timers() + # Log micro timing and reset self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown()) @@ -2084,6 +2922,17 @@ def _stop_timers(self, timer_names): for name in timer_names: self.timers(name).stop(record=record) + def _update_wall_clock_timers(self): + self.engine_timers_cache = {} + for name in self.engine_timers.active_timers(): + self.engine_timers_cache[name] = self.timers(name).elapsed(reset=False) + + def get_wall_clock_timers(self): + r""" + Return a dict snapshot of the Engine's wall clock timers. + """ + return self.engine_timers_cache + def _autotuning_exit(self): if self.global_rank == 0: msg = self.timers.get_mean([ @@ -2091,7 +2940,11 @@ def _autotuning_exit(self): BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER, ], reset=False) - titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[STEP_GLOBAL_TIMER] + titer = 0.0 + titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0 + titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0 + titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0 + titer *= self.gradient_accumulation_steps() msg["latency"] = titer msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer msg["throughput"] = self.train_batch_size() * 1_000_000 / \ @@ -2108,27 +2961,27 @@ def _write_monitor(self): if self.global_rank == 0: self.summary_events = [ ( - f"Train/Samples/elapsed_time_ms_forward", + "Train/Samples/elapsed_time_ms_forward", self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( - f"Train/Samples/elapsed_time_ms_backward", + "Train/Samples/elapsed_time_ms_backward", self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( - f"Train/Samples/elapsed_time_ms_backward_inner", + "Train/Samples/elapsed_time_ms_backward_inner", self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( - f"Train/Samples/elapsed_time_ms_backward_allreduce", + "Train/Samples/elapsed_time_ms_backward_allreduce", self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( - f"Train/Samples/elapsed_time_ms_step", + "Train/Samples/elapsed_time_ms_step", self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), @@ -2146,6 +2999,13 @@ def _get_optimizer_param(self, param_name): result.append(0.0) return result + def _get_optimizer_loss_scale(self): + if not self.optimizer: + return None + if hasattr(self.optimizer, "loss_scale_config"): + return self.optimizer.loss_scale_config.cur_scale + return getattr(self.optimizer, "cur_scale", None) + def get_lr(self): return self._get_optimizer_param("lr") @@ -2169,7 +3029,7 @@ def _report_progress(self, step): mom = self.get_mom() log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) - def allreduce_bucket(self, bucket, dp_group): + def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -2177,16 +3037,18 @@ def allreduce_bucket(self, bucket, dp_group): if self.communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(self.communication_data_type) + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_predivide_factor() != 1.0: tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor()) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: - if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + if self.gradient_predivide_factor() != dp_world_size: + tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size) else: - tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(1. / dp_world_size) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -2194,23 +3056,23 @@ def allreduce_bucket(self, bucket, dp_group): return tensor - def allreduce_and_copy(self, small_bucket, dp_group): - allreduced = self.allreduce_bucket(small_bucket, dp_group) + def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None): + allreduced = self.allreduce_bucket(small_bucket, dp_group, dp_world_size) for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000): + def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) small_bucket = [] numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) def _get_gradients_for_reduction(self): non_expert_grads = [] @@ -2220,6 +3082,14 @@ def _get_gradients_for_reduction(self): expert_grads[key] = [] for param_name, param in self.module.named_parameters(): + if not param.requires_grad: + continue + + # Skip empty parameters (numel=0) as they contribute nothing to gradient reduction + # and cause issues with flatten/unflatten operations + if param.numel() == 0: + continue + if param.grad is None: # In cases where there is an imbalance of empty grads across # ranks we must create empty grads, this will ensure that every @@ -2241,36 +3111,56 @@ def _get_gradients_for_reduction(self): return non_expert_grads, expert_grads def _reduce_non_expert_gradients(self, grads, elements_per_buffer): - split_buckets = split_half_float_double_sparse(grads) - for _, bucket_tuple in enumerate(split_buckets): - bucket_type, bucket = bucket_tuple - - if self.pipeline_parallelism: - dp_group = self.mpu.get_data_parallel_group() - else: - dp_group = groups._get_data_parallel_group() - - if bucket_type == SparseTensor.type(): - self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) - else: - self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) + split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads) + if self.pipeline_parallelism: + dp_group = self.mpu.get_data_parallel_group() + dp_world_size = dist.get_world_size(dp_group) + else: + dp_group = groups._get_sequence_data_parallel_group() + dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) + for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): + if sparse_bucket_tuple: + bucket_type, sparse_bucket = sparse_bucket_tuple + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) + + for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): + if dense_bucket_tuple: + bucket_type, dense_bucket = dense_bucket_tuple + self.allreduce_no_retain(dense_bucket, + dp_group=dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): + # to maintain the gradients value unaffected by ep_size setting, + # utilize dp_world_size for allreduce average + dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) for ep_name, expert_grads_group in expert_grads.items(): - expert_split_buckets = split_half_float_double_sparse(expert_grads_group) - for i, bucket_tuple in enumerate(expert_split_buckets): - bucket_type, bucket = bucket_tuple - if bucket_type == SparseTensor.type(): - self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name)) - else: + ep_dp_group = groups._get_expert_data_parallel_group(ep_name) + split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( + expert_grads_group) + + for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): + if sparse_bucket_tuple: + bucket_type, sparse_bucket = sparse_bucket_tuple + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size) + + for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): + if dense_bucket_tuple: + bucket_type, dense_bucket = dense_bucket_tuple # Separate between diff groups - self.allreduce_no_retain(bucket, - dp_group=groups._get_expert_data_parallel_group(ep_name), - numel_per_bucket=elements_per_buffer) + self.allreduce_no_retain(dense_bucket, + dp_group=ep_dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: - non_expert_grads, expert_grads = self._get_gradients_for_reduction() + if hasattr(self.optimizer, "get_grads_for_reduction"): + # This is currently for BF16 optimizer + non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction() + else: + non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads @@ -2280,8 +3170,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) if self.has_moe_layers: self._reduce_expert_gradients(expert_grads, elements_per_buffer) - def sparse_allreduce_no_retain(self, bucket, dp_group): - allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group) + def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None): + allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size) # Densify sparse tensor and copy back to original location for tensor in allreduced_sparses: if tensor.is_sparse: @@ -2289,13 +3179,13 @@ def sparse_allreduce_no_retain(self, bucket, dp_group): else: tensor.orig_dense_tensor.copy_(tensor.to_dense()) - def sparse_allreduce_bucket(self, bucket, dp_group): + def sparse_allreduce_bucket(self, bucket, dp_group, dp_world_size=None): sparse_list = [] for sparse in bucket: - sparse_list.append(self.sparse_allreduce(sparse, dp_group)) + sparse_list.append(self.sparse_allreduce(sparse, dp_group, dp_world_size)) return sparse_list - def sparse_allreduce(self, sparse, dp_group): + def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): original_data_type = sparse.values.dtype if self.communication_data_type != sparse.values.dtype: if self.communication_data_type in (torch.float16, torch.bfloat16): @@ -2307,11 +3197,13 @@ def sparse_allreduce(self, sparse, dp_group): indices = sparse.indices values = sparse.values + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + values.mul_(self.gradient_predivide_factor() / (dp_world_size)) else: - values.mul_(1. / dist.get_world_size(group=dp_group)) + values.mul_(1. / (dp_world_size)) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) @@ -2352,8 +3244,15 @@ def all_gather_scalar(self, value, dp_group): dist.all_gather(tensor_list, value, group=dp_group) return tensor_list - def module_state_dict(self, destination=None, prefix="", keep_vars=False): - sd = self.module.state_dict(destination, prefix, keep_vars) + def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): + sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + # Remove frozen parameter weights from state_dict if specified + if exclude_frozen_parameters: + for n, p in self.module.named_parameters(): + if not p.requires_grad and n in sd: + del sd[n] + if self.random_ltd_enabled(): sd = remove_random_ltd_state_dict(sd) return sd @@ -2414,13 +3313,36 @@ def load_moe_state_dict(checkpoint_path, state_dict.update(expert_state_dict) moe_layer_id += 1 - def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): - if custom_load_fn: - custom_load_fn(src=state_dict, dst=self.module) + def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): + if fetch_z3_params: + params_to_fetch = [ + p for p in self.module.parameters() + if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE + ] else: - self.module.load_state_dict( - state_dict, # TODO - strict=strict) + params_to_fetch = [] + + with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0): + module_state_dict = checkpoint['module'] + if custom_load_fn: + custom_load_fn(src=module_state_dict, dst=self.module) + else: + self.module.load_state_dict( + module_state_dict, # TODO + strict=strict) + + if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: + saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] + for param in self.module.parameters(): + if param.requires_grad: + continue + if param not in self.param_names: + raise ValueError(f"failed to find frozen {param} in named params") + name = self.param_names[param] + if hasattr(param, 'ds_id'): + param.ds_tensor.data.copy_(saved_frozen_params[name].data) + else: + param.data.copy_(saved_frozen_params[name].data) def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' @@ -2440,7 +3362,7 @@ def _get_zero_ckpt_name(self, checkpoints_path, tag): bf16_mode = self.bfloat16_enabled() return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode) - def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): + def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None): if mp_placeholder is not None: mp_rank_str = mp_placeholder else: @@ -2448,7 +3370,12 @@ def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): mp_rank_str = f"{mp_rank:02d}" if self.zero_optimization_partition_weights(): - filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) + if pp_placeholder is not None: + pp_rank = pp_placeholder + else: + pp_rank = dist.get_rank(group=self.optimizer.dp_process_group) + + filename = "zero_pp_rank_{}".format(pp_rank) ckpt_name = os.path.join( checkpoints_path, str(tag), @@ -2483,7 +3410,10 @@ def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None): def _get_all_ckpt_names(self, checkpoints_path, tag): # It is required that (checkpoints_path, tag) are consistent among all ranks. - ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*") + ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, + tag, + mp_placeholder="*", + pp_placeholder="0" if self.load_universal_checkpoint() else None) import glob ckpt_files = glob.glob(ckpt_file_pattern) @@ -2538,7 +3468,7 @@ def load_checkpoint(self, ) return None, None - if self.zero_optimization_partition_weights(): + if self._optimizer_has_ckpt_event_prologue(): # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() @@ -2550,15 +3480,36 @@ def load_checkpoint(self, load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled() - if load_zero_checkpoint and load_path is not None: - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + load_zero_checkpoint = load_path is not None and self.zero_optimization() + if load_zero_checkpoint and not self.zero_nvme_offload_optimizer(): + if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): + success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + else: + success = False if not success: self.optimizer._restore_from_bit16_weights() - if self.zero_optimization_partition_weights(): + if self.zero_nvme_offload_optimizer(): + from shutil import copytree, disk_usage + rank = self.local_rank if self.use_node_local_storage() else self.global_rank + rank_dir = "rank" + dp_index_to_str(rank) + offload_dir = self.optimizer.optimizer_swapper.swap_folder + offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors", rank_dir) + _, _, free = disk_usage(offload_dir) + logger.info( + f"Copying NVMe offload checkpoint from {offload_ckpt_dir} to {offload_dir}, {free / 1e9:,.2f} GB free on target filesystem..." + ) + copytree(offload_ckpt_dir, offload_dir, dirs_exist_ok=True) + _, _, free = disk_usage(offload_dir) + logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") + self.optimizer.reset_swap_buffers() + + if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() + if self.load_universal_checkpoint() and not self.zero_optimization_partition_weights(): + self.optimizer.update_lp_params() + return load_path, client_states def _load_checkpoint(self, @@ -2583,6 +3534,11 @@ def _load_checkpoint(self, if checkpoint is None: return None, None + fetch_z3_params = False + if self.zero_optimization_partition_weights() and not load_optimizer_states: + checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) + fetch_z3_params = True + if is_pipe_parallel: # Pipeline parallelism uses this to load its own checkpoint files. self._curr_ckpt_path = os.path.join(load_dir, tag) @@ -2601,32 +3557,36 @@ def _load_checkpoint(self, num_experts=self.num_experts, checkpoint_engine=self.checkpoint_engine) if not self.load_universal_checkpoint(): - self.load_module_state_dict(state_dict=checkpoint['module'], + self.load_module_state_dict(checkpoint=checkpoint, strict=load_module_strict, - custom_load_fn=custom_load_fn) + custom_load_fn=custom_load_fn, + fetch_z3_params=fetch_z3_params) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] + optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] - if self.optimizer is not None and self.fp16_enabled(): + if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'): self.optimizer.refresh_fp32_params() else: - if self.has_moe_layers: - largest_group_name = groups._get_max_expert_size_name() - expp_rank = groups._get_expert_parallel_rank(largest_group_name) - optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) - else: - optim_checkpoint = checkpoint - - has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() + has_zero_optimizer_state = self.zero_optimization() if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: - if self.fp16_enabled(): + if self.has_moe_layers: + largest_group_name = groups._get_max_expert_size_name() + expp_rank = groups._get_expert_parallel_rank(largest_group_name) + optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) + optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) + else: + optim_checkpoint = checkpoint + + if self.fp16_enabled() or self.bfloat16_enabled(): self.optimizer.load_state_dict(optim_checkpoint['optimizer'], load_optimizer_states=load_optimizer_states) else: - self.optimizer.load_state_dict(optim_checkpoint['optimizer']) + optim_checkpoint = checkpoint + + self.optimizer.load_state_dict(optim_checkpoint['optimizer']) if load_lr_scheduler_states and self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) @@ -2681,22 +3641,32 @@ def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters if load_optimizer_states: deepspeed_states.append('optimizer') - client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states} + client_state = {key: value for key, value in checkpoint.items() if key not in deepspeed_states} - if not load_optimizer_states and not load_module_only: + if optim_checkpoint is not None: client_state['optimizer'] = optim_checkpoint['optimizer'] return load_path, client_state def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): + + load_serial = None + # When use loading checkpoint serial, checkpoint loading start from local rank 0, + # all other local rank would be paused, waiting for its rank-1 peer ready and its notification. + if self._config.zero_config.pipeline_loading_checkpoint: + assert self.zero_optimization_stage( + ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading" + load_serial = torch.zeros(1).to(self.device) + if dist.get_local_rank() != 0: + dist.recv(tensor=load_serial, src=dist.get_rank() - 1) if self.load_universal_checkpoint(): zero_sd_list = None checkpoint_folder = f'{os.path.join(load_dir, tag)}' else: - if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: + if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size: raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ - f"current world size is {self.dp_world_size}. Automatic adjustment " \ + f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \ "of ZeRO's optimizer state partitioning with a new world size is not " \ "currently supported.") checkpoint_folder = None @@ -2704,10 +3674,13 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if zero_sd_list is None: return False + param_shapes = self._get_zero_param_shapes() self.optimizer.load_state_dict(state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder) + checkpoint_folder=checkpoint_folder, + load_serial=load_serial, + param_shapes=param_shapes) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') @@ -2773,7 +3746,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag): if bf16_mode is not self.bfloat16_enabled(): checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 - logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') + logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return None @@ -2795,7 +3768,7 @@ def _checkpoint_tag_validation(self, tag): elif not valid: logger.warning(msg) - def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True): + def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False): """Save training checkpoint Arguments: @@ -2804,14 +3777,15 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) used if not provided. Tag name must be the same across all ranks. client_state: Optional. State dictionary used for saving required training states in the client code. save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. + exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. Important: all processes must call this method and not just the process with rank 0. It is because each process needs to save its master weights and scheduler+optimizer states. This method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. """ - if self.zero_optimization_partition_weights(): - # Prepare for checkpoint save by ensuring all parameters are partitioned + if self._optimizer_has_ckpt_event_prologue(): + # Custom preparation for checkpoint save, if applicable self.optimizer.checkpoint_event_prologue() rank = self.local_rank if self.use_node_local_storage() else self.global_rank @@ -2820,7 +3794,8 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # There seems to be issue creating them in parallel # Ensure save_dir directory exists - self.checkpoint_engine.makedirs(save_dir, exist_ok=True) + if rank == 0: + self.checkpoint_engine.makedirs(save_dir, exist_ok=True) dist.barrier() if tag is None: @@ -2828,7 +3803,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Ensure tag is a string tag = str(tag) - self.checkpoint_engine.create(tag) + commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) + + self.checkpoint_engine.create(commit_info) # Ensure checkpoint tag is consistent across ranks self._checkpoint_tag_validation(tag) @@ -2836,7 +3813,10 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) if self.has_moe_layers: self.save_non_zero_checkpoint = False self._create_checkpoint_file(save_dir, tag, False) - self._save_moe_checkpoint(save_dir, tag, client_state=client_state) + self._save_moe_checkpoint(save_dir, + tag, + client_state=client_state, + exclude_frozen_parameters=exclude_frozen_parameters) # We distribute the task of saving layer checkpoint files among # data parallel instances, so all procs should call _save_checkpoint. @@ -2844,25 +3824,62 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # parallel rank 0 save the general model params. if not self.has_moe_layers: self._create_checkpoint_file(save_dir, tag, False) - self._save_checkpoint(save_dir, tag, client_state=client_state) + self._save_checkpoint(save_dir, + tag, + client_state=client_state, + exclude_frozen_parameters=exclude_frozen_parameters) if self.save_zero_checkpoint: self._create_zero_checkpoint_files(save_dir, tag) self._save_zero_checkpoint(save_dir, tag) - if self.zero_optimization_partition_weights(): + if self.zero_nvme_offload_optimizer(): + from shutil import copytree, disk_usage + rank_dir = "rank" + dp_index_to_str(rank) + offload_dir = self.optimizer.optimizer_swapper.swap_folder + offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors", rank_dir) + _, _, free = disk_usage(save_dir) + logger.info( + f"Copying NVMe offload files from {offload_dir} to {offload_ckpt_dir}, {free / 1e9:,.2f} GB free on target filesystem..." + ) + copytree(offload_dir, + offload_ckpt_dir, + ignore=lambda _, dir_list: list(filter(lambda x: 'gradient' in x, dir_list)), + dirs_exist_ok=False) + _, _, free = disk_usage(save_dir) + logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") + + if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() # Save latest checkpoint tag - self.checkpoint_engine.commit(tag) - if save_latest and rank == 0: - with open(os.path.join(save_dir, 'latest'), 'w') as fd: - fd.write(tag) + if not self.checkpoint_engine.is_decoupled(): + commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) + self.checkpoint_engine.commit(commit_info) + if save_latest and self.global_rank == 0: + with open(os.path.join(save_dir, 'latest'), 'w') as fd: + fd.write(tag) dist.barrier() return True + def _commit_decoupled_checkpoint(self): + assert self.checkpoint_engine.is_decoupled(), \ + f'{self.checkpoint_engine} is not a Decoupled Checkpoint Engine' + + commit_info = self.checkpoint_engine.get_commit_info() + if commit_info is None: + return + + self.checkpoint_engine.commit(commit_info) + + if self.global_rank == 0 and commit_info.save_latest: + with open(os.path.join(commit_info.save_dir, 'latest'), 'w') as fd: + fd.write(commit_info.tag) + + dist.barrier() + def _get_non_moe_state_dict(self, full_state_dict): """ Get the state dict of the non-moe layers @@ -2873,8 +3890,9 @@ def _get_non_moe_state_dict(self, full_state_dict): return full_state_dict - def _save_moe_checkpoint(self, save_dir, tag, client_state={}): + def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): save_path = self._get_ckpt_name(save_dir, tag) + # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. @@ -2888,7 +3906,8 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): expp_rank = groups._get_expert_parallel_rank(group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(group_name) # print(expp_rank, exp_dp_rank) - if exp_dp_rank != 0: + # if exp_dp_rank != 0: + if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): moe_layer_id += 1 continue @@ -2906,7 +3925,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): local_expert_id = None if not m: - logger.warn(f'No expert found in key {key}.') + logger.warning(f'No expert found in key {key}.') else: local_expert_id = m.group(1) @@ -2924,7 +3943,10 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) if self.random_ltd_enabled(): expert_state_dict = remove_random_ltd_state_dict(expert_state_dict) - self.checkpoint_engine.save(expert_state_dict, moe_save_path) + saveable_state_dict = expert_state_dict + if self.checkpoint_engine.preserves_storage_sharing(): + saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict) + self.checkpoint_engine.save(saveable_state_dict, moe_save_path) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) @@ -2936,7 +3958,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): # In the case of E + D parallelism, only the # first expert parallel group should save the expert weights # since each expert parallel group is a copy of the model's experts - if exp_dp_rank != 0: + if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): return # Save optimizer states. They are different across each exp parallel rank. @@ -2945,12 +3967,20 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } # TODO: why use BufferedWriter not the path file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) - self.checkpoint_engine.save(optimizer_state, file_path) + saveable_state_dict = optimizer_state + if self.checkpoint_engine.preserves_storage_sharing(): + saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) + self.checkpoint_engine.save(saveable_state_dict, file_path) + + # Load flow uses below saved file for model parameters, RNG and more + if groups._get_data_parallel_rank() == 0: + # Get non-moe parameters + # Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict. + # DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None. + # We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine) + model_state_dict = self._get_non_moe_state_dict( + DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters)) - # get non-moe parameters - model_state_dict = self._get_non_moe_state_dict(self.module_state_dict()) - - if expp_rank == 0: # TODO: update num experts info,.. in checkpoint state = { 'module': @@ -2971,7 +4001,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): 'global_samples': self.global_samples, 'dp_world_size': - self.dp_world_size, + self.seq_dp_world_size, 'mp_world_size': self.mp_world_size, 'num_experts': @@ -2979,8 +4009,10 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') - self.checkpoint_engine.save(state, save_path) - self._curr_save_path = None + saveable_state_dict = state + if self.checkpoint_engine.preserves_storage_sharing(): + saveable_state_dict = clone_tensors_for_torch_save(state) + self.checkpoint_engine.save(saveable_state_dict, save_path) def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) @@ -2988,7 +4020,7 @@ def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): checkpoint_name = name_function(save_dir, tag) path = os.path.dirname(checkpoint_name) self.checkpoint_engine.makedirs(path, exist_ok=True) - except: + except Exception: logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}") return False @@ -2997,32 +4029,37 @@ def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): def _create_zero_checkpoint_files(self, save_dir, tag): success = True # zero checkpoint files are created sequentially - for rank in range(self.world_size): + for rank in range(dist.get_world_size(self.optimizer.dp_process_group)): if rank == self.global_rank: success = self._create_checkpoint_file(save_dir, tag, True) - dist.barrier() - return success - def _save_checkpoint(self, save_dir, tag, client_state={}): + def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): save_path = self._get_ckpt_name(save_dir, tag) - zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() + zero_optimizer_state = self.zero_optimization() + + save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. The module_state_dict() implementation in # PipelineEngine expects the save path to be set in self._curr_ckpt_path. self._curr_ckpt_path = os.path.join(save_dir, tag) - module = self.module_state_dict() + module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) self._curr_ckpt_path = None state = dict(module=module, buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, + frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) + if save_frozen_param else None, + shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, + frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) + if save_frozen_param else None, lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, data_sampler=self.training_dataloader.data_sampler.state_dict() if (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, @@ -3031,15 +4068,18 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): skipped_steps=self.skipped_steps, global_steps=self.global_steps, global_samples=self.global_samples, - dp_world_size=self.dp_world_size, + dp_world_size=self.seq_dp_world_size, mp_world_size=self.mp_world_size, ds_config=self.config, ds_version=version) + autotp_uc_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None) + if autotp_uc_info is not None: + state[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info state.update(client_state) + log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0]) if self.save_non_zero_checkpoint: - log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - self.checkpoint_engine.save(state, save_path) + self.checkpoint_engine.save(state_dict=state, path=save_path) def _get_buffer_names(self): buffer_names = [] @@ -3062,6 +4102,25 @@ def get_layer_named_buffers(module, prefix=""): return buffer_names + def _get_param_shape_func(self, param): + return param.ds_shape if hasattr(param, 'ds_id') else param.shape + + def _get_param_fragment_func(self, param): + return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu() + + def _get_zero_frozen_param_attributes(self, attr_func): + frozen_param_fragments = OrderedDict() + + for param in self.module.parameters(): + if param.requires_grad: + continue + if param not in self.param_names: + raise ValueError(f"failed to find frozen {param} in named params") + name = self.param_names[param] + frozen_param_fragments[name] = attr_func(param) + + return frozen_param_fragments + def _get_zero_param_shapes(self): """Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the optimizer. the names are exactly as in state_dict. The order is absolutely important, since @@ -3080,7 +4139,7 @@ def _get_zero_param_shapes(self): # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups - elif self.bfloat16_enabled() and not self.zero_optimization(): + elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( @@ -3093,7 +4152,7 @@ def _get_zero_param_shapes(self): numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape if param not in self.param_names: - raise ValueError(f"failed to find optimizer param in named params") + raise ValueError("failed to find optimizer param in named params") name = self.param_names[param] param_shapes[name] = shape @@ -3104,6 +4163,46 @@ def _get_zero_param_shapes(self): return param_group_shapes + def _get_shared_params(self): + """ + Returns a dict of shared params, which can later be used to reconstruct the original state dict, + e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name + of the variable that isn't stored and the value is the actual param holding data. + """ + shared_index = {} + shared_params_by_full_name = {} + + is_zero3_model = (self.zero_optimization_partition_weights() + and any(hasattr(param, "ds_id") for param in self.module.parameters())) + + def get_layer_state_dict(module, prefix=""): + # handle params + for name, param in module.named_parameters(recurse=False): + if param is None or (is_zero3_model and not hasattr(param, "ds_id")): + continue + key = prefix + name + + # When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused + # as weights get gathered and reduced, but param.ds_id is unique across all zero weights + # (and shared params will have the same param.ds_id) + param_id = param.ds_id if is_zero3_model else param.data_ptr() + + if param_id in shared_index: + # shared weights + #print(f"`{key}` is shared with `{shared_index[param_id]}`") + shared_params_by_full_name[key] = shared_index[param_id] + else: + shared_index[param_id] = key + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + if dist.get_rank() == 0: + get_layer_state_dict(self.module, prefix="") + + return shared_params_by_full_name + def _copy_recovery_script(self, save_path): base_dir = os.path.dirname(os.path.dirname(__file__)) script = "zero_to_fp32.py" @@ -3111,8 +4210,17 @@ def _copy_recovery_script(self, save_path): dst = os.path.join(save_path, script) #logger.info(f"creating recovery script {dst}") copyfile(src, dst) - # make executable - os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) + self._change_recovery_script_permissions(dst) + + def _change_recovery_script_permissions(self, dst): + # make executable (safeguard for file shares - Azure as example) + try: + os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) + except (FileNotFoundError, PermissionError) as e: + #this message is used in unit test TestZeRONonDistributed + logger.info( + f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.' + ) def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) @@ -3122,9 +4230,55 @@ def _save_zero_checkpoint(self, save_path, tag): if self.global_rank == 0: self._copy_recovery_script(save_path) ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' - logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') + #logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') + + def _replace_module_consolidated_state_dict(self): + """ + Get a full non-partitioned state_dict with fp16 weights on cpu. + Important: this function must be called on all ranks and not just rank 0. + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict) + This method is used for tensor parallel training. + + Returns: + OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None. + """ + #TODO: If we use both Zero3 and tensor parallel simultaneously + # we need to consolidate the gather mechanisms of both. + state_dict = OrderedDict() if dist.get_rank() == 0 else None + + def get_layer_state_dict(module, prefix=""): + with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True): + for name, param in module.named_parameters(recurse=False): + if param is None: + continue + key = prefix + name + if (dist.get_rank() == 0): + state_dict[key] = param.detach().cpu() + # print(key,module, param.detach().cpu().shape) + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + get_layer_state_dict(self.module, prefix="") + + # ensure that all GPU communication tasks are completed before the process exits + get_accelerator().synchronize() + return state_dict + + def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): + """ + Consolidate the 16-bit state dictionary. + """ + if self.zero_optimization_stage() == ZeroStageEnum.weights: + return self._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters) + elif self.autotp_size() > 1: + return self._replace_module_consolidated_state_dict() + + raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " + "including Zero Stage 3 and tensor parallelism.") - def _zero3_consolidated_16bit_state_dict(self): + def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): """ Get a full non-partitioned state_dict with fp16 weights on cpu. Important: this function must be called on all ranks and not just rank 0. @@ -3150,7 +4304,7 @@ def get_layer_state_dict(module, prefix=""): if dist.get_rank() == 0: # handle params for name, param in module.named_parameters(recurse=False): - if param is None: + if param is None or (exclude_frozen_parameters and not param.requires_grad): continue key = prefix + name # can't rely on param.data_ptr() as it will be reused as weights gets @@ -3176,13 +4330,15 @@ def get_layer_state_dict(module, prefix=""): get_layer_state_dict(child, prefix + name + ".") # Prepare for checkpoint save by ensuring all parameters are partitioned - self.optimizer.checkpoint_event_prologue() + if self._optimizer_has_ckpt_event_prologue(): + self.optimizer.checkpoint_event_prologue() see_memory_usage("before get_layer_state_dict", force=False) get_layer_state_dict(self.module, prefix="") see_memory_usage("after get_layer_state_dict", force=False) - self.optimizer.checkpoint_event_epilogue() + if self._optimizer_has_ckpt_event_epilogue(): + self.optimizer.checkpoint_event_epilogue() return state_dict @@ -3191,7 +4347,7 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): compatibility""" return self.save_16bit_model(save_dir, save_filename) - def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): + def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_frozen_parameters=False): """ Save 16bit model weights @@ -3200,6 +4356,7 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): Arguments: save_dir: Required. Directory for saving the model save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` + exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. Returns: ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if @@ -3216,25 +4373,27 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): if self.zero_optimization_partition_weights(): if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default - state_dict = self._zero3_consolidated_16bit_state_dict() + state_dict = self._zero3_consolidated_16bit_state_dict( + exclude_frozen_parameters=exclude_frozen_parameters) else: # the model will be bogus if not consolidated so don't confuse the user by saving it logger.info( - f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False") + f"Did not save the model {path} because stage3_gather_16bit_weights_on_model_save is False") return False else: - state_dict = self.module.state_dict() + state_dict = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) tag = f"global_step{self.global_steps}" tag = str(tag) - self.checkpoint_engine.create(tag) + commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=False) + self.checkpoint_engine.create(commit_info) if dist.get_rank() == 0: self.checkpoint_engine.makedirs(save_dir, exist_ok=True) logger.info(f"Saving model weights to {path}, tag: {tag}") self.checkpoint_engine.save(state_dict, path) - self.checkpoint_engine.commit(tag) + self.checkpoint_engine.commit(commit_info) return True @@ -3246,3 +4405,202 @@ def empty_partition_cache(self): self.optimizer.empty_partition_cache() gc.collect() get_accelerator().empty_cache() + + def get_autosp_backend(self, compile_kwargs): + if self.compile_autosp() and self.zero_optimization_stage() not in [ + ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states + ]: + logger.info( + f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler.") + return None + + compile_config = self._config.compile_config + compile_kwargs['fullgraph'] = True + return init_autosp(self._config) + + def get_deepcompile_backend(self, backend, compile_kwargs, schedule): + if self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ + and self.zero_optimization_stage() != ZeroStageEnum.weights \ + and self.zero_optimization_stage() != ZeroStageEnum.gradients: + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + return init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + return init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + return init_z3(self, backend, compile_config, compile_kwargs, schedule) + return None + + def get_deepspeed_compile_backend(self, backend, compile_kwargs, schedule): + resolved_backend = None + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] + + assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." + + if self.compile_autosp(): + resolved_backend = self.get_autosp_backend(compile_kwargs) + else: + resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) + + return resolved_backend, schedule + + def compile(self, + backend=get_accelerator().get_compile_backend(), + compile_kwargs={}, + schedule=None, + compiled_autograd_enabled=False) -> None: + """Compile the module using the specified backend and kwargs. + If a compiler_fn is set, it will be used instead of torch.compile(). + """ + # Avoid graph breaks + deepspeed.utils.nvtx.enable_nvtx = False + + if not is_compile_supported(): + raise RuntimeError("compile is not supported in your version of PyTorch.") + + if self.is_compiled: + return + + if 'backend' in compile_kwargs: + logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.") + + logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") + + resolved_backend = None + if self.is_deepcompile_enabled(): + resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + + is_deepspeed_compile_backend = resolved_backend is not None + + # default to torch.compiler backend if deepspeed config validation fails + backend = resolved_backend or backend + + # Hook state must align with whether DeepCompile is active. + self._set_deepcompile_active(is_deepspeed_compile_backend) + + # create new dict to avoid modifying original dict + try: + self.module.compile(**{**compile_kwargs, 'backend': backend}) + except Exception: + if is_deepspeed_compile_backend: + # Restore default hooks if compilation fails before completing. + self._set_deepcompile_active(False) + raise + + self._is_compiled = True + self._compile_kwargs = compile_kwargs + if compiled_autograd_enabled: + if not self._deepcompile_active: + self._is_compiled_autograd_enabled = compiled_autograd_enabled + else: + logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.") + self._is_compiled_autograd_enabled = False + + def _set_deepcompile_active(self, active: bool) -> None: + """Toggle DeepCompile runtime state and manage forward hooks accordingly.""" + if self._deepcompile_active == active: + return + + if active: + if self.module_forward_pre_hook is not None: + self.module_forward_pre_hook.remove() + self.module_forward_pre_hook = None + if self.module_forward_post_hook is not None: + self.module_forward_post_hook.remove() + self.module_forward_post_hook = None + else: + if self.module_forward_pre_hook is None: + self.module_forward_pre_hook = self._create_module_forward_pre_hook() + if self.module_forward_post_hook is None: + self.module_forward_post_hook = self._create_module_forward_post_hook() + + self._deepcompile_active = active + + def get_compile_time(self): + from deepspeed.compile.backend import opt_pass_times + return opt_pass_times + + def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None: + register_compile_pass(pass_name, pass_fn) + + def is_deepcompile_enabled(self) -> bool: + return self._config.compile_config.deepcompile + + def is_deepcompile_active(self) -> bool: + return getattr(self, "_deepcompile_active", False) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + def _refine_include_states(self, include: Container[OffloadStateTypeEnum]) -> Container[OffloadStateTypeEnum]: + if include is None: + include = list(OffloadStateTypeEnum) + + if self.zero_use_cpu_optimizer(): + exclude_states = [OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.optim_states] + if self.zero_optimization_partition_weights(): + exclude_states.append(OffloadStateTypeEnum.lp_grads) + include = [x for x in include if x not in exclude_states] + + return include + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False) -> None: + """Offload the engine's states to the specified device. + + Arguments: + include: Optional. The set of states to offload. If not provided, all states are offloaded. + device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. + pin_memory: Optional. Whether to pin the memory of the offloaded states. + non_blocking: Optional. Whether to offload the states asynchronously. + """ + include = self._refine_include_states(include) + param_offload_config = self.zero_offload_param() + assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." + + assert not isinstance( + self.optimizer, + DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." + + if device == OffloadDeviceEnum.none: + logger.warning("No device specified for offloading states.") + return + + if device == OffloadDeviceEnum.nvme: + raise ValueError("NVMe offload is not supported for offloading states.") + + self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) + + def reload_states(self, non_blocking: bool = False) -> None: + """Reload the engine states to the original device. + + Arguments: + non_blocking: Optional. Whether to offload the states asynchronously. + """ + assert not isinstance( + self.optimizer, + DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." + + self.optimizer.reload_states(non_blocking=non_blocking) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 7fb9c9daf5c9..7a24235f57d0 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,14 +9,25 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from deepspeed.runtime import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm -from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE -from deepspeed.utils import groups, logger, log_dist -from deepspeed import comm as dist +from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer +from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter +from deepspeed.runtime.fp16.loss_scaler import LossScaleConfig, LossScaleProfile +from deepspeed.utils import logger, log_dist +from deepspeed.utils.torch import required_torch_version from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import is_moe_param_group +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank + +OVERFLOW_CHECK_TIMER = 'overflow_check' +COMPUTE_NORM_TIMER = 'compute_norm' +UNSCALE_AND_CLIP_TIMER = 'unscale_and_clip' +BASIC_STEP_TIMER = 'basic_step' +UPDATE_FP16_TIMER = 'update_fp16' + +OVERFLOW_TIMERS = [COMPUTE_NORM_TIMER, OVERFLOW_CHECK_TIMER] +STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP_TIMER, BASIC_STEP_TIMER, UPDATE_FP16_TIMER] class FP16_Optimizer(DeepSpeedOptimizer): @@ -29,6 +40,8 @@ class FP16_Optimizer(DeepSpeedOptimizer): def __init__(self, init_optimizer, deepspeed=None, + loss_scale_config=None, + low_precision_dtype=torch.float16, static_loss_scale=1.0, dynamic_loss_scale=False, initial_dynamic_scale=2**32, @@ -42,11 +55,23 @@ def __init__(self, self.fused_adam_legacy = fused_adam_legacy self.timers = timers - self.deepspeed = deepspeed self.has_moe_layers = has_moe_layers - self.using_pipeline = self.deepspeed.pipeline_parallelism + self.deepspeed = deepspeed + self.using_pipeline = getattr(self.deepspeed, 'pipeline_parallelism', False) + self.low_precision_dtype = low_precision_dtype + if loss_scale_config is None: + loss_scale_config = LossScaleConfig( + low_precision_dtype=low_precision_dtype, + dynamic_loss_scale=dynamic_loss_scale, + static_loss_scale=static_loss_scale, + dynamic_loss_args=dynamic_loss_args, + profile=LossScaleProfile.FUSED, + initial_dynamic_scale=initial_dynamic_scale, + ) + self.loss_scale_config = loss_scale_config + if not get_accelerator().is_available(): - raise SystemError("Cannot use fp16 without accelerator.") + raise SystemError("Cannot use {low_precision_dtype} without accelerator.") self.optimizer = init_optimizer # param flattened by groups @@ -54,6 +79,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self.flatten_grad_norm_mask_list = [] + self.has_executed_step = False self._global_grad_norm = 0. # loop to deal with groups @@ -72,25 +99,6 @@ def __init__(self, self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [self.fp32_groups_flat[i]] - # we may have a way of fusing dynamic scale. Do not support for now - if dynamic_loss_scale: - self.dynamic_loss_scale = True - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = 2 - - if dynamic_loss_args is None: - self.cur_scale = initial_dynamic_scale - self.scale_window = 1000 - self.min_loss_scale = 1 - else: - self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] - self.scale_window = dynamic_loss_args[SCALE_WINDOW] - self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] - else: - self.dynamic_loss_scale = False - self.cur_iter = 0 - self.cur_scale = static_loss_scale self.verbose = verbose self.custom_loss_scaler = False @@ -98,11 +106,8 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = 2 - self.step_count = 0 - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: + if required_torch_version(max_version=0.4): self.clip_grad_norm = torch.nn.utils.clip_grad_norm else: self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ @@ -126,7 +131,7 @@ def initialize_optimizer_states(self): return - def zero_grad(self, set_to_none=False): + def zero_grad(self, set_to_none=True): """ Zero FP16 parameter grads. """ @@ -156,21 +161,22 @@ def step_fused_adam(self, closure=None): norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) - prev_scale = self.cur_scale - self._update_scale(self.overflow) + if self.loss_scale_config.use_grad_scaling: + prev_scale = self.loss_scale_config.cur_scale + self._update_scale(self.overflow) - if self.overflow: - if self.verbose: - logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " - "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) - return self.overflow + if self.overflow: + if self.verbose: + logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " + "scale: {}, reducing to {}".format(prev_scale, self.loss_scale_config.cur_scale)) + return self.overflow scaled_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False) # Stash unscaled gradient norm - self._global_grad_norm = scaled_grad_norm / self.cur_scale + self._global_grad_norm = scaled_grad_norm / self.loss_scale_config.cur_scale # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], @@ -184,20 +190,6 @@ def step_fused_adam(self, closure=None): p.data = q.data return self.overflow - def start_timers(self, name_list): - if self.timers is not None: - for name in name_list: - self.timers(name).start() - - def stop_timers(self, name_list): - if self.timers is not None: - for name in name_list: - self.timers(name).stop() - - def log_timers(self, name_list): - if self.timers is not None: - self.timers.log(name_list) - def set_lr(self, lr): """Set the learning rate.""" for param_group in self.optimizer.param_groups: @@ -208,11 +200,47 @@ def get_lr(self): return self.optimizer.param_groups[0]["lr"] def override_loss_scale(self, loss_scale): + assert self.loss_scale_config.use_grad_scaling, f"Loss scale overriding only supported for torch.float16, rather than {self.low_precision_dtype}" + if loss_scale != self.external_loss_scale: logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}') self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): + # for filtering replicated tensors from tensor + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + return True + if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): + return True + + def _get_norm_mask_idx(self, group): + """The function preserves the parallel information for norm + from unflattened gradients. + + Args: + group (Iterable[Tensor] ): params group + + Returns: + torch.Tensor: A 2D tensor containing index ranges for each group, + where each row represents a [start index, end index]. + """ + group_mask_idx_list = [] + grad_flat_st_idx = 0 + grad_flat_en_idx = 0 + + for p in group: + grad_flat_en_idx = grad_flat_st_idx + p.numel() + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name()) + def step(self, closure=None): """ Not supporting closure. @@ -221,38 +249,38 @@ def step(self, closure=None): if self.fused_adam_legacy: return self.step_fused_adam() - COMPUTE_NORM = "compute_norm" - OVERFLOW_CHECK = 'overflow_check' - OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] - UNSCALE_AND_CLIP = 'unscale_and_clip' - BASIC_STEP = 'basic_step' - UPDATE_FP16 = 'update_fp16' - STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] - # First determine if there is overflow. - self.start_timers([OVERFLOW_CHECK]) + if self.timers: + self.timers(OVERFLOW_CHECK_TIMER).start() fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) - self.stop_timers([OVERFLOW_CHECK]) - prev_scale = self.cur_scale - self._update_scale(self.overflow) - if self.overflow: - if self.verbose: - log_dist( - "Overflow detected. Skipping step. Attempted loss " - f"scale: {prev_scale}, reducing to {self.cur_scale}", - ranks=[0]) - # Clear gradients - for i, group in enumerate(self.fp16_groups): - for p in group: - p.grad = None + if self.timers: + self.timers(OVERFLOW_CHECK_TIMER).stop() - self.log_timers(OVERFLOW_TIMERS) - return self.overflow + if self.loss_scale_config.use_grad_scaling: + prev_scale = self.loss_scale_config.cur_scale + self._update_scale(self.overflow) + if self.overflow: + if self.verbose: + log_dist( + "Overflow detected. Skipping step. Attempted loss " + f"scale: {prev_scale}, reducing to {self.loss_scale_config.cur_scale}", + ranks=[0]) + # Clear gradients + for i, group in enumerate(self.fp16_groups): + for p in group: + p.grad = None + if self.timers: + self.timers.log(OVERFLOW_TIMERS) + return self.overflow grads_groups_flat = [] + non_experts_grads_for_norm = [] + expert_grads_for_norm = {} + assert len(self.fp16_groups) == len(self.optimizer.param_groups) + for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype @@ -262,74 +290,87 @@ def step(self, closure=None): for p in group ])) - for p in group: - p.grad = None - self.fp32_groups_flat[i].grad = grads_groups_flat[i] + param_group = self.optimizer.param_groups[i] - self.start_timers([COMPUTE_NORM]) + # split expert and non_expert grads for norm + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] - all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) + else: + # retrieves the required mask for calculating the norm of flat_grad + # perform this collect operation only once + if not self.has_executed_step: + cur_flat_grad_norm_mask = self._get_norm_mask_idx(group) + self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) - self.stop_timers([COMPUTE_NORM]) + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) + + for p in group: + p.grad = None + + if self.timers: + self.timers(COMPUTE_NORM_TIMER).start() + + all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, + mpu=self.mpu, + grad_norm_mask=self.flatten_grad_norm_mask_list) if self.has_moe_layers: - all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm) + all_groups_norm = get_norm_with_moe_layers(all_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + if self.timers: + self.timers(COMPUTE_NORM_TIMER).stop() # Stash unscaled gradient norm - self._global_grad_norm = scaled_global_grad_norm / self.cur_scale + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale_config.cur_scale - self.start_timers([UNSCALE_AND_CLIP]) + if self.timers: + self.timers(UNSCALE_AND_CLIP_TIMER).start() self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm) - self.stop_timers([UNSCALE_AND_CLIP]) + if self.timers: + self.timers(UNSCALE_AND_CLIP_TIMER).stop() - self.start_timers([BASIC_STEP]) + if self.timers: + self.timers(BASIC_STEP_TIMER).start() self.optimizer.step() - self.stop_timers([BASIC_STEP]) + if self.timers: + self.timers(BASIC_STEP_TIMER).stop() #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None - self.start_timers([UPDATE_FP16]) + if self.timers: + self.timers(UPDATE_FP16_TIMER).start() for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) + self.has_executed_step = True + if self.timers: + self.timers(UPDATE_FP16_TIMER).stop() - self.stop_timers([UPDATE_FP16]) - - self.log_timers(STEP_TIMERS) - - self.step_count += 1 + if self.timers: + self.timers.log(STEP_TIMERS) return self.overflow - def _get_norm_with_moe_layers(self, all_groups_norm): - #all_groups_norm_old = all_groups_norm - # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce - if self.using_pipeline: - pg = self.deepspeed.mpu.get_data_parallel_group() - else: - pg = groups._get_data_parallel_group() - scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") - return all_groups_norm - def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True): # compute combined scale factor for this group - combined_scale = self.cur_scale + combined_scale = self.loss_scale_config.cur_scale if self.clip_grad > 0.: # norm is in fact norm*scale - clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad + clip = ((total_norm / self.loss_scale_config.cur_scale) + 1e-6) / self.clip_grad if clip > 1: - combined_scale = clip * self.cur_scale + combined_scale = clip * self.loss_scale_config.cur_scale if apply_scale: for grad in grad_groups_flat: @@ -349,31 +390,34 @@ def backward(self, loss, create_graph=False, retain_graph=False): scaled_loss = self.external_loss_scale * loss scaled_loss.backward() else: - scaled_loss = (loss.float()) * self.cur_scale + scaled_loss = (loss.float()) * self.loss_scale_config.cur_scale scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph) def _update_scale(self, skip): - if self.dynamic_loss_scale: - prev_scale = self.cur_scale + if self.loss_scale_config.dynamic_loss_scale: + prev_scale = self.loss_scale_config.cur_scale if skip: - self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale) - self.last_overflow_iter = self.cur_iter + self.loss_scale_config.cur_scale = max( + self.loss_scale_config.cur_scale / self.loss_scale_config.scale_factor, + self.loss_scale_config.min_loss_scale) + self.loss_scale_config.last_overflow_iter = self.loss_scale_config.cur_iter if self.verbose: - logger.info(f"\nGrad overflow on iteration {self.cur_iter}") - logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}") + logger.info(f"\nGrad overflow on iteration {self.loss_scale_config.cur_iter}") + logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.loss_scale_config.cur_scale}") else: - # Ensure self.scale_window updates since last overflow - stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 - if (stable_interval > 0) and (stable_interval % self.scale_window == 0): - self.cur_scale *= self.scale_factor + # Ensure self.loss_scale_config.scale_window updates since last overflow + stable_interval = (self.loss_scale_config.cur_iter - self.loss_scale_config.last_overflow_iter) - 1 + if (stable_interval > 0) and (stable_interval % self.loss_scale_config.scale_window == 0): + self.loss_scale_config.cur_scale *= self.loss_scale_config.scale_factor if self.verbose: - logger.info(f"No Grad overflow for {self.scale_window} iterations") - logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}") + logger.info(f"No Grad overflow for {self.loss_scale_config.scale_window} iterations") + logger.info( + f"Increasing dynamic loss scale from {prev_scale} to {self.loss_scale_config.cur_scale}") else: if skip: - logger.info("Grad overflow on iteration: %s", self.cur_iter) - logger.info("Using static loss scale of: %s", self.cur_scale) - self.cur_iter += 1 + logger.info("Grad overflow on iteration: %s", self.loss_scale_config.cur_iter) + logger.info("Using static loss scale of: %s", self.loss_scale_config.cur_scale) + self.loss_scale_config.cur_iter += 1 return # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" @@ -407,13 +451,14 @@ def state_dict(self): torch.save(checkpoint, "saved.pth") """ state_dict = {} - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['cur_scale'] = self.cur_scale - state_dict['cur_iter'] = self.cur_iter - if state_dict['dynamic_loss_scale']: - state_dict['last_overflow_iter'] = self.last_overflow_iter - state_dict['scale_factor'] = self.scale_factor - state_dict['scale_window'] = self.scale_window + if self.loss_scale_config.use_grad_scaling: + state_dict['dynamic_loss_scale'] = self.loss_scale_config.dynamic_loss_scale + state_dict['cur_scale'] = self.loss_scale_config.cur_scale + state_dict['cur_iter'] = self.loss_scale_config.cur_iter + if state_dict['dynamic_loss_scale']: + state_dict['last_overflow_iter'] = self.loss_scale_config.last_overflow_iter + state_dict['scale_factor'] = self.loss_scale_config.scale_factor + state_dict['scale_window'] = self.loss_scale_config.scale_window state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() state_dict['fp32_groups_flat'] = self.fp32_groups_flat state_dict[CLIP_GRAD] = self.clip_grad @@ -441,13 +486,14 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.cur_scale = state_dict['cur_scale'] - self.cur_iter = state_dict['cur_iter'] - if state_dict['dynamic_loss_scale']: - self.last_overflow_iter = state_dict['last_overflow_iter'] - self.scale_factor = state_dict['scale_factor'] - self.scale_window = state_dict['scale_window'] + if self.loss_scale_config.use_grad_scaling: + self.loss_scale_config.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.loss_scale_config.cur_scale = state_dict['cur_scale'] + self.loss_scale_config.cur_iter = state_dict['cur_iter'] + if state_dict['dynamic_loss_scale']: + self.loss_scale_config.last_overflow_iter = state_dict['last_overflow_iter'] + self.loss_scale_config.scale_factor = state_dict['scale_factor'] + self.loss_scale_config.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) self.clip_grad = state_dict[CLIP_GRAD] @@ -473,12 +519,16 @@ def __repr__(self): # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" def _get_loss_scale(self): + if not self.loss_scale_config.use_grad_scaling: + return None + if self.custom_loss_scaler: return self.external_loss_scale else: - return self.cur_scale + return self.loss_scale_config.cur_scale def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value + if self.loss_scale_config.use_grad_scaling: + self.loss_scale_config.cur_scale = value loss_scale = property(_get_loss_scale, _set_loss_scale) diff --git a/deepspeed/runtime/fp16/loss_scaler.py b/deepspeed/runtime/fp16/loss_scaler.py index e12ee92fdf98..b0fc8d942d99 100755 --- a/deepspeed/runtime/fp16/loss_scaler.py +++ b/deepspeed/runtime/fp16/loss_scaler.py @@ -22,15 +22,105 @@ """ import torch +from dataclasses import dataclass +from typing import Optional +from enum import Enum +from deepspeed.runtime.config_utils import DeepSpeedConfigObject from deepspeed import comm as dist from deepspeed.utils import logger INITIAL_LOSS_SCALE = 'init_scale' SCALE_WINDOW = 'scale_window' DELAYED_SHIFT = 'delayed_shift' +CONSECUTIVE_HYSTERESIS = 'consecutive_hysteresis' MIN_LOSS_SCALE = 'min_scale' +class LossScaleProfile(str, Enum): + FUSED = "fused" + UNFUSED = "unfused" + + +@dataclass(frozen=True) +class LossScaleProfileDefaults: + initial_dynamic_scale: float + default_scale_window: int + default_min_loss_scale: float + scale_factor: float + + +LOSS_SCALE_PROFILE_DEFAULTS = { + LossScaleProfile.FUSED: + LossScaleProfileDefaults( + initial_dynamic_scale=2**32, + default_scale_window=1000, + default_min_loss_scale=1, + scale_factor=2.0, + ), + LossScaleProfile.UNFUSED: + LossScaleProfileDefaults( + initial_dynamic_scale=1.0 * 2**16, + default_scale_window=1000, + default_min_loss_scale=0.25, + scale_factor=2.0, + ), +} + + +@dataclass +class LossScaleConfig: + use_grad_scaling: bool + dynamic_loss_scale: bool + cur_iter: int + cur_scale: float + last_overflow_iter: Optional[int] = None + scale_factor: Optional[float] = None + scale_window: Optional[int] = None + min_loss_scale: Optional[float] = None + + def __init__(self, + low_precision_dtype, + dynamic_loss_scale, + static_loss_scale, + dynamic_loss_args, + *, + profile: LossScaleProfile = LossScaleProfile.FUSED, + initial_dynamic_scale: Optional[float] = None): + defaults = LOSS_SCALE_PROFILE_DEFAULTS[profile] + use_grad_scaling = low_precision_dtype == torch.float16 + self.use_grad_scaling = use_grad_scaling + self.dynamic_loss_scale = False + self.cur_iter = 0 + self.cur_scale = 1.0 + self.last_overflow_iter = None + self.scale_factor = None + self.scale_window = None + self.min_loss_scale = None + + if not use_grad_scaling: + return + + self.cur_scale = static_loss_scale + if not dynamic_loss_scale: + return + + if initial_dynamic_scale is None: + initial_dynamic_scale = defaults.initial_dynamic_scale + + self.dynamic_loss_scale = True + self.last_overflow_iter = -1 + self.scale_factor = defaults.scale_factor + if dynamic_loss_args is None: + self.cur_scale = initial_dynamic_scale + self.scale_window = defaults.default_scale_window + self.min_loss_scale = defaults.default_min_loss_scale + return + + self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] + self.scale_window = dynamic_loss_args[SCALE_WINDOW] + self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] + + # item() is a recent addition, so this helps with backward compatibility. def to_python_float(t): if hasattr(t, 'item'): @@ -38,12 +128,13 @@ def to_python_float(t): return t[0] -class LossScalerBase: +class LossScalerBase(DeepSpeedConfigObject): """LossScalarBase Base class for a loss scaler """ def __init__(self, cur_scale): + super(LossScalerBase, self).__init__() self.cur_scale = cur_scale self.dynamic = False @@ -57,8 +148,14 @@ def scale_gradient(self, module, grad_in, grad_out): def update_scale(self, overflow): pass + def scale_loss(self, loss): + """ Scales the loss by the current loss scale. + We need this function to scale loss without calling backward on it. + """ + return loss * self.loss_scale + def backward(self, loss, retain_graph=False): - scaled_loss = loss * self.loss_scale + scaled_loss = self.scale_loss(loss) scaled_loss.backward(retain_graph=retain_graph) # print(f'LossScalerBackward: {scaled_loss=}') @@ -111,21 +208,21 @@ class DynamicLossScaler(LossScalerBase): init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. + consecutive_hysteresis (bool, optional, default=False): Whether to refill hysteresis if we reach an iteration that doesn't overflow """ def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False, + init_scale, + scale_window, + min_scale, + delayed_shift, + consecutive_hysteresis, raise_error_at_min_scale=True, dtype=torch.half): super(DynamicLossScaler, self).__init__(init_scale) self.cur_iter = 0 self.last_overflow_iter = -1 - self.scale_factor = scale_factor + self.scale_factor = 2.0 self.scale_window = scale_window self.min_scale = min_scale self.delayed_shift = delayed_shift @@ -190,8 +287,13 @@ def update_scale(self, overflow): self.last_overflow_iter = self.cur_iter else: if self.consecutive_hysteresis: + if dist.get_rank() == 0: + hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" + logger.info(hysteresis_msg) self.cur_hysteresis = self.delayed_shift - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + + stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 + if (stable_interval > 0) and (stable_interval % self.scale_window == 0): if not self.consecutive_hysteresis: self.cur_hysteresis = self.delayed_shift self.cur_scale *= self.scale_factor @@ -202,8 +304,7 @@ def update_scale(self, overflow): # we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling. def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args): if dtype == torch.half and dynamic_scaling: - if dynamic_loss_args is None: - return DynamicLossScaler(dtype=dtype) + assert dynamic_loss_args is not None, "Dynamic loss scaling parameters must be defined." return DynamicLossScaler(dtype=dtype, **dynamic_loss_args) loss_scale_value = static_loss_scale if dtype == torch.half else 1.0 diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index 3854e2d2cd66..fa817573f734 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -7,6 +7,7 @@ import torch import numpy as np from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version from deepspeed import comm as dist @@ -69,8 +70,6 @@ def __init__(self, super(OnebitAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.comm_time = 0.0 self.step_time = 0.0 self.ave_step = 1 @@ -85,24 +84,27 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) assert ( - (TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2 + required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) + elif self.comm_backend_name == 'compressed': + from deepspeed.runtime.comm.compressed import CompressedBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index e8a45480701f..54f7fd56abfd 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -7,8 +7,10 @@ import torch import numpy as np from deepspeed import comm as dist +from deepspeed.utils.torch import required_torch_version from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import filter_empty_parameters class OnebitLamb(torch.optim.Optimizer): @@ -81,6 +83,9 @@ def __init__(self, if amsgrad: raise RuntimeError('1-bit Lamb does not support the AMSGrad variant.') + # Filter out empty parameters (numel == 0) to avoid NaN in scaling calculations + filtered_params = filter_empty_parameters(params) + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, @@ -90,10 +95,8 @@ def __init__(self, max_coeff=max_coeff, min_coeff=min_coeff) - super(OnebitLamb, self).__init__(params, defaults) + super(OnebitLamb, self).__init__(filtered_params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.lamb_freeze_key = False self.initialize = False @@ -107,23 +110,27 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) assert ( - (TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2 + required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) + elif self.comm_backend_name == 'compressed': + from deepspeed.runtime.comm.compressed import CompressedBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size @@ -162,7 +169,7 @@ def step(self, closure=None, grads=None): else: grads_group = grads - #remove the previous stats + # remove the previous stats del self.lamb_coeffs[:] if self.lamb_freeze_key: @@ -174,10 +181,9 @@ def step(self, closure=None, grads=None): # This is used to reduce compression error during compression stage. momentum_scales = [] for group in self.param_groups: - momentum_scales.append([ - (torch.norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() - for p in group['params'] - ]) + momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) / + np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() + for p in group['params']]) united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales]) for i, group in enumerate(self.param_groups): for j, p in enumerate(group['params']): diff --git a/deepspeed/runtime/fp16/onebit/zoadam.py b/deepspeed/runtime/fp16/onebit/zoadam.py index fb2d2a061e38..70282ec41714 100644 --- a/deepspeed/runtime/fp16/onebit/zoadam.py +++ b/deepspeed/runtime/fp16/onebit/zoadam.py @@ -7,13 +7,16 @@ import torch import numpy as np from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version from deepspeed import comm as dist class ZeroOneAdam(torch.optim.Optimizer): - """Implements the 0/1 Adam algorithm. Currently GPU-only. + """ + Implements the 0/1 Adam algorithm. Currently GPU-only. For usage example please see https://www.deepspeed.ai/tutorials/zero-one-adam/ For technical details please read https://arxiv.org/abs/2202.06009 + Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. @@ -82,8 +85,6 @@ def __init__(self, super(ZeroOneAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.initialize = False self.cuda_aware = cuda_aware @@ -98,24 +99,27 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) assert ( - (TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2 + required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) + elif self.comm_backend_name == 'compressed': + from deepspeed.runtime.comm.compressed import CompressedBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index 1c57e2048771..bb48b133498f 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -11,10 +11,11 @@ import torch from torch._utils import _flatten_dense_tensors -from deepspeed.runtime import DeepSpeedOptimizer +from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm -from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE +from deepspeed.runtime.fp16.loss_scaler import LossScaleConfig, LossScaleProfile from deepspeed.utils import logger +from deepspeed.utils.torch import required_torch_version from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist @@ -30,6 +31,8 @@ class FP16_UnfusedOptimizer(DeepSpeedOptimizer): def __init__(self, init_optimizer, deepspeed=None, + loss_scale_config=None, + low_precision_dtype=torch.float16, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, @@ -44,8 +47,19 @@ def __init__(self, if dist.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') + self.low_precision_dtype = low_precision_dtype + if loss_scale_config is None: + loss_scale_config = LossScaleConfig( + low_precision_dtype=low_precision_dtype, + dynamic_loss_scale=dynamic_loss_scale, + static_loss_scale=static_loss_scale, + dynamic_loss_args=dynamic_loss_args, + profile=LossScaleProfile.UNFUSED, + ) + self.loss_scale_config = loss_scale_config + if not get_accelerator().is_available(): - raise SystemError("Cannot use fp16 without accelerator.") + raise SystemError(f"Cannot use {self.low_precision_dtype} without accelerator.") self.optimizer = init_optimizer # param groups @@ -71,25 +85,6 @@ def __init__(self, self.fp32_groups.append(fp32_group) param_group['params'] = self.fp32_groups[i] - # we may have a way of fusing dynamic scale. Do not support for now - if dynamic_loss_scale: - self.dynamic_loss_scale = True - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = 2.0 - if dynamic_loss_args is None: - self.cur_scale = 1.0 * 2**16 - self.scale_window = 1000 - self.min_loss_scale = 0.25 - else: - self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] - self.scale_window = dynamic_loss_args[SCALE_WINDOW] - self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] - else: - self.dynamic_loss_scale = False - self.cur_iter = 0 - self.cur_scale = static_loss_scale - self.custom_loss_scaler = False self.external_loss_scale = None @@ -98,9 +93,7 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = 2 - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: + if required_torch_version(max_version=0.4): self.clip_grad_norm = torch.nn.utils.clip_grad_norm else: self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ @@ -112,7 +105,7 @@ def __init__(self, self.initialize_optimizer_states() - def zero_grad(self, set_to_none=False): + def zero_grad(self, set_to_none=True): """ Zero FP16 parameter grads. """ @@ -153,13 +146,13 @@ def step_fused_lamb(self, closure=None): expert_norm_groups.append(expert_norm_group_value) self.overflow = self.overflow_checker.check_using_norm(norm_groups + expert_norm_groups) - prev_scale = self.cur_scale + prev_scale = self.loss_scale_config.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " - "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) + "scale: {}, reducing to {}".format(prev_scale, self.loss_scale_config.cur_scale)) return self.overflow self._global_grad_norm = get_global_norm(norm_list=norm_groups) @@ -201,13 +194,13 @@ def step(self, closure=None): return self.step_fused_lamb() self.overflow = self.overflow_checker.check() - prev_scale = self.cur_scale + prev_scale = self.loss_scale_config.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " - "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) + "scale: {}, reducing to {}".format(prev_scale, self.loss_scale_config.cur_scale)) return self.overflow norm_groups = [] @@ -218,7 +211,7 @@ def step(self, closure=None): norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu) norm_groups.append(norm_group_value) - # copying gradients to fp32 to wor k with fp32 parameters + # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) @@ -243,12 +236,12 @@ def step(self, closure=None): def unscale_and_clip_grads(self, total_norm, apply_scale=True): # compute combined scale factor for this group - combined_scale = self.cur_scale + combined_scale = self.loss_scale_config.cur_scale if self.clip_grad > 0.: # norm is in fact norm*scale - clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad + clip = ((total_norm / self.loss_scale_config.cur_scale) + 1e-6) / self.clip_grad if clip > 1: - combined_scale = clip * self.cur_scale + combined_scale = clip * self.loss_scale_config.cur_scale if apply_scale: for group in self.fp32_groups: @@ -269,32 +262,37 @@ def backward(self, loss, create_graph=False, retain_graph=False): if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss scaled_loss.backward() - else: - scaled_loss = (loss.float()) * self.cur_scale + elif self.loss_scale_config.use_grad_scaling: + scaled_loss = (loss.float()) * self.loss_scale_config.cur_scale scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph) + else: + loss.backward(create_graph=create_graph, retain_graph=retain_graph) def _update_scale(self, skip): - if self.dynamic_loss_scale: - prev_scale = self.cur_scale + if self.loss_scale_config.dynamic_loss_scale: + prev_scale = self.loss_scale_config.cur_scale if skip: - self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale) - self.last_overflow_iter = self.cur_iter + self.loss_scale_config.cur_scale = max( + self.loss_scale_config.cur_scale / self.loss_scale_config.scale_factor, + self.loss_scale_config.min_loss_scale) + self.loss_scale_config.last_overflow_iter = self.loss_scale_config.cur_iter if self.verbose: - logger.info("Grad overflow on iteration: %s", self.cur_iter) - logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}") + logger.info("Grad overflow on iteration: %s", self.loss_scale_config.cur_iter) + logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.loss_scale_config.cur_scale}") else: - # Ensure self.scale_window updates since last overflow - stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 - if (stable_interval > 0) and (stable_interval % self.scale_window == 0): - self.cur_scale *= self.scale_factor + # Ensure self.loss_scale_config.scale_window updates since last overflow + stable_interval = (self.loss_scale_config.cur_iter - self.loss_scale_config.last_overflow_iter) - 1 + if (stable_interval > 0) and (stable_interval % self.loss_scale_config.scale_window == 0): + self.loss_scale_config.cur_scale *= self.loss_scale_config.scale_factor if self.verbose: - logger.info(f"No Grad overflow for {self.scale_window} iterations") - logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}") + logger.info(f"No Grad overflow for {self.loss_scale_config.scale_window} iterations") + logger.info( + f"Increasing dynamic loss scale from {prev_scale} to {self.loss_scale_config.cur_scale}") else: if skip: - logger.info("Grad overflow on iteration %s", self.cur_iter) - logger.info("Using static loss scale of %s", self.cur_scale) - self.cur_iter += 1 + logger.info("Grad overflow on iteration %s", self.loss_scale_config.cur_iter) + logger.info("Using static loss scale of %s", self.loss_scale_config.cur_scale) + self.loss_scale_config.cur_iter += 1 return # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" @@ -321,10 +319,10 @@ def _get_loss_scale(self): if self.custom_loss_scaler: return self.external_loss_scale else: - return self.cur_scale + return self.loss_scale_config.cur_scale def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value + self.loss_scale_config.cur_scale = value loss_scale = property(_get_loss_scale, _set_loss_scale) @@ -340,13 +338,13 @@ def state_dict(self): torch.save(checkpoint, "saved.pth") """ state_dict = {} - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['cur_scale'] = self.cur_scale - state_dict['cur_iter'] = self.cur_iter + state_dict['dynamic_loss_scale'] = self.loss_scale_config.dynamic_loss_scale + state_dict['cur_scale'] = self.loss_scale_config.cur_scale + state_dict['cur_iter'] = self.loss_scale_config.cur_iter if state_dict['dynamic_loss_scale']: - state_dict['last_overflow_iter'] = self.last_overflow_iter - state_dict['scale_factor'] = self.scale_factor - state_dict['scale_window'] = self.scale_window + state_dict['last_overflow_iter'] = self.loss_scale_config.last_overflow_iter + state_dict['scale_factor'] = self.loss_scale_config.scale_factor + state_dict['scale_window'] = self.loss_scale_config.scale_window state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() state_dict['fp32_groups'] = self.fp32_groups return state_dict @@ -374,13 +372,13 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.cur_scale = state_dict['cur_scale'] - self.cur_iter = state_dict['cur_iter'] + self.loss_scale_config.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.loss_scale_config.cur_scale = state_dict['cur_scale'] + self.loss_scale_config.cur_iter = state_dict['cur_iter'] if state_dict['dynamic_loss_scale']: - self.last_overflow_iter = state_dict['last_overflow_iter'] - self.scale_factor = state_dict['scale_factor'] - self.scale_window = state_dict['scale_window'] + self.loss_scale_config.last_overflow_iter = state_dict['last_overflow_iter'] + self.loss_scale_config.scale_factor = state_dict['scale_factor'] + self.loss_scale_config.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index 3d7538ac4f3b..de8c34fa039f 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -12,21 +12,19 @@ from deepspeed.runtime.zero import GatheredParameters import time import gc - +import math from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from torch import nn from deepspeed.utils import logger - -from deepspeed.ops.op_builder import InferenceBuilder - from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding +from ..ops.transformer.inference.op_binding.workspace import WorkspaceOp + try: import transformers OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding -except: +except Exception: OPTLearnedPositionalEmbedding = None -inference_cuda_module = None class DeepSpeedHybridEngine(DeepSpeedEngine): @@ -61,12 +59,8 @@ def __init__(self, args, model, **kwargs): self._total_batch_size = None self._gather_latency = 0 - global inference_cuda_module - if inference_cuda_module is None: - builder = InferenceBuilder() - inference_cuda_module = builder.load() - self.is_lora_fused = False + self.workspace = WorkspaceOp() def convert_to_linear_transposed(self, model): @@ -83,18 +77,36 @@ def _replace_linear_layer(r_module, parent_type=None, prev_type=None): def new_inference_container(self, orig_layer, policy_cls, layer_id): policy = policy_cls(orig_layer, inference=True) + + if self._config.float16_config.enabled: + inference_dtype = torch.float16 + elif self._config.bfloat16_config.enabled: + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + _container = policy_to_ds_container( policy=policy, - config=DeepSpeedInferenceConfig(set_empty_params=True, - max_out_tokens=self._config.hybrid_engine.max_out_tokens, - min_out_tokens=self._config.hybrid_engine.max_out_tokens, - transposed_mode=True), + config=DeepSpeedInferenceConfig( + set_empty_params=True, + dtype=inference_dtype, + max_out_tokens=self._config.hybrid_engine.max_out_tokens, + min_out_tokens=self._config.hybrid_engine.max_out_tokens, + transposed_mode=True, + ), model_config=self.module.config if hasattr(self.module, 'config') else None, layer_id=layer_id, child=orig_layer) - _container.set_dtype(self._config.fp16_enabled) - _container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group) + if self.mpu is not None: + if hasattr(self.mpu, 'get_model_parallel_world_size'): + _container.set_tensor_parallel_config(self.mpu.get_model_parallel_world_size(), + self.mpu.get_model_parallel_group()) + else: + _container.set_tensor_parallel_config(self.mpu.get_tensor_model_parallel_world_size(), + self.mpu.get_tensor_model_parallel_group()) + else: + _container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group) _container.initialize_tensors(enable_training=True) _container.create_ds_model_config() _container.create_module() @@ -117,31 +129,19 @@ def populate_all_inference_policies(self): OPTLearnedPositionalEmbedding: (OPTEmbedding, ) }) - def _fuse_lora(self, params, lora_params): - maybe_has_lora_params = [p for p in params if len(p.shape) > 1] - for lora_param, weight in zip(lora_params, maybe_has_lora_params): - if len(lora_params) > 0: - lora_right_weight, \ - lora_left_weight, \ - lora_scaling = lora_param - weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + def _fuse_lora_layer(self, layer_id): + self._inference_containers[layer_id].fuse_lora() def fuse_lora_weight(self): for layer_id in range(len(self.layer_params)): - self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + self._fuse_lora_layer(layer_id) - def _unfuse_lora(self, params, lora_params): - maybe_has_lora_params = [p for p in params if len(p.shape) > 1] - for lora_param, weight in zip(lora_params, maybe_has_lora_params): - if len(lora_params) > 0: - lora_right_weight, \ - lora_left_weight, \ - lora_scaling = lora_param - weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + def _unfuse_lora_layer(self, layer_id): + self._inference_containers[layer_id].unfuse_lora() def unfuse_lora_weight(self): for layer_id in range(len(self.layer_params)): - self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + self._unfuse_lora_layer(layer_id) def unfuse_lora_weight_non_pinned(self): for layer_id in range(len(self.layer_params)): @@ -150,17 +150,17 @@ def unfuse_lora_weight_non_pinned(self): non_active_params.extend(non_active_lora_params) with GatheredParameters(non_active_params): - self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + self._unfuse_lora_layer(layer_id) def retake_inference_cache(self): if self._config.hybrid_engine.release_inference_cache: - retake_success = inference_cuda_module.retake_workspace() + retake_success = self.workspace.retake_workspace() if not retake_success: - logger.warning("Unable to acquire workspace on first attempt, emtpying cache and retrying.") + logger.warning("Unable to acquire workspace on first attempt, emptying cache and retrying.") gc.collect() get_accelerator().empty_cache() - retake_success = inference_cuda_module.retake_workspace() + retake_success = self.workspace.retake_workspace() if not retake_success: raise RuntimeError("Unable to retake inference workspace.") @@ -181,7 +181,7 @@ def generate(self, *inputs, **kwargs): partition_size = self._config.hybrid_engine.tp_gather_partition_size - layer_groups = len(self.layer_params) // partition_size + layer_groups = math.ceil(len(self.layer_params) / partition_size) for lg in range(layer_groups): non_active_params = [] non_active_lora_params = [] @@ -194,9 +194,11 @@ def generate(self, *inputs, **kwargs): for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size), 1): if len(self.all_lora_params) > 0: - self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) - self._inference_containers[layer_id].apply_tensor_parallelism( - mp_group=self.mp_group, tp_size=self._config.hybrid_engine.inference_tp_size) + self._fuse_lora_layer(layer_id) + + if self.mpu is not None: + self._inference_containers[layer_id].apply_tensor_parallelism(self.mp_replace, + reversed_dim=True) # TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache # is enabled. @@ -215,7 +217,7 @@ def generate(self, *inputs, **kwargs): dist.all_gather_into_tensor(output, input_cont, group=self.mp_group) if len(inputs) > 0: - inputs = (output, ) + inputs = (output, *inputs[1:]) else: kwargs['input_ids'] = output @@ -261,7 +263,7 @@ def generate(self, *inputs, **kwargs): self.is_lora_fused = False if self._config.hybrid_engine.release_inference_cache: - inference_cuda_module.release_workspace() + self.workspace.release_workspace() gc.collect() get_accelerator().empty_cache() @@ -288,8 +290,13 @@ def create_inference_containers(self, module, layer_id=0): layer_id += 1 else: - self._other_layers.append(self.inference_policies[child.__class__][0]( - weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None)) + if self.inference_policies[child.__class__][0] == LinearLayer: + self._other_layers.append(self.inference_policies[child.__class__][0](module=child, + mp_group=None, + skip_partition=True)) + else: + self._other_layers.append(self.inference_policies[child.__class__][0]( + weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None)) self._orig_modules_others.append(child) self._orig_fwds_others.append(child.forward) else: @@ -306,27 +313,49 @@ def create_inference_module(self): self._orig_fwds_others = [] if self._config.hybrid_engine.inference_tp_size > 1: - global_rank = dist.get_rank() - world_size = dist.get_world_size() - mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size - num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size - for mp_group_id in range(num_mp_groups): - ranks = list( - range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \ - (mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \ - 1) - ) - mp_group = dist.new_group(ranks) - if global_rank in ranks: - self.mp_group = mp_group + if self.mpu is None: + global_rank = dist.get_rank() + world_size = dist.get_world_size() + mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size + num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size + for mp_group_id in range(num_mp_groups): + ranks = list( + range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \ + (mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \ + 1) + ) + mp_group = dist.new_group(ranks) + if global_rank in ranks: + # mp_group is used for broader collective + self.mp_group = mp_group + + # mp_replace is used for container tensor slicing + from deepspeed.module_inject import ReplaceWithTensorSlicing + self.mp_replace = ReplaceWithTensorSlicing( + mp_group=self.mp_group, + mp_size=self._config.hybrid_engine.inference_tp_size, + out_dim=0, + in_dim=1) + + else: + self.mp_group = self.mpu.get_model_parallel_group() if hasattr(self.mpu, 'get_model_parallel_group') else \ + self.mpu.get_tensor_model_parallel_group() + + from deepspeed.module_inject import ReplaceWithTensorSlicing + self.mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, + mp_size=self._config.hybrid_engine.inference_tp_size, + out_dim=0, + in_dim=1) else: self.mp_group = None + self.mp_replace = None self.populate_all_inference_policies() self.all_layers_params = list(self.module.parameters()) self.create_inference_containers(self.module) - self._generate = self.module.generate - self.module.generate = self.generate + if len(self._inference_containers) > 0: + self._generate = self.module.generate + self.module.generate = self.generate self._t0 = time.time() @@ -341,7 +370,7 @@ def run_forward(*inputs, **kwargs): if len(self.all_lora_params) > 0: # Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory if not self.is_lora_fused: - self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + self._fuse_lora_layer(layer_id) # Set the is_lora_fused to true when reaching the last layer if layer_id == len(self.layer_params) - 1: self.is_lora_fused = True @@ -355,14 +384,20 @@ def eval(self): self._total_latency = self._total_latency + latency self._iters = self._iters + 1 if not dist.is_initialized() or dist.get_rank() == 0: + if self._total_batch_size is not None: + cur_samples_p_sec = f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + avg_samples_p_sec = f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}' + else: + cur_samples_p_sec = '' + avg_samples_p_sec = '' others = latency - (self._generate_latency + self._training_latency) print(f'|E2E latency={(latency):.2f}s ' + \ f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) ' f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \ f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \ - f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' - f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \ - f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}') + f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' + \ + cur_samples_p_sec + \ + avg_samples_p_sec) self._t_start = time.time() self._training_latency = 0 super().eval() @@ -374,6 +409,8 @@ def eval(self): else: orig_module.forward = inference_container.module.forward + inference_container.transform_for_inference() + if not self.Z3_enabled or self.gather_all_layers: for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers): orig_module.forward = inference_layer.forward @@ -385,7 +422,9 @@ def eval(self): def train(self, mode=True): if mode and len(self._orig_modules) > 0: - for orig_module, orig_fwd in zip(self._orig_modules, self._orig_fwds): + for inference_container, orig_module, orig_fwd in zip(self._inference_containers, self._orig_modules, + self._orig_fwds): + inference_container.transform_for_training() orig_module.forward = orig_fwd for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others): orig_module.forward = orig_fwd @@ -395,10 +434,12 @@ def train(self, mode=True): def step(self, lr_kwargs=None): super().step(lr_kwargs=lr_kwargs) - if(self._inference_containers[0].module.attention.attn_qkvw is not None and \ - self._inference_containers[0].q_k_v is not None): - for inference_container in self._inference_containers: - inference_container.reset_qkv() + + if len(self._inference_containers) > 0: + if not self.Z3_enabled: + for inference_container in self._inference_containers: + inference_container.reset_params() + if self._training_start_time is not None: self._training_latency += (time.time() - self._training_start_time) self._training_start_time = time.time() diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index d2bd93d8ee31..f9a1fa7ad162 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -13,13 +13,15 @@ from torch.optim import Optimizer import math from deepspeed.utils import logger +from torch import tensor, is_tensor LR_SCHEDULE = 'lr_schedule' LR_RANGE_TEST = 'LRRangeTest' ONE_CYCLE = 'OneCycle' WARMUP_LR = 'WarmupLR' WARMUP_DECAY_LR = 'WarmupDecayLR' -VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR] +WARMUP_COSINE_LR = 'WarmupCosineLR' +VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR, WARMUP_COSINE_LR] LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr' LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate' @@ -50,6 +52,9 @@ WARMUP_LOG_RATE = 'log' WARMUP_LINEAR_RATE = 'linear' +WARMUP_MIN_RATIO = 'warmup_min_ratio' +COS_MIN_RATIO = 'cos_min_ratio' + TOTAL_NUM_STEPS = 'total_num_steps' @@ -109,6 +114,11 @@ def add_tuning_arguments(parser): type=str, default=WARMUP_LOG_RATE, help='WarmupLR increasing function during warmup') + + # WarmUP cos LR + group.add_argument("--warmup_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.') + group.add_argument("--cos_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.') + return parser @@ -200,7 +210,7 @@ def get_config_from_args(args): if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None: return None, '--{} not specified on command line'.format(LR_SCHEDULE) - if not args.lr_schedule in VALID_LR_SCHEDULES: + if args.lr_schedule not in VALID_LR_SCHEDULES: return None, '{} is not supported LR schedule'.format(args.lr_schedule) config = {} @@ -218,16 +228,16 @@ def get_config_from_args(args): def get_lr_from_config(config): - if not 'type' in config: + if 'type' not in config: return None, 'LR schedule type not defined in config' - if not 'params' in config: + if 'params' not in config: return None, 'LR schedule params not defined in config' lr_schedule = config['type'] lr_params = config['params'] - if not lr_schedule in VALID_LR_SCHEDULES: + if lr_schedule not in VALID_LR_SCHEDULES: return None, '{} is not a valid LR schedule'.format(lr_schedule) if lr_schedule == LR_RANGE_TEST: @@ -238,6 +248,15 @@ def get_lr_from_config(config): return lr_params[WARMUP_MAX_LR], '' +def update_lr(param_groups, lrs): + for param_group, lr in zip(param_groups, lrs): + # new LR should match the type of current LR for scalar and Tensor LR support + if is_tensor(param_group['lr']): + lr = tensor([lr], device=param_group['lr'].device) + param_group['lr'] = lr + return [group['lr'] for group in param_groups] + + """ Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped optimizer to see if requirement is satisfied. @@ -259,7 +278,7 @@ class LRRangeTest(object): """Sets the learning rate of each parameter group according to learning rate range test (LRRT) policy. The policy increases learning rate starting from a base value with a constant frequency, as detailed in - the paper `A disciplined approach to neural network hyper-parameters: Part1`_. + the paper `A disciplined approach to neural network hyper-parameters: Part 1 `_ LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to configure the LR boundaries for Cyclic LR schedules. @@ -319,7 +338,7 @@ def __init__(self, self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval if last_batch_iteration == -1: - self._update_optimizer(self.min_lr) + self._last_lr = update_lr(self.optimizer.param_groups, self.min_lr) def _staircase_interval(self): return math.floor(float(self.last_batch_iteration + 1) / self.step_size) @@ -340,16 +359,11 @@ def get_last_lr(self): assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr - def _update_optimizer(self, group_lrs): - for param_group, lr in zip(self.optimizer.param_groups, group_lrs): - param_group['lr'] = lr - def step(self, batch_iteration=None): if batch_iteration is None: batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration - self._update_optimizer(self.get_lr()) - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -369,7 +383,7 @@ class OneCycle(object): 1CLR policy changes the learning rate after every batch. `step` should be called after a batch has been used for training. - This implementation was adapted from the github repo: `pytorch/pytorch`_ + This implementation was adapted from the github repo: `PyTorch `_. Args: optimizer (Optimizer): Wrapped optimizer. @@ -457,7 +471,6 @@ def __init__(self, if cycle_momentum: self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration) - # Initialize batch iteration tracker self.last_batch_iteration = last_batch_iteration @@ -499,7 +512,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration): if 'betas' not in optimizer.defaults: optimizer_name = type(optimizer).__name__ - logger.warn( + logger.warning( f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults" ) self.cycle_momentum = False @@ -607,9 +620,7 @@ def step(self, batch_iteration=None): batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) if self.cycle_momentum: momentums = self.get_mom() @@ -648,13 +659,16 @@ class WarmupLR(object): def __init__(self, optimizer: Optimizer, warmup_min_lr: float = 0.0, - warmup_max_lr: float = 0.001, + warmup_max_lr: float = None, warmup_num_steps: int = 1000, warmup_type: str = WARMUP_LOG_RATE, last_batch_iteration: int = -1): self.optimizer = get_torch_optimizer(optimizer) + if warmup_max_lr is None: + warmup_max_lr = [group['lr'] for group in self.optimizer.param_groups][0] + self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr") self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr") self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)] @@ -667,11 +681,14 @@ def __init__(self, self.warmup_type = warmup_type self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) self.last_batch_iteration = last_batch_iteration + # Initialize lr in optimizer + if last_batch_iteration == -1: + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def get_lr(self): if self.last_batch_iteration < 0: logger.warning("Attempting to get learning rate from scheduler before it has started") - return [0.0] + return self.min_lrs gamma = self._get_gamma() return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)] @@ -685,9 +702,7 @@ def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -761,3 +776,110 @@ def _get_gamma(self): 0.0, float(self.total_num_steps - self.last_batch_iteration) / float(max(1.0, self.total_num_steps - self.warmup_num_steps))) + + +class WarmupCosineLR(object): + """Increase the learning rate of each parameter group from min lr ratio to max lr ratio + over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_num_steps (int): total number of training steps + warmup_min_ratio (float or list): warmup start learning rate ratio. Default: 0 + warmup_num_steps (int): number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000 + warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log + cos_min_ratio (float): cosine end learning rate ratio. Default: 0.0001 + last_batch_iteration (int): The index of the last batch. Default: -1. + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = WarmupCosineLR(optimizer, 1000000) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + """ + + def __init__(self, + optimizer: Optimizer, + total_num_steps: int, + warmup_min_ratio: float = 0.0, + warmup_num_steps: int = 1000, + cos_min_ratio: float = 0.0001, + warmup_type: str = WARMUP_LOG_RATE, + last_batch_iteration: int = -1): + + self.optimizer = get_torch_optimizer(optimizer) + + self.total_num_steps = total_num_steps + self.last_batch_iteration = last_batch_iteration + self.cos_min_ratio = cos_min_ratio + + self.warmup_type = warmup_type + self.warmup_min_ratio = warmup_min_ratio + self.warmup_num_steps = max(2, warmup_num_steps) + self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) + + if self.total_num_steps < self.warmup_num_steps: + logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format( + total_num_steps, warmup_num_steps)) + self.org_lrs = [group['lr'] for group in self.optimizer.param_groups] + + # Initialize lrs in optimizer groups + if last_batch_iteration == -1: + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) + + def get_lr_ratio(self): + if self.last_batch_iteration < 0: + logger.warning("Attempting to get learning rate from scheduler before it has started") + return 0.0 + + if self.last_batch_iteration < self.warmup_num_steps: + if self.warmup_type == WARMUP_LOG_RATE: + ratio = self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + elif self.warmup_type == WARMUP_LINEAR_RATE: + ratio = self.last_batch_iteration / self.warmup_num_steps + ratio_delta = 1. - self.warmup_min_ratio + ratio = self.warmup_min_ratio + ratio * ratio_delta + return ratio + + real_last_step = self.last_batch_iteration - self.warmup_num_steps + 1 + real_total_steps = self.total_num_steps - self.warmup_num_steps + ratio_delta = 1. - self.cos_min_ratio + ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2 + ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio) + return ratio + + def step(self, last_batch_iteration=None): + if last_batch_iteration is None: + last_batch_iteration = self.last_batch_iteration + 1 + self.last_batch_iteration = last_batch_iteration + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) + + def get_lr(self): + if self.last_batch_iteration < 0: + logger.warning("Attempting to get learning rate from scheduler before it has started") + return [0.0 for _ in self.org_lrs] + lr_ratio = self.get_lr_ratio() + return [org_lr * lr_ratio for org_lr in self.org_lrs] + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + + def state_dict(self): + return {'last_batch_iteration': self.last_batch_iteration} + + def load_state_dict(self, sd): + self.last_batch_iteration = sd['last_batch_iteration'] + + def _format_param(self, optimizer, param_value, param_name): + if isinstance(param_value, list) or isinstance(param_value, tuple): + if len(param_value) != len(optimizer.param_groups): + raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name, + FileNotFoundError(param_value))) + return list(param_value) + return [param_value] * len(optimizer.param_groups) diff --git a/deepspeed/runtime/model_checkpointing/__init__.py b/deepspeed/runtime/model_checkpointing/__init__.py new file mode 100644 index 000000000000..5e60b03ac671 --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .constants import * +from .writer_factory import CheckpointWriterFactory diff --git a/deepspeed/runtime/model_checkpointing/config.py b/deepspeed/runtime/model_checkpointing/config.py new file mode 100644 index 000000000000..d5a579fe31f0 --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/config.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.config_utils import get_scalar_param +from .constants import * + +VALID_VALUES = { + CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_MODES, + CHECKPOINT_WRITER_TYPE: CHECKPOINT_WRITER_TYPES, + CHECKPOINT_DATA_PARALLEL: CHECKPOINT_DATA_PARALLEL_UNITS +} + +CHECKPOINT_DEFAULT_DICT = { + CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_DEFAULT, + CHECKPOINT_SERIALIZATION: CHECKPOINT_SERIALIZATION_DEFAULT, + CHECKPOINT_WRITER: CHECKPOINT_WRITER_DEFAULT +} + + +def _validate_config_values(config_name, config_dict, valid_values): + for key, value in config_dict.items(): + if value is None: + continue + if key in valid_values.keys(): + assert value in valid_values[key], \ + f"{config_name} contains invalid value {value} for {key}, expecting one of {valid_values[key]}" + + +def _make_upper_case(value): + return value if value is None else value.upper() + + +def get_checkpoint_writer_config(param_dict): + writer_dict = param_dict.get(CHECKPOINT_WRITER, None) + if writer_dict is None: + return CHECKPOINT_WRITER_DEFAULT + + writer_config = { + CHECKPOINT_WRITER_TYPE: + _make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_WRITER_TYPE, CHECKPOINT_WRITER_TYPE_DEFAULT)), + CHECKPOINT_IO_BUFFER_SIZE: + get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_SIZE, CHECKPOINT_IO_BUFFER_SIZE_DEFAULT), + CHECKPOINT_IO_BUFFER_DOUBLE: + get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_DOUBLE, CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT), + CHECKPOINT_IO_STATISTICS: + get_scalar_param(writer_dict, CHECKPOINT_IO_STATISTICS, CHECKPOINT_IO_STATISTICS_DEFAULT), + CHECKPOINT_DATA_PARALLEL: + _make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_DATA_PARALLEL, CHECKPOINT_DATA_PARALLEL_DEFAULT)), + CHECKPOINT_WRITER_DECOUPLED: + get_scalar_param(writer_dict, CHECKPOINT_WRITER_DECOUPLED, CHECKPOINT_WRITER_DECOUPLED_DEFAULT), + CHECKPOINT_IO_MULTIPLIER: + get_scalar_param(writer_dict, CHECKPOINT_IO_MULTIPLIER, CHECKPOINT_IO_MULTIPLIER_DEFAULT), + } + _validate_config_values(CHECKPOINT_WRITER, writer_config, VALID_VALUES) + + return writer_config + + +def get_checkpoint_config(param_dict): + checkpoint_dict = param_dict.get(CHECKPOINT, None) + if checkpoint_dict is None: + return CHECKPOINT_DEFAULT_DICT + + checkpoint_config = { + CHECKPOINT_TAG_VALIDATION: + get_scalar_param(checkpoint_dict, CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT).upper(), + CHECKPOINT_SERIALIZATION: + get_scalar_param(checkpoint_dict, CHECKPOINT_SERIALIZATION, CHECKPOINT_SERIALIZATION_DEFAULT), + CHECKPOINT_WRITER: + get_checkpoint_writer_config(checkpoint_dict) + } + + _validate_config_values(CHECKPOINT, checkpoint_config, VALID_VALUES) + + return checkpoint_config diff --git a/deepspeed/runtime/model_checkpointing/constants.py b/deepspeed/runtime/model_checkpointing/constants.py new file mode 100644 index 000000000000..3b9bd549af92 --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/constants.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + + +######################################### +# Validation modes +######################################### +class ValidationMode: + WARN = "WARN" + IGNORE = "IGNORE" + FAIL = "FAIL" + + +######################################### +# Checkpoint config params +######################################### +# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]} +CHECKPOINT_FORMAT = ''' +"checkpoint": { + "tag_validation": [Ignore|Warn|Fail], + "checkpoint_serialization": False, + "writer": { + "type": [mock|python|fast], + "decoupled": [True|False] + "io_buffer_size": 64e6, + "io_buffer_double": True, + "show_statistics": False, + "data_parallel": [replica|socket|machine], + "io_multiplier": 1, + } +} +''' +CHECKPOINT = "checkpoint" +CHECKPOINT_TAG_VALIDATION = "tag_validation" +CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN +CHECKPOINT_TAG_VALIDATION_MODES = [ValidationMode.WARN, ValidationMode.IGNORE, ValidationMode.FAIL] + +CHECKPOINT_SERIALIZATION = "checkpoint_serialization" +CHECKPOINT_SERIALIZATION_DEFAULT = True + +CHECKPOINT_WRITER = "writer" +CHECKPOINT_WRITER_DEFAULT = None + +CHECKPOINT_WRITER_TYPE = "type" + + +class CheckpointWriterType: + MOCK = "MOCK" + PYTHON = "PYTHON" + FAST = "FAST" + + +CHECKPOINT_WRITER_TYPE_DEFAULT = CheckpointWriterType.FAST +CHECKPOINT_WRITER_TYPES = [CheckpointWriterType.MOCK, CheckpointWriterType.PYTHON, CheckpointWriterType.FAST] + +CHECKPOINT_IO_BUFFER_SIZE = "io_buffer_size" +CHECKPOINT_IO_BUFFER_SIZE_DEFAULT = 64 * (1024**2) + +CHECKPOINT_IO_BUFFER_DOUBLE = "io_buffer_double" +CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT = True + +CHECKPOINT_IO_MULTIPLIER = "io_multiplier" +CHECKPOINT_IO_MULTIPLIER_DEFAULT = 1 + +CHECKPOINT_IO_STATISTICS = "show_statistics" +CHECKPOINT_IO_STATISTICS_DEFAULT = False + +CHECKPOINT_DATA_PARALLEL = "data_parallel" +CHECKPOINT_DATA_PARALLEL_DEFAULT = None + + +class CheckpointDataParallel: + REPLICA = "REPLICA" + SOCKET = "SOCKET" + MACHINE = "MACHINE" + + +CHECKPOINT_DATA_PARALLEL_UNITS = [ + CheckpointDataParallel.REPLICA, CheckpointDataParallel.SOCKET, CheckpointDataParallel.MACHINE +] + +CHECKPOINT_WRITER_DECOUPLED = "decoupled" +CHECKPOINT_WRITER_DECOUPLED_DEFAULT = False diff --git a/deepspeed/runtime/model_checkpointing/data_parallel_writer_factory.py b/deepspeed/runtime/model_checkpointing/data_parallel_writer_factory.py new file mode 100644 index 000000000000..0a9740a381ab --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/data_parallel_writer_factory.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass +from deepspeed.checkpoint.reshape_utils import partition_data +from deepspeed.runtime.zero.config import ZeroStageEnum +from .constants import * + + +@dataclass +class DataParallelWriterConfig(object): + world_size: int + rank: int + global_rank: int + local_rank: int + pure_dp: bool + + +class DataParallelWriterFactory(object): + + def __init__(self, uni_parallel_info, parallel_unit): + self._uni_parallel_info = uni_parallel_info + self._parallel_unit = parallel_unit + if parallel_unit == CheckpointDataParallel.SOCKET: + self._num_resources = uni_parallel_info.num_sockets + else: + self._num_resources = uni_parallel_info.num_machines + self._ranks_per_resource = max(1, self._uni_parallel_info.global_world_size // self._num_resources) + + def create_config(self, zero_stage, has_moe_layers): + if zero_stage == ZeroStageEnum.weights: + return self._create_config(1, 0) + + if has_moe_layers: + writer_config = self._get_expert_data_parallel_config() + else: + writer_config = self._get_data_parallel_config() + + if writer_config is None and zero_stage >= ZeroStageEnum.optimizer_states: + return self._create_config(1, 0) + + return writer_config + + def _create_config(self, world_size, rank): + return DataParallelWriterConfig(world_size=world_size, + rank=rank, + global_rank=self._uni_parallel_info.global_rank, + local_rank=self._uni_parallel_info.local_rank, + pure_dp=self._uni_parallel_info.pure_dp) + + def _get_expert_data_parallel_config(self): + ep_info = self._uni_parallel_info.ep_info + if self._parallel_unit is None: + dp_rank = ep_info.dp_rank + return self._create_config(1, 0) if dp_rank == 0 else None + + assert self._uni_parallel_info.pure_dp, \ + '3D parallelism is not yet supported for data parallel checkpointing.' + + if self._parallel_unit == CheckpointDataParallel.REPLICA or ep_info.ep_world_size == 1: + return self._get_parallel_write_for_ddp(ep_info.dp_world_size, ep_info.dp_rank) + + return self._get_expert_parallel_write_for_2d() + + def _get_expert_parallel_write_for_2d(self): + ep_info = self._uni_parallel_info.ep_info + + def _get_expert_slice_resources(expert_resources, resource_name): + ep_world_size = ep_info.ep_world_size + slices_per_resource = min(self._ranks_per_resource, ep_world_size) + assert slices_per_resource <= len(expert_resources) + + ep_num_resources = len(expert_resources) + assert ep_num_resources % slices_per_resource == 0, f'{resource_name}: Expected ep_num_resources={ep_num_resources} to multiple of slices_per_resource={slices_per_resource} for ep_world_size={ep_world_size}' + + slice_partitions = partition_data(expert_resources, slices_per_resource) + # print( + # f'edp_resource_partition: self._uni_parallel_info.global_rank={self._uni_parallel_info.global_rank} expert_resources={expert_resources} slices_per_resource={slices_per_resource} ep_world_size={ep_world_size} slice_partitions={slice_partitions}' + # ) + resource_index = ep_info.ep_rank % slice_resources + return slice_partitions[resource_index] + + dp_ranks = ep_info.dp_peer_ranks + expert_resources = [r // self._ranks_per_resource for r in dp_ranks] + slice_resources = _get_expert_slice_resources(expert_resources, self._parallel_unit) + assert all([idx < self._num_resources for idx in expert_resources]), \ + f'Detected invalid resource index in expert_resources={expert_resources}, self._num_resources={self._num_resources}' + return self._assign_resources_to_tensor_slice(slice_resources, ep_info.ep_rank, dp_ranks) + + def _get_data_parallel_config(self): + mpu_info = self._uni_parallel_info.mpu_info + if self._parallel_unit is None: + dp_rank = self._uni_parallel_info.dp_rank if mpu_info is None else mpu_info.dp_rank + return self._create_config(1, 0) if dp_rank == 0 else None + + if self._uni_parallel_info.pure_dp: + return self._get_parallel_write_for_ddp(self._uni_parallel_info.global_world_size, + self._uni_parallel_info.global_rank) + + if self._parallel_unit == CheckpointDataParallel.REPLICA: + return self._create_config(mpu_info.dp_world_size, mpu_info.dp_rank) + + return self._get_parallel_write_for_3d() + + def _get_parallel_write_for_3d(self): + mpu_info = self._uni_parallel_info.mpu_info + my_global_rank = self._uni_parallel_info.global_rank + + def _expand_resources(resource_list, new_size): + old_size = len(resource_list) + if old_size >= new_size: + return resource_list + + assert new_size % old_size == 0, f'Expect new_size={new_size} to be multiple of old_size={old_size}' + multiplier = new_size // old_size + new_resource_list = [] + for r in resource_list: + new_resource_list += [r] * multiplier + # print(f'expand_resources: {my_global_rank=} {old_size=} {new_size=} {resource_list=} {new_resource_list=}') + return new_resource_list + + # Getting resource partition for a tensor slice is a 2-step process + # 1. Get resource partitions for all pipeline stages. A pipeline stage is a 2D grid of size TP x DP + def _get_pipeline_stage_resources(resource_indices): + num_resources = len(resource_indices) + pp_world_size = mpu_info.pp_world_size + if num_resources < pp_world_size: + resource_indices = _expand_resources(resource_indices, pp_world_size) + num_resources = pp_world_size + global_resource_partitions = partition_data(resource_indices, pp_world_size) + pp_rank = mpu_info.pp_rank + return global_resource_partitions[pp_rank] + + # 2. Get resource partition for tensor slice. A tensor slice is a 1D vector of size DP + def _get_tensor_slice_resources(resource_indices, resource_name): + pipe_stage_resources = _get_pipeline_stage_resources(resource_indices) + tp_world_size = mpu_info.tp_world_size + if len(pipe_stage_resources) < tp_world_size: + pipe_stage_resources = _expand_resources(pipe_stage_resources, tp_world_size) + tp_num_resources = len(pipe_stage_resources) + assert tp_num_resources % tp_world_size == 0, \ + f'{resource_name}: Expected tp_num_resources={tp_num_resources} to multiple of tp_world_size={tp_world_size}' + + pipe_stage_resource_partitions = partition_data(pipe_stage_resources, tp_world_size) + tp_rank = mpu_info.tp_rank + return pipe_stage_resource_partitions[tp_rank] + + def _get_model_parallel_slice_resources(): + # Get resources of my dp peer ranks + resources = [(r // self._ranks_per_resource) for r in mpu_info.dp_peer_ranks] + if len(resources) < self._ranks_per_resource: + resources = _expand_resources(resources, self._ranks_per_resource) + + resource_partitions = partition_data(resources, self._ranks_per_resource) + mp_rank = (mpu_info.pp_rank * mpu_info.tp_world_size) + mpu_info.tp_rank + slice_rank = mp_rank % self._ranks_per_resource + return resource_partitions[slice_rank] + + num_slices = mpu_info.tp_world_size * mpu_info.pp_world_size + if num_slices > self._ranks_per_resource: + slice_resources = _get_model_parallel_slice_resources() + else: + all_resources = list(range(self._num_resources)) + slice_resources = _get_tensor_slice_resources(all_resources, self._parallel_unit) + + return self._assign_resources_to_tensor_slice(slice_resources, mpu_info.tp_rank, mpu_info.dp_peer_ranks) + + def _get_slice_writers(self, slice_resources, my_dp_ranks): + resource_map = {} + for res in slice_resources: + resource_map[res] = [r for r in my_dp_ranks if (r // self._ranks_per_resource) == res] + + # Only one writer per resource, and we conventionally pick the first rank as writer. + return [ranks[0] for ranks in resource_map.values()] + + def _assign_resources_to_tensor_slice(self, slice_resources, my_slice_index, my_dp_ranks): + my_global_rank = self._uni_parallel_info.global_rank + slice_writer_ranks = self._get_slice_writers(slice_resources, my_dp_ranks) + my_resource_index = my_global_rank // self._ranks_per_resource + print( + f'resource_assign: my_global_rank={my_global_rank} my_slice_index={my_slice_index} my_dp_ranks={my_dp_ranks} slice_resources={slice_resources} slice_writer_ranks={slice_writer_ranks}' + ) + if my_resource_index in slice_resources and my_global_rank in slice_writer_ranks: + my_writer_index = (my_global_rank - slice_writer_ranks[0]) // self._ranks_per_resource + num_slice_writers = len(slice_writer_ranks) + print( + f'slice_writer: my_global_rank={my_global_rank} my_writer_index={my_writer_index} num_slice_writers={num_slice_writers}' + ) + return self._create_config(num_slice_writers, my_writer_index) + + return None + + def _get_parallel_write_for_ddp(self, dp_world_size, dp_rank): + if self._parallel_unit == CheckpointDataParallel.REPLICA: + return self._create_config(dp_world_size, dp_rank) + + num_machines = self._uni_parallel_info.num_machines + if self._parallel_unit == CheckpointDataParallel.SOCKET: + if dp_world_size == num_machines: + # There is one rank per machine + return self._create_config(num_machines, dp_rank) + + num_sockets = self._uni_parallel_info.num_sockets + ranks_per_socket = dp_world_size // num_sockets + if dp_rank % ranks_per_socket == 0: + return self._create_config(num_sockets, dp_rank // ranks_per_socket) + else: + return None + + ranks_per_machine = dp_world_size // num_machines + if dp_rank % ranks_per_machine == 0: + return self._create_config(num_machines, self._uni_parallel_info.machine_rank) + + return None diff --git a/deepspeed/runtime/model_checkpointing/utils.py b/deepspeed/runtime/model_checkpointing/utils.py new file mode 100644 index 000000000000..e212008a9277 --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from dataclasses import dataclass +from deepspeed import comm as dist +from deepspeed.constants import CROSS_RANK, CROSS_SIZE, LOCAL_RANK +from .data_parallel_writer_factory import DataParallelWriterFactory + +# TODO: parse socket number from env. +SOCKETS_PER_MACHINE = 2 + + +@dataclass +class MPUInfo(object): + pp_world_size: int + pp_rank: int + tp_world_size: int + tp_rank: int + dp_world_size: int + dp_peer_ranks: list + dp_rank: int + + +def _create_model_parallel_info(mpu): + return MPUInfo(pp_world_size=mpu.get_pipeline_model_parallel_world_size(), + pp_rank=mpu.get_pipeline_model_parallel_rank(), + tp_world_size=mpu.get_tensor_model_parallel_world_size(), + tp_rank=mpu.get_tensor_model_parallel_rank(), + dp_world_size=mpu.get_data_parallel_world_size(), + dp_peer_ranks=mpu.get_data_parallel_group_ranks(), + dp_rank=mpu.get_data_parallel_rank()) + + +@dataclass +class ExpertParallelInfo(object): + ep_world_size: int + ep_rank: int + dp_world_size: int + dp_peer_ranks: list + dp_rank: int + + +def _create_expert_parallel_info(groups): + group_name = groups._get_max_expert_size_name() + return ExpertParallelInfo(ep_world_size=groups._get_expert_parallel_world_size(group_name), + ep_rank=groups._get_expert_parallel_rank(group_name), + dp_world_size=groups._get_expert_data_parallel_world_size(group_name), + dp_peer_ranks=groups._get_expert_data_parallel_group_ranks(group_name), + dp_rank=groups._get_expert_data_parallel_rank(group_name)) + + +@dataclass +class UniversalParallelInfo(object): + global_world_size: int + global_rank: int + local_rank: int + mpu_info: MPUInfo + ep_info: ExpertParallelInfo + pure_dp: bool + num_machines: int + machine_rank: int + num_sockets: int + + +def create_universal_parallel_info(groups, has_moe_layers): + return UniversalParallelInfo(global_world_size=dist.get_world_size(), + global_rank=dist.get_rank(), + local_rank=int(os.environ[LOCAL_RANK]), + mpu_info=None if groups.mpu is None else _create_model_parallel_info(groups.mpu), + ep_info=_create_expert_parallel_info(groups) if has_moe_layers else None, + pure_dp=groups.mpu is None + or groups.mpu.get_data_parallel_world_size() == dist.get_world_size(), + num_machines=int(os.environ[CROSS_SIZE]), + machine_rank=int(os.environ[CROSS_RANK]), + num_sockets=int(os.environ[CROSS_SIZE]) * SOCKETS_PER_MACHINE) + + +def create_data_parallel_writer_config(groups, parallel_unit, zero_stage, has_moe_layers): + uni_parallel_info = create_universal_parallel_info(groups, has_moe_layers) + writer_factory = DataParallelWriterFactory(uni_parallel_info, parallel_unit) + return writer_factory.create_config(zero_stage, has_moe_layers) diff --git a/deepspeed/runtime/model_checkpointing/writer_factory.py b/deepspeed/runtime/model_checkpointing/writer_factory.py new file mode 100644 index 000000000000..a8c324530ae5 --- /dev/null +++ b/deepspeed/runtime/model_checkpointing/writer_factory.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder +from deepspeed.io import MockFileWriter, PyFileWriter, FastFileWriter, FastFileWriterConfig +from deepspeed.runtime.swap_tensor.constants import * +from .constants import * +from deepspeed.accelerator import get_accelerator + + +class CheckpointWriterFactory(object): + + def __init__(self, writer_config, aio_config, dp_writer_config): + self._type = writer_config[CHECKPOINT_WRITER_TYPE] + self._io_buffer_size = writer_config[CHECKPOINT_IO_BUFFER_SIZE] + self._io_buffer_double = writer_config[CHECKPOINT_IO_BUFFER_DOUBLE] + self._data_parallel_writer = dp_writer_config + self._io_multiplier = writer_config[CHECKPOINT_IO_MULTIPLIER] + if self._data_parallel_writer.pure_dp: + self._show_statistics = writer_config[CHECKPOINT_IO_STATISTICS] and self._data_parallel_writer is not None + else: + self._show_statistics = writer_config[CHECKPOINT_IO_STATISTICS] and self._data_parallel_writer is not None + self._io_buffer = None + self._dnvme_handle = None + self._writer = None + self._use_gds = False + + if self._type == CheckpointWriterType.FAST: + self._use_gds = aio_config[AIO_USE_GDS] + if self._use_gds: + self._setup_for_gds(aio_config) + else: + self._setup_for_aio(aio_config) + print( + f'WriterFactory: self._data_parallel_writer={self._data_parallel_writer} self._show_statistics={self._show_statistics}' + ) + + def create_writer(self, file_path, optimize_dp_state): + assert self._writer is None, \ + f'Cannot create checkpoint writer for {file_path} because writer is currently used for {self._writer.file_path()}.\ + Must call writer.release() before reusing to avoid this error.' + + if self._type == CheckpointWriterType.MOCK: + self._writer = MockFileWriter(file_path) + elif self._type == CheckpointWriterType.PYTHON: + self._writer = PyFileWriter(file_path) + else: + if optimize_dp_state: + num_parallel_writers = self._data_parallel_writer.world_size * self._io_multiplier + writer_rank = self._data_parallel_writer.rank + file_path = f'{file_path}-{writer_rank}.{num_parallel_writers}' + # print(f'create_dp_writer: {self._data_parallel_writer.global_rank=} {writer_rank=} {num_parallel_writers=} {file_path=}') + else: + num_parallel_writers = 1 + writer_rank = 0 + # print(f'create_rank0_writer: {self._data_parallel_writer.global_rank=} {writer_rank=} {num_parallel_writers=} {file_path=}') + + config = FastFileWriterConfig(dnvme_handle=self._dnvme_handle, + pinned_tensor=self._io_buffer, + double_buffer=self._io_buffer_double, + num_parallel_writers=num_parallel_writers, + writer_rank=writer_rank, + global_rank=self._data_parallel_writer.global_rank) + self._writer = FastFileWriter(file_path=file_path, config=config) + + return self._writer + + def release_writer(self): + self._writer.close() + if self._show_statistics: + self._writer._dump_state() + self._writer = None + + def _setup_for_aio(self, aio_config): + self._io_buffer = torch.zeros(self._io_buffer_size, dtype=torch.uint8, device='cpu').pin_memory() + self._dnvme_handle = AsyncIOBuilder().load().aio_handle( + block_size=aio_config[AIO_BLOCK_SIZE], + queue_depth=aio_config[AIO_QUEUE_DEPTH], + single_submit=aio_config[AIO_SINGLE_SUBMIT], + overlap_events=aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM]) + + def _setup_for_gds(self, aio_config): + self._io_buffer = torch.zeros(self._io_buffer_size, + dtype=torch.uint8, + device=get_accelerator().current_device_name()) + self._dnvme_handle = GDSBuilder().load().gds_handle(block_size=aio_config[AIO_BLOCK_SIZE], + queue_depth=aio_config[AIO_QUEUE_DEPTH], + single_submit=aio_config[AIO_SINGLE_SUBMIT], + overlap_events=aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM]) + self._dnvme_handle.pin_device_tensor(self._io_buffer) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 4873b6f48453..46fbb4ebaf05 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -4,17 +4,27 @@ # DeepSpeed Team from types import MethodType +from collections import OrderedDict +from functools import reduce +from operator import mul import torch from deepspeed import comm as dist from deepspeed.utils import logger from deepspeed.utils.timer import ThroughputTimer -from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE +from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \ + BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER, \ + BACKWARD_REDUCE_MICRO_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ + STEP_MICRO_TIMER, STEP_GLOBAL_TIMER + from ..utils import PartitionedTensor from ..dataloader import RepeatingLoader +from ..zero.config import ZeroStageEnum +from ..activation_checkpointing import checkpointing as ds_checkpointing from .module import PipelineModule, PipelineError from . import p2p @@ -24,6 +34,16 @@ LOG_STAGE = -2 DATA_PARALLEL_ID = -2 +BATCH_INPUT_TIMER = 'batch_input' +TRAIN_BATCH_TIMER = 'train_batch' +PIPE_SEND_OUTPUT_TIMER = 'pipe_send_output' +PIPE_SEND_GRAD_TIMER = 'pipe_send_grad' +PIPE_RECV_INPUT_TIMER = 'pipe_recv_input' +PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad' + +# The buffer size to store the meta data for each tensor. +TENSOR_META_SIZE = 256 + def is_even(number): return number % 2 == 0 @@ -53,13 +73,16 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) assert isinstance(self.module, PipelineModule), "model must base PipelineModule" - assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" + assert self.zero_optimization_stage( + ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False self.has_bool_tensors = has_bool_tensors self.eval_return_logits = False self.outputs = None + # BF16 Optimizer is hardcoded for fp32 gradient accumulation + self.using_bf16_optimizer = type(self.optimizer) == BF16_Optimizer # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB self.pipeline_enable_backward_allreduce = True @@ -98,7 +121,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self._force_grad_boundary = False - self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(), + self.batch_timer = ThroughputTimer(self._config.timers_config, + batch_size=self.train_batch_size(), logging_fn=self.tput_log, monitor_memory=False, steps_per_output=self.steps_per_print()) @@ -114,8 +138,12 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # Partition input/output buffers # XXX temporarily disable while I revert some partition hacks. - self.is_pipe_partitioned = self.is_model_parallel - self.is_grad_partitioned = self.is_model_parallel + assert isinstance(self._config.pipeline['pipe_partitioned'], bool) + assert isinstance(self._config.pipeline['grad_partitioned'], bool) + self.is_pipe_partitioned = self.is_model_parallel and self._config.pipeline['pipe_partitioned'] + self.is_grad_partitioned = self.is_model_parallel and self._config.pipeline['grad_partitioned'] + logger.info(f'is_pipe_partitioned= {self.is_pipe_partitioned} ' + f'is_grad_partitioned= {self.is_grad_partitioned}') model_parameters = filter(lambda p: p.requires_grad, self.module.parameters()) num_params = sum([p.numel() for p in model_parameters]) @@ -155,22 +183,43 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): } self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None self.first_output_send = True self.first_gradient_send = True + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None #stores the loss for the current micro batch being processed self.loss = torch.tensor(0.0).to(self.device) #stores the loss for the entire batch self.total_loss = None + self.total_additional_losses = None self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device) self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device) + # stores aggregated-DP train final loss and aggregated-DP additional losses, if any + # additional losses are stored as dict: {loss-name: agg-loss} + self.agg_train_loss = None + self.agg_additional_losses = None + if self._config.pipeline['activation_checkpoint_interval'] > 0: self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval'] + # set use_reentrant default to True. + if self._config.pipeline.get('use_reentrant') is None: + self._config.pipeline['use_reentrant'] = True + if self._config.pipeline['use_reentrant'] is False: + # set activation_checkpoint_func to non_reentrant_checkpoint func. + self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint + if self.grid.get_global_rank() == 0: + logger.info('CONFIG: activation_checkpoint_func=non_reentrant_checkpoint') + if self.module.activation_checkpoint_interval > 0: + self.module._precompute_checkpointable_values() self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline @@ -193,18 +242,20 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # XXX look into timer reporting timing # Initialize some timers because of early weirdness. if self.wall_clock_breakdown(): - self.timers('forward_microstep').start() - self.timers('forward_microstep').stop() - self.timers('backward_microstep').start() - self.timers('backward_microstep').stop() - self.timers('backward_inner_microstep').start() - self.timers('backward_inner_microstep').stop() - self.timers('backward_allreduce_microstep').start() - self.timers('backward_allreduce_microstep').stop() - self.timers('backward_allreduce').start() - self.timers('backward_allreduce').stop() - self.timers('step_microstep').start() - self.timers('step_microstep').stop() + self.timers(FORWARD_MICRO_TIMER).start() + self.timers(FORWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_MICRO_TIMER).start() + self.timers(BACKWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_INNER_MICRO_TIMER).start() + self.timers(BACKWARD_INNER_MICRO_TIMER).stop() + self.timers(BACKWARD_REDUCE_MICRO_TIMER).start() + self.timers(BACKWARD_REDUCE_MICRO_TIMER).stop() + self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).start() + self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).stop() + self.timers(STEP_MICRO_TIMER).start() + self.timers(STEP_MICRO_TIMER).stop() + + self.dynamic_shape = self.module.dynamic_shape def set_has_attention_mask(self, value): assert isinstance(value, bool) @@ -234,27 +285,22 @@ def _exec_reduce_tied_grads(self): weight_group_list = self.module.get_tied_weights_and_groups() for weight, group in weight_group_list: - grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad - dist.all_reduce(grad, group=group) + grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad + if grad is not None: + dist.all_reduce(grad, group=group) def _exec_reduce_grads(self): self._force_grad_boundary = True if self.pipeline_enable_backward_allreduce: - if self.bfloat16_enabled(): - if self.zero_optimization_stage() == 0: - self._bf16_reduce_grads() - else: - assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported" - raise NotImplementedError() + if self.using_bf16_optimizer: + # PP+BF16 work for ZeRO Stage 1 + self._bf16_reduce_grads() else: self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) self._force_grad_boundary = False def _bf16_reduce_grads(self): - # Make our own list of gradients from the optimizer's FP32 grads - grads = [] - self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(), - elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + self.buffered_allreduce_fallback(grads=None, elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) def _reserve_pipe_buffers(self, num_buffers): """Ensure that each pipeline buffer has at least ``num_buffers`` slots. @@ -280,8 +326,14 @@ def reset_activation_shape(self): self.first_output_send = True self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None + def train_batch(self, data_iter=None): """Progress the pipeline to train the next batch of data. The engine will ingest ``self.train_batch_size()`` total samples collectively across all workers. @@ -307,7 +359,7 @@ def train_batch(self, data_iter=None): The arithmetic mean of the losses computed this batch. """ if not torch._C.is_grad_enabled(): - raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.') + raise RuntimeError('train_batch() requires gradients enabled. Use eval_batch() instead.') # Curriculum learning could change activation shape if self.curriculum_enabled_legacy(): @@ -320,46 +372,65 @@ def train_batch(self, data_iter=None): self.global_steps): self.reset_activation_shape() - if data_iter: + if data_iter is not None: self.set_dataiterator(data_iter) self.module.train() self.total_loss = None + self.total_additional_losses = None self._compute_loss = True # Do the work - self.timers('train_batch').start() + self.timers(TRAIN_BATCH_TIMER).start() sched = schedule.TrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id) self._exec_schedule(sched) - self.agg_train_loss = self._aggregate_total_loss() - self.timers('train_batch').stop() + with torch.no_grad(): + self.agg_train_loss = self._aggregate_total_loss() - if self.global_steps % self.steps_per_print() == 0: + self.timers(TRAIN_BATCH_TIMER).stop() + + if self.steps_per_print() is not None and self.global_steps % self.steps_per_print() == 0: if self.global_rank == 0: - elapsed = self.timers('train_batch').elapsed(reset=True) / 1000.0 + elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0 iter_time = elapsed / self.steps_per_print() tput = self.train_batch_size() / iter_time - print(f'steps: {self.global_steps} ' - f'loss: {self.agg_train_loss:0.4f} ' - f'iter time (s): {iter_time:0.3f} ' - f'samples/sec: {tput:0.3f}') + log_str = f'steps: {self.global_steps} loss: {self.agg_train_loss:0.4f} ' + if self.agg_additional_losses is not None: + for loss_name, loss_value in self.agg_additional_losses.items(): + log_str += f'{loss_name}: {loss_value.item():0.4f} ' + log_str += f'iter time (s): {iter_time:0.3f} samples/sec: {tput:0.3f}' + print(log_str) + else: + self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) # Monitoring if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(), - self.global_samples)] + self.summary_events = [('Train/Samples/train_loss', self.agg_train_loss.mean().item(), self.global_samples) + ] self.monitor.write_events(self.summary_events) - if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0: - self.timers.log(['pipe_send_output', 'pipe_send_grad', 'pipe_recv_input', 'pipe_recv_grad']) + if self.steps_per_print() is not None and self.wall_clock_breakdown( + ) and self.global_steps % self.steps_per_print() == 0: + self.timers.log([ + PIPE_SEND_OUTPUT_TIMER, + PIPE_SEND_GRAD_TIMER, + PIPE_RECV_INPUT_TIMER, + PIPE_RECV_GRAD_TIMER, + ]) # TODO: should return precisely what loss returned and allow others to be queried? return self.agg_train_loss - def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'): + def eval_batch(self, + data_iter, + return_logits=False, + compute_loss=True, + reduce_output='avg', + bcast_loss=True, + num_micro_batches=None): """Evaluate the pipeline on a batch of data from ``data_iter``. The engine will evaluate ``self.train_batch_size()`` total samples collectively across all workers. @@ -408,10 +479,11 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o train_iterator = self.data_iterator self.set_dataiterator(data_iter) + # set the number micro batches in case the user chose value than training + micro_batches = self.micro_batches if num_micro_batches is None else num_micro_batches + # Do the work - sched = schedule.InferenceSchedule(micro_batches=self.micro_batches, - stages=self.num_stages, - stage_id=self.stage_id) + sched = schedule.InferenceSchedule(micro_batches=micro_batches, stages=self.num_stages, stage_id=self.stage_id) # prevent dead-lock with multiple evals sequence dist.barrier() @@ -420,13 +492,13 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o self._exec_schedule(sched) if self.is_last_stage(): - eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output) + eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output, micro_batches=micro_batches) - if compute_loss: + if compute_loss and (bcast_loss or self.monitor.enabled): eval_output = self._bcast_pipe_scalar(eval_output) if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)] + self.summary_events = [('Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)] self.monitor.write_events(self.summary_events) # Restore the training iterator @@ -462,7 +534,10 @@ def is_last_stage(self): """True if this process is in the last stage in the pipeline.""" return self.stage_id == self.num_stages - 1 - def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): + def get_pipeline_parallel_rank(self): + return self.stage_id + + def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None): if reduce is None: return outputs @@ -477,7 +552,7 @@ def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): reduced[idx] += out # Average over the microbatches - reduced = self._scale_loss_by_gas(reduced) + reduced = self._scale_loss_by_gas(reduced, eval_micro_batches=micro_batches) # Average over DP groups if reduce_dp and self.is_data_parallel: @@ -500,7 +575,7 @@ def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): assert src_rank in self.grid.pp_group if self.global_rank == src_rank: - result = data.clone().detach() + result = data.clone().detach().type(dtype).to(self.device) else: result = torch.Tensor([0.]).type(dtype).to(self.device) @@ -511,29 +586,67 @@ def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): def _aggregate_total_loss(self): # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group if self.is_last_stage(): + # Scale loss and additional losses, if any loss = self._scale_loss_by_gas(self.total_loss) - self.dp_group_loss = loss.clone().detach() + self.agg_additional_losses = self.total_additional_losses + if self.agg_additional_losses is not None: + self.agg_additional_losses = OrderedDict({ + loss_name: self._scale_loss_by_gas(_loss.clone().detach()) + for loss_name, _loss in self.agg_additional_losses.items() + }) - ## Average loss across all data-parallel groups + self.dp_group_loss = loss.clone().detach() agg_loss = self.dp_group_loss.clone().detach() #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True) + + # Average loss across all data-parallel groups if self.is_data_parallel: - dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) - agg_loss /= self.dp_world_size + if self.agg_additional_losses is None: + dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) + agg_loss /= self.dp_world_size + else: + # use a single reduce op for agg_loss and additional losses, if any + assert '__train_loss__' not in self.agg_additional_losses.keys() + tensors = OrderedDict({'__train_loss__': agg_loss}) + tensors.update(self.agg_additional_losses.items()) + flat_tensor = torch.cat([t.clone().reshape(-1).detach() for t in tensors.values()]) + dist.all_reduce(flat_tensor, group=self.mpu.get_data_parallel_group()) + flat_tensor /= self.dp_world_size + offset = 0 + reduced_tensor = {} + for name, t in tensors.items(): + n_elem = t.numel() + reduced_tensor[name] = flat_tensor[offset:offset + n_elem].clone().detach().reshape(t.shape) + offset += n_elem + agg_loss = reduced_tensor['__train_loss__'] + self.agg_additional_losses = OrderedDict( + {name: reduced_tensor[name] + for name in self.agg_additional_losses.keys()}) assert self.global_rank in self.grid.pp_group - losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) - dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) - + losses = [self.dp_group_loss, agg_loss] + if self.agg_additional_losses is not None: + losses += list(self.agg_additional_losses.values()) + losses = torch.stack(losses).float() + if self.is_pipe_parallel: + dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: # Get loss from last stage src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group - losses = torch.Tensor([0., 0.]).to(self.device) + # losses to reduce are: dp_group_loss, agg_loss, model additional losses + # therefore: 2 + n_additional_losses + additional_losses = self.module.get_additional_losses() + n_additional_losses = 0 if additional_losses is None else len(additional_losses) + losses = torch.Tensor([0.] * (2 + n_additional_losses)).to(self.device) dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group()) self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() - + if additional_losses is not None: + self.agg_additional_losses = OrderedDict({ + name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys()) + }) return agg_loss def set_dataloader(self, loader): @@ -598,7 +711,6 @@ def _next_batch(self): def _exec_forward_pass(self, buffer_id): self.tput_timer.start() - self.mem_status('BEFORE FWD', reset_max=True) if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) @@ -607,7 +719,9 @@ def _exec_forward_pass(self, buffer_id): # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): - part_input = PartitionedTensor.from_meta(meta=inputs[0], + if self.pipe_partition_input_meta_cache is None: + self.pipe_partition_input_meta_cache = inputs[0].to('cpu') + part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache, local_part=inputs[1], group=self.grid.get_slice_parallel_group()) @@ -619,12 +733,14 @@ def _exec_forward_pass(self, buffer_id): inputs = inputs[0] if len(inputs) == 1 else inputs self.pipe_buffers['inputs'][buffer_id] = inputs - # Zero out the gradients each time we use the tensor because only the data in - # tensor changes across batches - self._zero_grads(inputs) - + # inputs has no gradient because it is from a cloned tensor outputs = super().forward(inputs) + # Reset activation checkpointing buffers. + # Need to call this between evaluation iterations + if not self.module.training: + ds_checkpointing.reset() + # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): if isinstance(outputs, tuple): @@ -639,7 +755,7 @@ def _exec_forward_pass(self, buffer_id): raise ValueError("expecting a tensor or a tuple of tensors") part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group()) # Clear the large output data, but save the computation graph - first_output.data = torch.zeros(1) + first_output.data = torch.zeros(1, device=first_output.data.device) self.pipe_buffers['output_tensors'][buffer_id] = first_output # Inject the partitioned tensor into the output before sending outputs = (part.to_meta(), part.data(), *outputs_tail) @@ -657,46 +773,60 @@ def _exec_forward_pass(self, buffer_id): self.loss = outputs if self.eval_return_logits: self.outputs = outputs + if isinstance(self.loss, torch.Tensor): self.fwd_outputs.append(self.loss.detach()) - - if self.total_loss is None: - self.total_loss = torch.zeros_like(self.loss) - self.total_loss += self.loss.detach() else: self.fwd_outputs.append([l.detach() for l in self.loss]) - if self.total_loss is None: - self.total_loss = [torch.zeros_like(l) for l in self.loss] - for idx, l in enumerate(self.loss): - self.total_loss[idx] += l.detach() + def add_to_total_loss(_total_loss, _loss): + if isinstance(_loss, torch.Tensor): + if _total_loss is None: + _total_loss = torch.zeros_like(_loss) + _total_loss += _loss.detach() + else: + if _total_loss is None: + _total_loss = [torch.zeros_like(_l) for _l in _loss] + for _idx, _l in enumerate(_loss): + _total_loss[_idx] += _l.detach() + return _total_loss + + self.total_loss = add_to_total_loss(self.total_loss, self.loss) + + # aggregate additional losses across gradient accumulation steps + additional_losses = self.module.get_additional_losses() + if additional_losses is not None: + if self.total_additional_losses is None: + self.total_additional_losses = OrderedDict() + for name, loss in additional_losses.items(): + total = self.total_additional_losses[name] if name in self.total_additional_losses else None + self.total_additional_losses[name] = add_to_total_loss(total, loss) def _exec_backward_pass(self, buffer_id): assert self.optimizer is not None, "must provide optimizer during " \ "init in order to use backward" - self.mem_status('BEFORE BWD', reset_max=True) - # The last stage just runs backward on the loss using DeepSpeed's typical # mechanisms. if self.is_last_stage(): super().backward(self.loss) - self.mem_status('AFTER BWD') return outputs = self.pipe_buffers['outputs'][buffer_id] if self.wall_clock_breakdown(): - self.timers('backward_microstep').start() - self.timers('backward').start() - self.timers('backward_inner_microstep').start() - self.timers('backward_inner').start() + self.timers(BACKWARD_MICRO_TIMER).start() + self.timers(BACKWARD_GLOBAL_TIMER).start() + self.timers(BACKWARD_INNER_MICRO_TIMER).start() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).start() # Reconstruct if we previously partitioned the output. We must be # careful to also restore the computational graph of the tensors we partitioned. if self.is_pipe_partitioned: if self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_output_meta_cache is None: + self.pipe_partition_output_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() @@ -709,28 +839,40 @@ def _exec_backward_pass(self, buffer_id): grad_tensors = self.grad_layer if self.is_grad_partitioned: #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0], + if self.grad_partition_grad_layer_meta_cache is None: + self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu') + part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache, local_part=self.grad_layer[1], group=self.grid.get_slice_parallel_group()) grad_tensors = (part_grad.full(), *grad_tensors[2:]) part_grad = None #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - if self.bfloat16_enabled() and not self.is_last_stage(): + if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() - # This handles either a single tensor or tuple of tensors. - if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if t.is_floating_point()] - assert len(out_tensors) == len(grad_tensors) - torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) - else: - torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + # Set _running_engine_backward to avoid RuntimeError in post-backward hook + # when needs_scaler=True (the hook checks this flag to skip error checking) + self._running_engine_backward = True + try: + # Use tensor.backward(gradient) style which is now supported by DeepSpeed. + # This properly integrates with DeepSpeed's hooks and loss scaling. + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.is_floating_point()] + assert len(out_tensors) == len(grad_tensors) + # For multiple tensors, use retain_graph for all but the last + for i, (out, grad) in enumerate(zip(out_tensors, grad_tensors)): + out.backward(gradient=grad, retain_graph=(i < len(out_tensors) - 1)) + else: + outputs.backward(gradient=grad_tensors) + finally: + self._running_engine_backward = False - if self.bfloat16_enabled() and not self.is_last_stage(): + if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward() - self.optimizer.update_hp_grads(clear_lp_grads=False) + if not self._config.bfloat16_config.immediate_grad_update: + self.optimizer.update_hp_grads(clear_lp_grads=False) # Free up the memory from the output of forward() self.pipe_buffers['output_tensors'][buffer_id] = None @@ -738,16 +880,14 @@ def _exec_backward_pass(self, buffer_id): grad_tensors = None if self.wall_clock_breakdown(): - self.timers('backward_inner').stop() - self.timers('backward_inner_microstep').stop() - self.timers('backward').stop() - self.timers('backward_microstep').stop() - - self.mem_status('AFTER BWD') + self.timers(BACKWARD_INNER_MICRO_TIMER).stop() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop() + self.timers(BACKWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_GLOBAL_TIMER).stop() def _exec_load_micro_batch(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('batch_input').start() + self.timers(BATCH_INPUT_TIMER).start() batch = self._next_batch() @@ -755,7 +895,9 @@ def _exec_load_micro_batch(self, buffer_id): loaded = None if torch.is_tensor(batch[0]): loaded = batch[0].clone().to(self.device).detach() - loaded.requires_grad = loaded.is_floating_point() + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + loaded.requires_grad = loaded.is_floating_point() else: assert isinstance(batch[0], (tuple, list)) # Assume list or tuple @@ -763,7 +905,9 @@ def _exec_load_micro_batch(self, buffer_id): for x in batch[0]: assert torch.is_tensor(x) mine = x.clone().detach().to(self.device) - mine.requires_grad = mine.is_floating_point() + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + mine.requires_grad = mine.is_floating_point() loaded.append(mine) loaded = tuple(loaded) @@ -773,7 +917,8 @@ def _exec_load_micro_batch(self, buffer_id): loaded = batch[1] if torch.is_tensor(batch[1]): loaded = batch[1].to(self.device) - elif isinstance(batch[1], tuple): + # XXX: torch 1.6.0 DataLoader will auto convert tuple to list + elif isinstance(batch[1], (tuple, list)): loaded = [] for x in batch[1]: assert torch.is_tensor(x) @@ -784,7 +929,7 @@ def _exec_load_micro_batch(self, buffer_id): self.pipe_buffers['labels'][buffer_id] = loaded if self.wall_clock_breakdown(): - self.timers('batch_input').stop() + self.timers(BATCH_INPUT_TIMER).stop() def _send_tensor_meta(self, buffer, recv_stage): """ Communicate metadata about upcoming p2p transfers. @@ -796,51 +941,38 @@ def _send_tensor_meta(self, buffer, recv_stage): * ndims * shape """ - send_bytes = 0 + meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) if isinstance(buffer, torch.Tensor): - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.send(type_tensor, recv_stage) - send_shape = torch.LongTensor(data=buffer.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(buffer) - elif isinstance(buffer, list): - assert (False) - type_tensor = torch.LongTensor(data=[1]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for tensor in buffer: - assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(tensor) + meta_buf_list = [ + 0, # type of data (0: tensor, 1: list (unused), 2: tuple) + self.DTYPE_TO_ID[buffer.dtype], # dtype + len(buffer.size()) # ndims + ] + meta_buf_list.extend(buffer.size()) + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + elif isinstance(buffer, tuple): - type_tensor = torch.LongTensor(data=[2]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for idx, tensor in enumerate(buffer): + meta_buf_list = [ + 2, # type of data (0: tensor, 1: list (unused), 2: tuple) + len(buffer) # num_tensors + ] + + for tensor in buffer: assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device) - p2p.send(send_dtype, recv_stage) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - # Useful for performance debugging. - ''' - new_bytes = _tensor_bytes(tensor) - send_bytes += _tensor_bytes(tensor) - # Useful for performance debugging. - if self.grid.data_parallel_id == 0: - print( - f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB' - ) - ''' + meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype]) + meta_buf_list.append(len(tensor.size())) + meta_buf_list.extend(tensor.size()) + + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + else: raise NotImplementedError(f'Could not send meta type {type(buffer)}') @@ -853,49 +985,35 @@ def _send_tensor_meta(self, buffer, recv_stage): def _recv_tensor_meta(self, send_stage): """Receive metadata about upcoming p2p transfers and return allocated buffers. - Metadata is communicated in this order: - * type (0: tensor, 1: list) - * num_tensors if type=list - foreach tensor in buffer: - * ndims - * shape - Returns: Allocated buffer for receiving from send_stage. """ + buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + p2p.recv(buffer, send_stage) - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(type_tensor, send_stage) - recv_type = type_tensor.item() + recv_type = buffer[0].item() # A single tensor will be sent. if recv_type == 0: - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shape = recv_shape.tolist() - return self._allocate_buffer(recv_shape, num_buffers=1)[0] - - # List or tuple of tensors + recv_dtype = self.ID_TO_DTYPE[buffer[1].item()] + recv_ndims = buffer[2].item() + recv_shape = buffer[3:3 + recv_ndims].tolist() + return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype) + + # List or tuple of tensors (recv_type == 1 (list) is currently unused) elif recv_type == 1 or recv_type == 2: - count_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(count_tensor, send_stage) - num_tensors = count_tensor.item() - recv_shapes_and_dtypes = [] + num_tensors = buffer[1].item() + + buffers = [] + offset = 2 for idx in range(num_tensors): - recv_dtype = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_dtype, send_stage) - recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()] - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype)) - - buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0] + recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()] + recv_ndims = buffer[offset + 1].item() + recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist() + offset += 2 + recv_ndims + + buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype)) + # Convert to tuples if requested. if recv_type == 2: buffers = tuple(buffers) @@ -906,7 +1024,7 @@ def _recv_tensor_meta(self, send_stage): def _exec_send_activations(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_send_output').start() + self.timers(PIPE_SEND_OUTPUT_TIMER).start() outputs = self.pipe_buffers['outputs'][buffer_id] @@ -918,7 +1036,7 @@ def _exec_send_activations(self, buffer_id): outputs[-1] = outputs[-1].half() outputs = tuple(outputs) - if self.first_output_send: + if self.dynamic_shape or self.first_output_send: self.first_output_send = False self._send_tensor_meta(outputs, self.next_stage) @@ -938,11 +1056,11 @@ def _exec_send_activations(self, buffer_id): outputs = tuple(outputs) if self.wall_clock_breakdown(): - self.timers('pipe_send_output').stop() + self.timers(PIPE_SEND_OUTPUT_TIMER).stop() def _exec_send_grads(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_send_grad').start() + self.timers(PIPE_SEND_GRAD_TIMER).start() inputs = self.pipe_buffers['inputs'][buffer_id] @@ -951,7 +1069,7 @@ def _exec_send_grads(self, buffer_id): if isinstance(inputs, tuple): first_input = inputs[0] assert all([torch.is_tensor(elt) for elt in inputs[1:]]) - inputs_grad_tail = [elt.grad for elt in inputs[1:] if elt.grad is not None] + inputs_grad_tail = [elt.grad for elt in inputs[1:]] elif torch.is_tensor(inputs): first_input = inputs inputs_grad_tail = [] @@ -994,16 +1112,16 @@ def _exec_send_grads(self, buffer_id): self.pipe_buffers['inputs'][buffer_id] = None if self.wall_clock_breakdown(): - self.timers('pipe_send_grad').stop() + self.timers(PIPE_SEND_GRAD_TIMER).stop() def _exec_recv_activations(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_recv_input').start() + self.timers(PIPE_RECV_INPUT_TIMER).start() recvd = None # Allocate the buffer if necessary - if self.pipe_recv_buf is None: + if self.dynamic_shape or self.pipe_recv_buf is None: self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage) if isinstance(self.pipe_recv_buf, torch.Tensor): @@ -1037,17 +1155,19 @@ def _exec_recv_activations(self, buffer_id): self.pipe_buffers['inputs'][buffer_id] = recvd if self.wall_clock_breakdown(): - self.timers('pipe_recv_input').stop() + self.timers(PIPE_RECV_INPUT_TIMER).stop() def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_recv_grad').start() + self.timers(PIPE_RECV_GRAD_TIMER).start() outputs = self.pipe_buffers['outputs'][buffer_id] # XXX these shapes are hardcoded for Megatron # Restore partitioned output if it was partitioned and we are sending full gradients if self.is_pipe_partitioned and not self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_grad_meta_cache is None: + self.pipe_partition_grad_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) outputs[0].data = part_output.full() @@ -1056,10 +1176,9 @@ def _exec_recv_grads(self, buffer_id): self.pipe_buffers['outputs'][buffer_id] = outputs # Allocate gradient if necessary - if self.grad_layer is None: + if self.dynamic_shape or self.grad_layer is None: if isinstance(outputs, torch.Tensor): - s = list(outputs.size()) - self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0] + self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype) else: # XXX This is a HACK # When we exchange activations/gradients, the two pipe stages @@ -1081,7 +1200,11 @@ def _exec_recv_grads(self, buffer_id): for t in outputs[2:] if t.is_floating_point()] else: sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()] - self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0] + + self.grad_layer = [ + self._allocate_or_extend_buffers(i, size, dtype) + for i, (size, dtype) in enumerate(sizes_and_dtypes) + ] if isinstance(self.grad_layer, torch.Tensor): p2p.recv(self.grad_layer, self.next_stage) @@ -1094,46 +1217,44 @@ def _exec_recv_grads(self, buffer_id): p2p.recv(buffer, self.next_stage) if self.wall_clock_breakdown(): - self.timers('pipe_recv_grad').stop() + self.timers(PIPE_RECV_GRAD_TIMER).stop() def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): - self.timers('step_microstep').start() - self.timers('step').start() - self.mem_status('BEFORE STEP', reset_max=True) + self.timers(STEP_MICRO_TIMER).start() + self.timers(STEP_GLOBAL_TIMER).start() self._force_grad_boundary = True self._take_model_step(lr_kwargs) self._force_grad_boundary = False - self.mem_status('AFTER STEP') - if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)] - if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'): - self.summary_events.append( - (f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples)) + self.summary_events = [('Train/Samples/lr', self.get_lr()[0], self.global_samples)] + loss_scale = self._get_optimizer_loss_scale() if self.fp16_enabled() else None + if loss_scale is not None: + self.summary_events.append(('Train/Samples/loss_scale', loss_scale, self.global_samples)) self.monitor.write_events(self.summary_events) if self.wall_clock_breakdown(): - self.timers('step_microstep').stop() - self.timers('step').stop() + self.timers(STEP_MICRO_TIMER).stop() + self.timers(STEP_GLOBAL_TIMER).stop() if self.global_steps % self.steps_per_print() == 0: self.timers.log([ - 'batch_input', 'forward_microstep', 'backward_microstep', 'backward_inner_microstep', - 'backward_allreduce_microstep', 'backward_tied_allreduce_microstep', 'step_microstep' + BATCH_INPUT_TIMER, + FORWARD_MICRO_TIMER, + BACKWARD_MICRO_TIMER, + BACKWARD_INNER_MICRO_TIMER, + BACKWARD_REDUCE_MICRO_TIMER, + STEP_MICRO_TIMER, ]) if self.global_steps % self.steps_per_print() == 0: - self.timers.log(['forward', 'backward', 'backward_inner', 'backward_allreduce', 'step']) - - def _zero_grads(self, inputs): - if isinstance(inputs, torch.Tensor): - if inputs.grad is not None: - inputs.grad.data.zero_() - else: - for t in inputs: - if t.grad is not None: - t.grad.data.zero_() + self.timers.log([ + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + BACKWARD_INNER_GLOBAL_TIMER, + BACKWARD_REDUCE_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, + ]) def _allocate_zeros(self, shape, **kwargs): """ Allocate a tensor of zeros on the engine's device. @@ -1161,16 +1282,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs): buffers.append(self._allocate_zeros(shape, **kwargs)) return buffers - def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1): - buffers = [] - if num_buffers == -1: - num_buffers = self.num_pipe_buffers - for count in range(num_buffers): - buffer = [] - for shape, dtype in shapes_and_dtypes: - buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad)) - buffers.append(buffer) - return buffers + def _allocate_or_extend_buffers(self, idx, shape, dtype): + numel = reduce(mul, shape) if len(shape) > 0 else 1 + if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel: + new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0] + if len(self._grad_layer_buf) <= idx: + self._grad_layer_buf.append(new_buf) + else: + self._grad_layer_buf[idx] = new_buf + return self._grad_layer_buf[idx] + else: + return self._grad_layer_buf[idx].flatten()[:numel].view(shape) def forward(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ @@ -1184,54 +1306,7 @@ def step(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ raise PipelineError("Only train_batch() is accessible in pipeline mode.") - def mem_status(self, msg, print_rank=-1, reset_max=False): - return - global mem_alloced, mem_cached - if not self.global_steps == 0 or not self.global_steps == 9: - #return - pass - if self.mpu.get_data_parallel_rank() != 0: - return - - if self.global_rank != 0: - return - - rank = self.global_rank - if print_rank != -1 and rank != print_rank: - return - - get_accelerator().synchronize() - - if reset_max: - get_accelerator().reset_max_memory_cached() - get_accelerator().reset_max_memory_allocated() - - new_alloced = get_accelerator().memory_allocated() - new_cached = get_accelerator().memory_cached() - - delta_alloced = new_alloced - mem_alloced - delta_cached = new_cached - mem_cached - - mem_cached = new_cached - mem_alloced = new_alloced - - max_alloced = get_accelerator().max_memory_allocated() - max_cached = get_accelerator().max_memory_cached() - - # convert to GB for printing - new_alloced /= 1024**3 - new_cached /= 1024**3 - delta_alloced /= 1024**3 - delta_cached /= 1024**3 - max_alloced /= 1024**3 - max_cached /= 1024**3 - - print( - f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg, - f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) ' - f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)') - - def module_state_dict(self): + def module_state_dict(self, exclude_frozen_parameters=False): """Override hack to save a pipe model and return the directory path of the save. This method should only be called by DeepSpeed's ``save_checkpoint()``. The @@ -1245,10 +1320,12 @@ def module_state_dict(self): assert self._curr_ckpt_path is not None, \ "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" - self.module.save_state_dict(self._curr_ckpt_path, checkpoint_engine=self.checkpoint_engine) + self.module.save_state_dict(self._curr_ckpt_path, + checkpoint_engine=self.checkpoint_engine, + exclude_frozen_params=exclude_frozen_parameters) return None - def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): + def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): """Override hack to instead use a directory path. This is important because pipeline models checkpoint by layer instead of rank. @@ -1260,6 +1337,7 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): strict (bool, optional): Strict state loading. Defaults to True. """ assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism" + state_dict = checkpoint if self.has_moe_layers else checkpoint['module'] if (state_dict is not None) and (not isinstance(state_dict, str)): super().load_module_state_dict(state_dict, strict) return @@ -1298,3 +1376,6 @@ def _exec_schedule(self, pipe_schedule): # Equivalent to: self._exec_forward_pass(buffer_id=0) self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self) self._exec_instr(**cmd.kwargs) + + def get_additional_losses(self): + return self.agg_additional_losses diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 7bf9c7d973b1..36f54facc05c 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -20,6 +20,7 @@ from .topology import PipeDataParallelTopology, PipelineParallelGrid from deepspeed.runtime.state_dict_factory import SDLoaderFactory from deepspeed.accelerator import get_accelerator +from deepspeed.checkpoint.utils import clone_tensors_for_torch_save class PipelineError(Exception): @@ -75,11 +76,11 @@ def build(self, log=False): class TiedLayerSpec(LayerSpec): - def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs): + def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr=['weight'], **module_kwargs): super().__init__(typename, *module_args, **module_kwargs) self.key = key self.forward_fn = forward_fn - self.tied_weight_attr = tied_weight_attr + self.tied_weight_attr = [tied_weight_attr] if type(tied_weight_attr) == str else tied_weight_attr class PipelineModule(nn.Module): @@ -115,7 +116,10 @@ def forward(self, inputs): partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. - checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models, + ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are + considered checkpointable. Defaults to None. + dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ def __init__(self, @@ -129,7 +133,8 @@ def __init__(self, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=checkpointing.checkpoint, - checkpointable_layers=None): + checkpointable_layers=None, + dynamic_shape=False): super().__init__() @@ -159,7 +164,7 @@ def __init__(self, self.global_rank = dist.get_rank(group=self.world_group) self.world_size = dist.get_world_size(group=self.world_group) self.local_rank = int(os.environ.get("LOCAL_RANK", None)) - assert self.local_rank != None + assert self.local_rank is not None if topology: self._topo = topology @@ -195,6 +200,16 @@ def __init__(self, #newseed = get_accelerator().initial_seed() + self._grid.get_stage_id() #ds_utils.set_random_seed(newseed) + self.activation_checkpoint_interval = activation_checkpoint_interval + + self.activation_checkpoint_func = activation_checkpoint_func + + #storage for precomputed checkpointeble results + self.is_checkpointable_results = [] + self.is_checkpointable_results_interval = None + + # if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint`` + #with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]): self._build() self.to(get_accelerator().device_name(self.local_rank)) @@ -202,8 +217,17 @@ def __init__(self, self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights() - self.activation_checkpoint_interval = activation_checkpoint_interval - self.activation_checkpoint_func = activation_checkpoint_func + self.dynamic_shape = dynamic_shape + + def _precompute_checkpointable_values(self): + if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval: + num_layers = len(self.forward_funcs) + self.interval_was_zero = False + for start_idx in range(0, num_layers, self.activation_checkpoint_interval): + end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers) + funcs = self.forward_funcs[start_idx:end_idx] + self.is_checkpointable_results.append(self._is_checkpointable(funcs)) + self.is_checkpointable_results_interval = self.activation_checkpoint_interval def _build(self): specs = self._layer_specs @@ -259,6 +283,20 @@ def _build(self): for p in self.parameters(): p.ds_pipe_replicated = False + def _get_frozen_parameter_names(self, layer): + """ Get names of frozen parameters in the layer. + + Returns: + A list of frozen parameter names + """ + if isinstance(layer, LayerSpec): + l = layer.build() + return [n for n, p in l.named_parameters() if not p.requires_grad] + elif isinstance(layer, nn.Module): + return [n for n, p in layer.named_parameters() if not p.requires_grad] + + return [] + def _count_layer_params(self): """Count the trainable parameters in individual layers. @@ -335,7 +373,9 @@ def exec_func(*inputs): else: num_layers = len(self.forward_funcs) x = forward_input - for start_idx in range(0, num_layers, self.activation_checkpoint_interval): + for start_idx, is_checkpointable_result in \ + zip(range(0, num_layers, self.activation_checkpoint_interval), self.is_checkpointable_results): + end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers) funcs = self.forward_funcs[start_idx:end_idx] @@ -344,7 +384,7 @@ def exec_func(*inputs): if not isinstance(x, tuple): x = (x, ) - if self._is_checkpointable(funcs): + if is_checkpointable_result: x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x) else: x = exec_range_func(start_idx, end_idx)(*x) @@ -403,26 +443,37 @@ def _partition_layers(self, method='uniform'): self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) + @staticmethod + def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor: + '''Allow getting an attribute like "linear.weight"''' + weight = module + for item in attr_name.split("."): + weight = getattr(weight, item) + return weight + def allreduce_tied_weight_gradients(self): '''All reduce the gradients of the tied weights between tied stages''' for key, comm in self.tied_comms.items(): - weight = getattr(self.tied_modules[key], comm['weight_attr']) - dist.all_reduce(weight.grad, group=comm['group']) + for attr_name in comm['weight_attr']: + weight = self._recursive_getattr(self.tied_modules[key], attr_name) + dist.all_reduce(weight.grad, group=comm['group']) def get_tied_weights_and_groups(self): weight_group_list = [] for key, comm in self.tied_comms.items(): - weight = getattr(self.tied_modules[key], comm['weight_attr']) - weight_group_list.append((weight, comm['group'])) + for attr_name in comm['weight_attr']: + weight = self._recursive_getattr(self.tied_modules[key], attr_name) + weight_group_list.append((weight, comm['group'])) return weight_group_list def _synchronize_tied_weights(self): for key, comm in self.tied_comms.items(): - dist.broadcast( - getattr(comm['module'], comm['weight_attr']), - src=min(comm['ranks']), - group=comm['group'], - ) + for attr_name in comm['weight_attr']: + dist.broadcast( + self._recursive_getattr(comm['module'], attr_name), + src=min(comm['ranks']), + group=comm['group'], + ) def _index_tied_modules(self): ''' Build communication structures for tied modules. ''' @@ -432,7 +483,10 @@ def _index_tied_modules(self): specs = self._layer_specs tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) - for key in tie_keys: + # Since Python 3.7, "Dictionary order is guaranteed to be insertion order." + # Sort tie_keys here so that orders of self.tied_comms.items() are consistent + # among ranks. + for key in sorted(tie_keys): # Find the layers that the tied module appears in tied_layers = [] for idx, layer in enumerate(specs): @@ -544,7 +598,9 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): ckpt_files.sort() return ckpt_files - def save_state_dict(self, save_dir, checkpoint_engine): + def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=False): + # TODO: Need to validate interaction of checkpoint_parallel_write_pipeline and fastwriter + # Processes having the same model parallel rank on different data parallel instances # have identical layer weights. We can distribute the task of saving the layer weights # among the data parallel ranks. For example, if a pipeline stage has 9 layers and @@ -565,19 +621,20 @@ def save_state_dict(self, save_dir, checkpoint_engine): layer_list = self.forward_funcs[start:end] checkpoint_engine.makedirs(save_dir, exist_ok=True) + should_clone = checkpoint_engine.preserves_storage_sharing() for idx, layer in enumerate(layer_list): model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx) if not hasattr(layer, 'state_dict'): continue - # We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save() - # saves the underlying storage rather than the slice of the storage corresponding to individual tensors. - # This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers. - # Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size. - # It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat. - # See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing + orig_state_dict = layer.state_dict() - final_state_dict = type(orig_state_dict)({k: v.clone() for k, v in orig_state_dict.items()}) - checkpoint_engine.save(final_state_dict, model_ckpt_path) + if exclude_frozen_params: + for n in self._get_frozen_parameter_names(layer): + del orig_state_dict[n] + final_state_dict = orig_state_dict + if should_clone: + final_state_dict = clone_tensors_for_torch_save(orig_state_dict) + checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path) def load_state_dir(self, load_dir, checkpoint_engine, strict=True): for idx, layer in enumerate(self.forward_funcs): @@ -595,7 +652,7 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True): checkpoint_engine=checkpoint_engine) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) - layer.load_state_dict(checkpoint) + layer.load_state_dict(checkpoint, strict=strict) # if self._grid.data_parallel_id == 0: # logger.info( @@ -605,13 +662,37 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True): self._synchronize_tied_weights() def _is_checkpointable(self, funcs): - # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations. - # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things. - # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit. - if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + + if self.activation_checkpoint_func is not checkpointing.non_reentrant_checkpoint: + # This hook excludes the embedding layer + # because only non_reentrant_checkpoint can accept inputs with requires_grad=False + # otherwise, the backward of the embedding layer won't receive gradients. + if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): + # For GPT models, checkpoint both transformer layers and any additional + # layers specified in checkpointable_layers (if provided) + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or ( + self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers) + for f in funcs) + if self.checkpointable_layers is not None: + # For non-GPT models, only checkpoint layers specified in checkpointable_layers return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) + # Default behavior: checkpoint any layer that has parameters params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) + + def get_additional_losses(self): + """ Returns model specific additional losses for reporting + + Return a dictionary of {"loss name": loss_value} or None if no additional losses. + """ + return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index 31b9a14bf0f4..ed6d80b8d4fb 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -3,15 +3,13 @@ # DeepSpeed Team -import pickle +import msgpack import typing import torch from deepspeed import comm as dist -# To query whether we have send/recv support -from packaging.version import Version -from deepspeed.git_version_info import torch_info +from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator _groups = None @@ -21,9 +19,7 @@ def can_send_recv() -> bool: - torch_version = Version(torch_info['version']) - sendrecv_min = Version('1.8') - return torch_version >= sendrecv_min + return required_torch_version(min_version=1.8) #initializes adjacent process groups @@ -100,7 +96,7 @@ def wait(): def send_obj(msg: typing.Any, dest: int): """Send an arbitrary python object to ``dest``. - Note: ``msg`` must be pickleable. + Note: ``msg`` must be serializable by msgpack. WARN: This incurs a CPU -> GPU transfer and should be used sparingly for performance reasons. @@ -110,7 +106,7 @@ def send_obj(msg: typing.Any, dest: int): dest (int): Destination rank. """ # serialize the message - msg = pickle.dumps(msg) + msg = msgpack.packb(msg) # construct a tensor to send msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name()) @@ -137,7 +133,7 @@ def recv_obj(sender: int) -> typing.Any: msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name()) dist.recv(msg, src=sender) - msg = pickle.loads(msg.cpu().numpy().tobytes()) + msg = msgpack.unpackb(msg.cpu().numpy().tobytes()) def _to(x): """Recursively move to the current device.""" diff --git a/deepspeed/runtime/pipe/topology.py b/deepspeed/runtime/pipe/topology.py index 328c19907100..4b1077626364 100644 --- a/deepspeed/runtime/pipe/topology.py +++ b/deepspeed/runtime/pipe/topology.py @@ -411,10 +411,16 @@ def get_pipe_parallel_rank(self): """ The stage of the pipeline this rank resides in. """ return self.get_stage_id() + def get_pipeline_model_parallel_rank(self): + return self.get_pipe_parallel_rank() + def get_pipe_parallel_world_size(self): """ The number of stages in the pipeline. """ return self.pipe_parallel_size + def get_pipeline_model_parallel_world_size(self): + return self.get_pipe_parallel_world_size() + def get_pipe_parallel_group(self): """ The group of ranks within the same pipeline. """ return self.pp_proc_group @@ -431,6 +437,10 @@ def get_data_parallel_group(self): """ The group of ranks within the same stage of all pipelines. """ return self.dp_proc_group + def get_data_parallel_group_ranks(self): + """ List of ranks in the data parallel group. """ + return self.dp_group + # These are model parallel groups across all types of model parallelism. # Deepspeed uses them to detect overflow, etc. def get_model_parallel_rank(self): @@ -449,8 +459,14 @@ def get_slice_parallel_rank(self): else: return 0 + def get_tensor_model_parallel_rank(self): + return self.get_slice_parallel_rank() + def get_slice_parallel_world_size(self): return self.slice_parallel_size + def get_tensor_model_parallel_world_size(self): + return self.get_slice_parallel_world_size() + def get_slice_parallel_group(self): return self.slice_proc_group diff --git a/deepspeed/runtime/precision_config.py b/deepspeed/runtime/precision_config.py new file mode 100644 index 000000000000..efec5c9d00c8 --- /dev/null +++ b/deepspeed/runtime/precision_config.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import math +from pydantic import field_validator +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .fp16.loss_scaler import ( + INITIAL_LOSS_SCALE, + SCALE_WINDOW, + DELAYED_SHIFT, + CONSECUTIVE_HYSTERESIS, + MIN_LOSS_SCALE, +) + +######################################### +# BFLOAT16 support +######################################### +# BFLOAT16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +BFLOAT16_FORMAT = ''' +BFLOAT16 parameters should be of the format: +"bf16": { + "enabled": true, + "immediate_grad_update": false, + "check_grad_overflow": false +} +''' +BFLOAT16 = "bf16" +BFLOAT16_OLD = "bfloat16" # keeping for backwards compatibility + + +def get_bfloat16_config(param_dict): + bf16_config_dict = param_dict.get(BFLOAT16, None) + if bf16_config_dict is None: + bf16_config_dict = param_dict.get(BFLOAT16_OLD, {}) + return DeepSpeedBF16Config(**bf16_config_dict) + + +class DeepSpeedBF16Config(DeepSpeedConfigModel): + """ + For bfloat16 configuration + """ + + enabled: bool = False + """ + Enable bfloat16 mixed-precision training/inference + """ + + immediate_grad_update: bool = False + """ + Apply gradient updates immediately rather than delayed. + """ + + check_grad_overflow: bool = False + """ + Check for gradient overflows and underflows + """ + + bf16_master_weights_and_grads: bool = False + """ + Maintain master weights/gradients in bf16 precision for ZeRO optimizer. + """ + + bf16_optimizer_states: bool = False + """ + Keep optimizer states in bf16 (only valid when bf16_master_weights_and_grads is enabled). + """ + + +######################################### +# FP16 support +######################################### +# FP16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +FP16_FORMAT = ''' +FP16 parameters should be of the format: +"fp16": { + "enabled": true, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 +} +''' +FP16 = "fp16" + + +def get_float16_config(param_dict): + fp16_config_dict = param_dict.get(FP16, {}) + return DeepSpeedFP16Config(**fp16_config_dict) + + +class DeepSpeedFP16Config(DeepSpeedConfigModel): + """ + For float16 configuration + """ + + enabled: bool = False + """ + Enable fp16 mixed-precision training/inference + """ + + auto_cast: bool = False + """ + Automatically cast inputs to fp16 + """ + + loss_scale: float = 0 + """ + Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale. + """ + + @field_validator("loss_scale", mode="before") + @classmethod + def _validate_loss_scale(cls, v): + if isinstance(v, bool): + raise ValueError("fp16.loss_scale must be a number, not bool") + try: + v = float(v) + except (TypeError, ValueError): + raise ValueError("fp16.loss_scale must be a number") + if not math.isfinite(v): + raise ValueError("fp16.loss_scale must be a finite number (not inf/-inf/nan)") + if v < 0: + raise ValueError("fp16.loss_scale must be >= 0 (0 enables dynamic loss scaling)") + return v + + initial_scale_power: int = 16 + """ + For dynamic loss scaling, set initial loss scale to 2^{initial_scale_power}. + """ + + loss_scale_window: int = 1000 + """ + Iteration intervals for raising/lowering dynamic loss scale value. + """ + + hysteresis: int = 2 + """ + Delay shift in dynamic loss scaling. + """ + + consecutive_hysteresis: bool = False + """ + Refill hysteresis if iteration does not overflow/underflow. + """ + + min_loss_scale: int = 1 + """ + Minimum dynamic loss scale value. + """ + + fp16_master_weights_and_grads: bool = False + """ + Maintain master weights in optimizer state as fp16 instead of fp32 (valid with DeepSpeedCPUAdam only). + """ + + def initial_dynamic_scale(self): + return 2**self.initial_scale_power + + def dynamic_loss_scale_args(self): + return { + INITIAL_LOSS_SCALE: 2**self.initial_scale_power, + SCALE_WINDOW: self.loss_scale_window, + DELAYED_SHIFT: self.hysteresis, + CONSECUTIVE_HYSTERESIS: self.consecutive_hysteresis, + MIN_LOSS_SCALE: self.min_loss_scale, + } diff --git a/deepspeed/runtime/sequence_parallel/__init__.py b/deepspeed/runtime/sequence_parallel/__init__.py new file mode 100644 index 000000000000..d8cb728da375 --- /dev/null +++ b/deepspeed/runtime/sequence_parallel/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) The DeepSpeed Contributors +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/sequence_parallel/parallel_state_sp.py b/deepspeed/runtime/sequence_parallel/parallel_state_sp.py new file mode 100644 index 000000000000..806a58a23ff4 --- /dev/null +++ b/deepspeed/runtime/sequence_parallel/parallel_state_sp.py @@ -0,0 +1,96 @@ +# Copyright (c) The DeepSpeed Contributors +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +This is a slimmed-down version of parallel_state.py (mpu) from Megatron-Deepspeed +""" + +from deepspeed import comm as dist + +# Sequence parallel groups to handle both data and sequence parallelisms. +# These groups are used to reduce gradients and shard parameters and optimizer stages for ZeRO. +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_DATA_PARALLEL_GROUP = None + + +def initialize_sequence_parallel(sequence_parallel_size: int) -> None: + """Initialize sequence parallel groups.""" + + assert dist.is_initialized() + world_size: int = dist.get_world_size() + + if world_size < sequence_parallel_size: + raise RuntimeError(f"world_size ({world_size}) is less than sequence_parallel_size {sequence_parallel_size}") + + if sequence_parallel_size <= 1: + raise ValueError(f"sequence_parallel_size must be greater than 1, got {sequence_parallel_size}") + + if world_size % sequence_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})") + + data_parallel_size: int = world_size // sequence_parallel_size + sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size + + rank = dist.get_rank() + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized" + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = dist.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + + # Build the sequence data parallel groups. + global _SEQUENCE_DATA_PARALLEL_GROUP + assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized" + all_data_sequence_parallel_group_ranks = [] + for i in range(num_sequence_data_parallel_groups): + ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size) + group = dist.new_group(ranks) + all_data_sequence_parallel_group_ranks.append(list(ranks)) + if rank in ranks: + _SEQUENCE_DATA_PARALLEL_GROUP = group + + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_data_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized" + return _SEQUENCE_DATA_PARALLEL_GROUP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return dist.get_world_size(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return dist.get_world_size(group=get_sequence_data_parallel_group()) + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return dist.get_rank(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_rank(): + """Return my rank for the sequence data parallel group.""" + return dist.get_rank(group=get_sequence_data_parallel_group()) + + +# since we only have 1 additional dimension over DP, we can just alias MP with SP +get_model_parallel_rank = get_sequence_parallel_rank +get_model_parallel_world_size = get_sequence_parallel_world_size +get_model_parallel_group = get_sequence_parallel_group diff --git a/deepspeed/runtime/sequence_parallel/ulysses_sp.py b/deepspeed/runtime/sequence_parallel/ulysses_sp.py new file mode 100644 index 000000000000..413921c2090c --- /dev/null +++ b/deepspeed/runtime/sequence_parallel/ulysses_sp.py @@ -0,0 +1,1593 @@ +# Copyright (c) The DeepSpeed Contributors +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +*** Arctic Long Sequence Training (ALST) components *** + +1. Ulysses Sequence Parallelism for HF Transformers implements an efficient way of training on long sequences by employing sequence parallelism and attention head parallelism. +2. ALST enables even longer sequence lengths using a bag of tricks: +- Activation checkpoint offload to CPU +- Tiled MLP compute +- Liger-kernel +- PYTORCH_CUDA_ALLOC_CONF + +ALST features found in this module: + +- `UlyssesSPAttentionHF` - port of UlyssesAttention from Megatron-Deepspeed plus modern MHA-variations +- `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL batches to be used by `UlyssesSPAttentionHF` +- `SequenceTiledCompute` - generic autograd function to perform compute after tiling on the sequence dimension +- `TiledMLP` - a specific autograd function to perform tiled MLP (it's much easier to understand before trying to grok `SequenceTiledCompute`) +- `TiledFusedLogitsLoss` - a specific autograd function to perform loss computation without manifesting the full logits tensor and instead computing loss on shards of logits. + +This module implements Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences: https://arxiv.org/abs/2506.13996 + +For integration docs see: https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/ + +The other ALST features live inside +https://github.com/snowflakedb/ArcticTraining/blob/main/projects/sequence-parallelism/ + +""" + +from collections import defaultdict, deque +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.sequence.layer import _DimZeroAllToAll +from deepspeed.utils.logging import logger +from einops import rearrange +from packaging import version +from torch import Tensor +from torch.utils.data import DataLoader +from typing import Any +from typing import Tuple +import deepspeed.comm as dist +import importlib.metadata +import math +import re +import torch +import torch.distributed.nn + + +class UlyssesSPAttentionHF(torch.nn.Module): + """Re-Implementation of deepspeed.sequence.layer.DistributedAttention. This implementation enforces the input shape + to be standard [sl, bs, hc, hs] form. Any deviation from this shape will raise an error. + + The primary reason for the re-implementation is to make this less error prone, and remove what seemed like bugs in scenarios where batch size > 1 and when using different versions of + flash attention each of which takes different input shape. Those should be handled by + the actual attn implementation, and not by this module. + + This class then has been further adapted to work with HF Transformers' supported attention mechanism. + + Dimension annotation: + bs = bs + hc = head count + hc_l = head count local + hs = head_size + sl = seqlen + sl_l = seqlen local + ws = world_size + em = embedding (hidden size) + em_l = embedding (hidden size) local + + Arguments: + attn: normal attention implementation from transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS + seq_length_is_variable (bool): whether global seqlen may change between batches + local_seq_length (int): local sequence length per GPU or None if seq_length_is_variable is True + global_seq_length (int): actual sequence length or None if seq_length_is_variable is True + batch_size (int): batch size + attn_head_size (int): size of each attention head + attn_head_count (int): total number of attention heads + kv_head_count (int): total number of kv heads + num_hidden_layers (int): total number of layers + process_group (dist.ProcessGroup): Ulysses process group + disable_in_eval (bool): whether to disable sequence parallelism during evaluation (default: False). + When True, SP operations are bypassed during eval to avoid potential issues with frameworks + like HF Trainer that may run eval with different data distribution. + + + Extras: + - set self.skip_all_but_last_attention_debug_mode to True to enable fast debug which will skip calling all core attn layers but the last one, it will produce garbage of course quality-wise. + """ + + def __init__( + self, + attn, + batch_size: int, + attn_head_count: int, + attn_head_size: int, + kv_head_count: int, + num_hidden_layers: int, + process_group: dist.ProcessGroup, + seq_length_is_variable: bool = False, + local_seq_length: int = None, + global_seq_length: int = None, + disable_in_eval: bool = False, + ) -> None: + super().__init__() + self.attn = attn + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.sp_rank = dist.get_rank(process_group) + + self.batch_size = batch_size + self.seq_length_is_variable = seq_length_is_variable + self.local_seq_length = local_seq_length + self.global_seq_length = global_seq_length + self.disable_in_eval = disable_in_eval + + self.attn_head_size = attn_head_size + self.attn_head_count = attn_head_count + self.global_kv_head_count = kv_head_count + + self.num_hidden_layers = num_hidden_layers + self.skip_all_but_last_attention_debug_mode = False + self.rotating_layer_counter = 0 # used for dev work + + self.core_attn_implementation = None # set by register_with_transformers + self._flex_block_mask_cached = None # cached BlockMask for flex_attention + self._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation + + self.local_q_head_count = attn_head_count // self.world_size + + # if we have 4 kv heads and sp 8, we need to replicate kv heads 2x + self.kv_replication_factor = self.world_size // kv_head_count + if self.kv_replication_factor > 1: + self.local_kv_head_count = 1 + else: + self.local_kv_head_count = kv_head_count // self.world_size + + transformers_version_min = "4.51.3" + transformers_version_have = importlib.metadata.version("transformers") + if version.parse(transformers_version_have) < version.parse(transformers_version_min): + raise ValueError( + f"transformers>={transformers_version_min} is required, but you have transformers=={transformers_version_have}" + ) + + if self.attn_head_count % self.world_size != 0: + raise ValueError(f"Attention head count {attn_head_count} is not divisible by SP size {self.world_size}") + if not (self.global_kv_head_count % self.world_size == 0 or self.world_size % self.global_kv_head_count == 0): + raise ValueError( + f"KV attention head count {self.global_kv_head_count} is not divisible by SP size {self.world_size} or" + " vice versa") + + if self.seq_length_is_variable: + # the self.required_*_shape depending on the following will get updated in `forward` + # use 1 as a placeholder for dim=0 to keep torch.Size happy + local_seq_length = 1 + global_seq_length = 1 + + # [sl_l bs hc hs] + self.required_query_shape = torch.Size([local_seq_length, batch_size, attn_head_count, attn_head_size]) + self.required_key_value_shape = torch.Size([local_seq_length, batch_size, kv_head_count, attn_head_size]) + + # [sl bs em_l] + self.required_context_shape = torch.Size( + [global_seq_length, batch_size, attn_head_size * attn_head_count // self.world_size]) + + def _combine_local_sequences(self, query, key, value) -> Tuple[Tensor, Tensor, Tensor]: + + def combine_sequence(input, head_type): + """ + expects inputs in shape: [sl_l bs hc hs] + returns output in shape: [sl bs hc_l hs] + + local_head_count could be different for k,v vs q if it's not an MHA situation + """ + if head_type == "q": + local_head_count = self.local_q_head_count + else: # kv + local_head_count = self.local_kv_head_count + + # MQA and some GQA cases: + if self.kv_replication_factor > 1: + # local_head_count *= self.kv_replication_factor + # replicate heads to the kv_replication_factor on hc dimension [sl_l bs hc hs] - so dim=2 + input = input.repeat_interleave(self.kv_replication_factor, dim=2) + + # [sl_l bs hc hs] -> [sl_l bs ws hc_l hs] + input = input.reshape( + [self.local_seq_length, self.batch_size, self.world_size, local_head_count, self.attn_head_size]) + + input = rearrange(input, "sl_l bs ws hc_l hs -> ws sl_l bs hc_l hs").contiguous() + + output = _DimZeroAllToAll.apply(self.process_group, input) + + # [ws sl_l bs hc_l hs] -> [sl bs hc_l hs] + output = output.reshape([self.global_seq_length, *output.shape[2:]]).contiguous() + + # [sl bs hc_l hs] + return output + + return ( + combine_sequence(query, head_type="q"), + combine_sequence(key, head_type="kv"), + combine_sequence(value, head_type="kv"), + ) + + def _partition_global_sequence(self, input) -> Tensor: + """ + expects input in shape: [sl bs em_l] + returns output in shape: [sl_l bs em] + """ + + # [sl bs em_l] -> [ws sl_l bs em_l] + input = input.reshape([ + self.world_size, + self.local_seq_length, + self.batch_size, + self.attn_head_size * self.attn_head_count // self.world_size, + ]).contiguous() + + output = _DimZeroAllToAll.apply(self.process_group, input) + output = rearrange(output, "ws sl_l bs em_l -> sl_l bs ws em_l") + + # [sl_l bs ws em_l] -> [sl_l bs em] + output = output.reshape([*output.shape[:2], -1]).contiguous() + + # [sl_l bs em] + return output + + def forward( + self, + module: torch.nn.Module, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Tensor, + *args: Any, + **kwargs: Any, + ) -> Tensor: + """forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + attention_mask (Tensor): Attention mask + args: other args + + Returns: + * output (Tensor): context output + """ + # HF incoming shapes are: + # [batch_size, num_heads, seqlen, head_size] + # UlyssesSPAttentionHF expects: + # [seqlen, batch_size, num_heads, head_size] + # print_rank0(f"{query.shape=}") + # print_rank0(f"{key.shape=}") + # print_rank0(f"{value.shape=}") + # print_rank0(f"{self.required_input_shape=}") + + # Skip SP operations during eval if disable_in_eval is True + # This avoids issues with frameworks like HF Trainer that may run eval with different data distribution + if not module.training and self.disable_in_eval: + return self.attn(module, query, key, value, attention_mask, *args, **kwargs) + + if self.seq_length_is_variable: + current_local_seq_length = query.shape[2] + self.local_seq_length = current_local_seq_length + self.global_seq_length = current_local_seq_length * self.world_size + # update the required seqlen shapes + self.required_query_shape = torch.Size([self.local_seq_length] + list(self.required_query_shape)[1:]) + self.required_key_value_shape = torch.Size([self.local_seq_length] + + list(self.required_key_value_shape)[1:]) + self.required_context_shape = torch.Size([self.global_seq_length] + list(self.required_context_shape)[1:]) + + # make the blocks contiguous as early as possible to minimize fragmentation + query = rearrange(query, "bs hc sl hs -> sl bs hc hs") # .contiguous() + key = rearrange(key, "bs hc sl hs -> sl bs hc hs") # .contiguous() + value = rearrange(value, "bs hc sl hs -> sl bs hc hs") # .contiguous() + + # All attention backends need unsharded position_ids after the all-to-all. + # FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention + # need them to be monotonically increasing so causal masking works correctly. + # UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding, + # so after gathering here they reconstruct to the correct global positions. + assert "position_ids" in kwargs, ( + "Ulysses SP requires position_ids in every forward() call so that after all_gather " + "causal masking works correctly. Without them each rank generates local [0..chunk_len-1] " + "positions which, after gathering, look like packed sequences and break attention. " + "For non-packed sequences: position_ids = torch.arange(seq_len) per sample. " + "For packed sequences: position_ids must reset at document boundaries. " + "Ensure your data collator or UlyssesSPDataLoaderAdapter includes position_ids.") + position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)] + dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group) + kwargs["position_ids"] = torch.cat(position_ids_list, dim=1) + + # please don't remove the white-space vertical alignment in the error message + assert query.shape == self.required_query_shape, ( + f"[{dist.get_rank()}]: query input tensor does not match the required shape\n " + f" {self.required_query_shape}:\n {query.shape=}\n {key.shape=}\n {value.shape=}") + assert key.shape == value.shape == self.required_key_value_shape, ( + f"[{dist.get_rank()}]: key or value input tensor does not match the required shape\n " + f" {self.required_key_value_shape}:\n {query.shape=}\n {key.shape=}\n {value.shape=}") + + # expects: [sl_l bs hc hs] + query_layer, key_layer, value_layer = self._combine_local_sequences(query, key, value) + # returns: [sl bs hc_l hs] + + query_layer = rearrange(query_layer, "sl bs hc_l hs -> bs hc_l sl hs").contiguous() + key_layer = rearrange(key_layer, "sl bs hc_l hs -> bs hc_l sl hs").contiguous() + value_layer = rearrange(value_layer, "sl bs hc_l hs -> bs hc_l sl hs").contiguous() + + # crucial in the case of MQA and some GQA cases we need to fix `module.num_key_value_groups` + # XXX: could move this somewhere to do it only once per run + if self.kv_replication_factor > 1: + module.num_key_value_groups = query_layer.size(-3) // key_layer.size(-3) + + # For flex_attention: the wrapper preserved the BlockMask from the model, but it + # was built for the local shard's sequence length. Rebuild it for the full gathered + # sequence length after the all-to-all. + # XXX: currently hardcodes a causal mask_mod — models with sliding window or other + # non-standard patterns would need the mask_mod extracted from the original BlockMask. + if self.core_attn_implementation == "flex_attention": + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + if isinstance(attention_mask, BlockMask): + seq_len = query_layer.shape[2] + batch_size = query_layer.shape[0] + cache_key = (batch_size, seq_len) + + # Cache the BlockMask — create_block_mask is expensive and the mask is the + # same for all layers within a forward pass. Only rebuild when dimensions change. + if self._flex_block_mask_cache_key != cache_key: + + def causal_mask(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + # Don't compile create_block_mask here — it runs inside the model's + # forward pass where flex_attention already uses torch.compile, and + # nesting compiled contexts causes gradient explosion in the backward + # pass. The BlockMask is cached so creation cost is negligible. + self._flex_block_mask_cached = create_block_mask( + mask_mod=causal_mask, + B=batch_size, + H=None, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=query_layer.device, + ) + self._flex_block_mask_cache_key = cache_key + + attention_mask = self._flex_block_mask_cached + + if not self.skip_all_but_last_attention_debug_mode: + # expects: [bs hc_l sl hs] + context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, *args, + **kwargs) + # returns [bs sl hc_l hs] + else: + # we need this hack during development in order to be able to check memory fitting w/o + # waiting for 3h to compute 1.5M seqlen attention, because it's quadratic in dense + # attention, so we skip all but the last core attention call - we want the last one to + # still get the memory usage approximately close to the real memory usage. of course + # the loss will be wrong when we do that. + self.rotating_layer_counter = (self.rotating_layer_counter + 1) % self.num_hidden_layers + # we detect the last layer by module counting since we know how many layers there are + if self.rotating_layer_counter % self.num_hidden_layers == 0: + # do the real pass + context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, + *args, **kwargs) + else: + # this feeds bogus data of the right shape - good enough for quick debug + context_layer = rearrange(query_layer, "bs hc_l sl ... -> bs sl hc_l ...") + attn_weights = None + + # [bs sl hc_l hs] -> [sl bs hc_l hs]' + context_layer = rearrange(context_layer, "bs sl ... -> sl bs ...") + context_layer = context_layer.reshape([*context_layer.shape[:2], -1]) + + assert ( + context_layer.shape == self.required_context_shape + ), f"The context shape {context_layer.shape} is not of the expected shape {self.required_context_shape}" + + # expects: [sl bs em_l] + output = self._partition_global_sequence(context_layer) + # returns: [sl_l bs em] + + output = rearrange(output, "sl_l bs ... -> bs sl_l ...") + + output = output.reshape([*output.shape[:2], -1]) + + # expects [bs sl em] + return output, attn_weights + + @classmethod + def register_with_transformers( + cls, + model_name_or_path, + core_attn_implementation, + sequence_parallel_size, + micro_batch_size, + seq_length=None, + seq_length_is_variable=True, + disable_in_eval=False, + # deprecated + max_length=None, + ): + """ + Register "ulysses" attn_implementation with HF transformers and return mpu (Megatron-LM-style parallel state groups object). + If sequence_parallel_size==1 do nothing and return None. + + Args: + - model_name_or_path (object or str): model object, or HF hub model name, or model's local path + - core_attn_implementation (str): which attention to use: flash_attention_2 or flash_attention_3 or sdpa + - sequence_parallel_size (int): sequence parallelism dimension (if 1 it's disabled) + - micro_batch_size (int): micro batch size + - seq_length (int): set this argument if the sequence length is fixed in all batches + - seq_length_is_variable (bool): whether global seqlen may change between batches an optimization flag - the default is `True` + - disable_in_eval (bool): whether to disable sequence parallelism during evaluation (default: False). + When True, SP operations are bypassed during eval to avoid issues with frameworks + like HF Trainer that may run eval with different data distribution. + - max_length (int): actual global sequence length - this argument is deprecated - use `seq_length` instead + + """ + if sequence_parallel_size == 1: + return None + + if max_length is not None: + logger.warning( + "The 'max_length` argument is deprecated and will be eventually removed, please use `seq_length` instead" + ) + if seq_length is None and max_length is not None: + seq_length = max_length + if not seq_length_is_variable and seq_length is None: + raise ValueError( + "Either `seq_length_is_variable` needs to be `True` or `seq_length` needs to be set to an integer value of the fixed batch size length." + ) + + from transformers import AutoConfig + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + import deepspeed.runtime.sequence_parallel.parallel_state_sp as mpu + + mpu.initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size) + + from transformers import PreTrainedModel + if hasattr(model_name_or_path, "config") or isinstance(model_name_or_path, PreTrainedModel): + # we already have the model (or a PEFT wrapper with config attribute) + hf_model_config = model_name_or_path.config + else: + # if we don't have the model yet at this stage + hf_model_config = AutoConfig.from_pretrained(model_name_or_path) + + model_attn_implementation = getattr(hf_model_config, "_attn_implementation", None) + if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation: + raise ValueError( + f"core_attn_implementation='{core_attn_implementation}' does not match " + f"model config attn_implementation='{model_attn_implementation}'. " + "Set both to the same value so sequence-parallel wrapper can intercept the active attention path.") + + # eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back + # to is_causal=True like sdpa — so it's incompatible with SP which discards masks. + unsupported_attn_implementation = ["eager", "paged|eager"] + if core_attn_implementation in unsupported_attn_implementation: + raise ValueError( + f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence" + f" parallelism because it requires a 4D attention_mask (O(n²) memory)." + f" Use any flash attention variant, 'flex_attention', 'sdpa'," + f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2').") + + # Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers. + # Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS. + is_hub_kernel_attn = (isinstance(core_attn_implementation, str) and re.search( + r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", core_attn_implementation) is not None) + if is_hub_kernel_attn: + try: + from transformers.modeling_flash_attention_utils import lazy_import_flash_attention + except ImportError as e: + raise ImportError("Hub kernel attention requires a transformers version exposing " + "`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`.") from e + lazy_import_flash_attention(core_attn_implementation) + + if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS: + raise ValueError( + f"{core_attn_implementation} is not a valid attn_implementation. The choices are {ALL_ATTENTION_FUNCTIONS.valid_keys()}" + ) + core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation] + + if seq_length_is_variable: + local_seq_length = None + global_seq_length = None + else: + local_seq_length = seq_length // mpu.get_sequence_parallel_world_size() + global_seq_length = seq_length + + arch_cfg = hf_model_config.get_text_config() + + uattn = UlyssesSPAttentionHF( + attn=core_attn_function, + batch_size=micro_batch_size, + attn_head_count=arch_cfg.num_attention_heads, + attn_head_size=getattr( + arch_cfg, + "head_dim", + arch_cfg.hidden_size // arch_cfg.num_attention_heads, + ), + kv_head_count=arch_cfg.num_key_value_heads, + num_hidden_layers=arch_cfg.num_hidden_layers, + process_group=mpu.get_sequence_parallel_group(), + seq_length_is_variable=seq_length_is_variable, + local_seq_length=local_seq_length, + global_seq_length=global_seq_length, + disable_in_eval=disable_in_eval, + ) + uattn.core_attn_implementation = core_attn_implementation + + def uattn_wrapper( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + # SP relies on position_ids (not attention_mask) for causal masking. + # HF doesn't know about the SP wrapper, so it creates an attention_mask for + # the local shard's sequence length — which is invalid after the SP all-to-all + # gathers the full sequence. A 4D mask at full sequence length would also be + # O(n²) memory. So we discard 4D tensor masks. + # + # Keep BlockMask (flex_attention) — it's a compressed sparse representation. + # It will be rebuilt for the full gathered sequence in forward(). + _is_block_mask = False + if core_attn_implementation == "flex_attention": + from torch.nn.attention.flex_attention import BlockMask + _is_block_mask = isinstance(attention_mask, BlockMask) + + if not _is_block_mask: + attention_mask = None + + attn_output, attn_weights = uattn( + module, + query, + key, + value, + attention_mask, + *args, + **kwargs, + ) + return attn_output, attn_weights + + # We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper) + # The problem with that approach is that we'd miss all the special-case branches in + # HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...` + # So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS + # with our wrapper. All other code paths relying on the original core attn_implementation + # will still be executed — we only intercept at the point of calling attention. + # This is what we called "Being John Malkovich". + ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper + + return mpu + + +class UlyssesSPDataLoaderAdapter: + + def __init__( + self, + dl: DataLoader, + sp_rank: int, + sp_group, + sp_world_size, + device, + ): + """ + This a DataLoader adapter which wraps around any existing DataLoader. It is used in conjunction with Ulysses to perform batch sharding on the sequence dimension. + + It gathers 1 sample from each participating rank, using the DL it wraps, then shards each of them and sends back to the ranks. So that when dl->iter->next is called, we end up with: + - rank 0: getting batch 0 shard 0 + - rank 1: getting batch 0 shard 1 + ... + - rank n: getting batch 0 shard n + which is used to compute the batch (from rank0) using all SP ranks. + + When the next iteration starts and dl->iter->next is called, we end up with: + - rank 0: getting batch 1 shard 0 + - rank 1: getting batch 1 shard 1 + ... + - rank n: getting batch 1 shard n + which is used to compute a second batch (from rank1) using all SP ranks. + + This continues until SP iterations are performed. At this point we need to get more data and so the above repeats. + + The key thing to understand is that all SP ranks participate in processing a single DL sample. So instead of normal DataParallel we perform a sort of SP over DP. + + When SP number of iterations is completed it's an equivalent of performing a single iteration with normal DP. + + If more tokens need to be consumed per step use the gradient accumulation feature. + + Ulysses expects the following dict keys in each DL batch (`dl->iter->next`): + - `input_ids` + - `position_ids` + - `labels` + + Additional entries can be present. + + The tensors are expected to be of shape: `[batch_size, seqlen, ...]` + + The sharding happens on the seqlen (1st) dimension for all tensors in the batch, any non-tensor entries get copied to all ranks. + + `attention_mask` isn't used by Ulysses, because it's typically too large when it's 4D, and position_ids is just 1D, therefore it's much much smaller and consumes little GPU memory. + + Arguments: + - `dl`: an existing DataLoader object to wrap + - `sp_rank`: SP rank + - `sp_group`: SP group + - `sp_world_size`: SP world size + - `device`: cuda device + + Returns: + Another DataLoader object + """ + + self.dl = dl + self.sp_rank = sp_rank + self.sp_group = sp_group + self.sp_world_size = sp_world_size + self.device = device + + self.iter = iter(dl) + self.micro_batches: deque[Any] = deque() + + def __len__(self): + return len(self.dl) * self.sp_world_size + + def __iter__(self): + return self + + def __next__(self): + if len(self.micro_batches) == 0: + self.refill() + + return self.micro_batches.popleft() + + def refill(self): + # reset the iterator if StopIteration arrives, and re-raise it to allow multiple epochs to run + try: + batch = next(self.iter) + except StopIteration: + self.iter = iter(self.dl) + raise StopIteration + micro_batches = defaultdict(dict) + # XXX: replace with more efficient all-to-all? + + # position_ids must exist before sharding so that after all_gather in + # UlyssesSPAttentionHF.forward() they reconstruct to correct global positions. + # Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER + # sharding, which after all_gather looks like packed sequences and breaks + # sdpa/flex_attention causal masking. + if "position_ids" not in batch: + raise ValueError("Ulysses SP requires `position_ids` in every dataloader batch so that " + "each token retains its correct global position after sequence sharding. " + "For non-packed sequences: position_ids = torch.arange(seq_len) per sample. " + "For packed sequences: position_ids must reset at document boundaries. " + "Ensure your data collator includes position_ids in its output.") + + # we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank + seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device) + seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)] + dist.all_gather(seqlens, seqlen, group=self.sp_group) + seqlens = [x[0].item() for x in seqlens] + + for k in batch.keys(): + if torch.is_tensor(batch[k]): + batch[k] = batch[k].to(self.device) + if seqlen != batch[k].shape[1]: + raise ValueError( + f"{k}'s shape {batch[k].shape} must match input_ids's shape {batch['input_ids'].shape}") + with torch.no_grad(): + tensor_list = [ + torch.zeros((batch[k].shape[0], seqlens[i]), dtype=batch[k].dtype, device=batch[k].device) + for i in range(self.sp_world_size) + ] + dist.all_gather(tensor_list, batch[k], group=self.sp_group) + else: + tensor_list = [None for _ in range(self.sp_world_size)] + dist.all_gather_object(tensor_list, batch[k], group=self.sp_group) + + for rank, tensor in enumerate(tensor_list): + micro_batches[rank][k] = tensor + + del tensor_list + del batch + + for batch in micro_batches.values(): + seq_length = len(batch["input_ids"][0]) + + if seq_length % self.sp_world_size != 0: + raise ValueError(f"batch's seqlen={seq_length} isn't divisible by sp-size={self.sp_world_size}") + chunk_len = seq_length // self.sp_world_size + + # because we have to gather logits from all sp ranks we have to do the loss function ourselves + # therefore remove labels to avoid an attempt to calculate loss by transformers + labels = batch.pop("labels") + labels = torch.nn.functional.pad(labels, (0, 1), value=-100) + batch["shift_labels"] = labels[..., 1:].contiguous() + # free up temp memory + del labels + + # batch sharding + for k in batch.keys(): + # leave non-tensors alone + if not torch.is_tensor(batch[k]): + continue + # at seqlen>10M and 32+ gpus this can take GBs of memory so keep the prefill buffer on cpu + batch[k] = batch[k][:, chunk_len * self.sp_rank:chunk_len * (self.sp_rank + 1)].cpu() + + self.micro_batches.append(batch) + + +def sequence_tiled_compute( + fn, + seqlen, + shards, + kwargs_to_shard, + kwargs_to_pass, + grad_requiring_tensor_key, + compute_params=None, + output_unshard_dimension=1, + output_reduction="mean", +): + """ + This is a wrapper for SequenceTiledCompute which we need since torch.autograd.Function can't work with dicts of tensors (in backward it has to return a grad value and not a dict that may have a non-None grad value). It's also useful for setting default values which we can't do either in torch.autograd.Function. + + Args: + - `fn`: the function to call on sharded inputs + - `seqlen`: total seqlen of the seqlen dimension + - `shards`: how many shards to use + - `kwargs_to_shard`: this dict will be passed to `fn` as `**kwargs` after sharding on seqlen dimension + - `kwargs_to_pass`: this dict will be passed to `fn` as is, as `**kwargs` + - `grad_requiring_tensor_key`: which main key requires grads + - `compute_params`: a list of weights engaged in the compute. Default: `None` (only needed when using DeepSpeed ZeRO) + - `output_reduction`: None, "mean" or "sum": Default: "mean" + - `output_unshard_dimension`: the dimension to concat the outputs on: Default: 1 (seqlen dim) + + Returns: + - unsharded output with an optional reduction applied, depending on the `output_reduction` value: + `None` - return the unsharded output tensor + `"mean"` - apply mean + `"sum"` - apply sum + + Please note that this implementation doesn't require DeepSpeed and can work without it. `compute_params` can remain `None` in such a case. + + """ + args_to_shard = kwargs_to_shard.values() + keys_to_shard = list(kwargs_to_shard.keys()) + args_to_pass = kwargs_to_pass.values() + keys_to_pass = list(kwargs_to_pass.keys()) + + return SequenceTiledCompute.apply( + fn, + seqlen, + shards, + keys_to_shard, + keys_to_pass, + grad_requiring_tensor_key, + compute_params, + output_unshard_dimension, + output_reduction, + *args_to_shard, + *args_to_pass, + ) + + +class SequenceTiledCompute(torch.autograd.Function): + """ + A generic autograd function to perform a tiled compute. + + Please note this module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. And if you're using activation checkpointing it then occurs trice. + + Please note that this implementation doesn't require DeepSpeed and can work without it. `compute_params` can remain `None` in such a case. + + For an easier to understand example see TiledMLP - which is the same as this autograd function but without the generalization code. + """ + + @staticmethod + def forward( + ctx, + fn, + seqlen, + shards, + keys_to_shard, + keys_to_pass, + grad_requiring_tensor_key, + compute_params, + output_unshard_dimension, + output_reduction, + *args, + ) -> torch.Tensor: + """ + for args and return values see `sequence_tiled_compute`'s doc + + Currently we assume that all kwargs_to_shard values have a shape of `[bs, seqlen, ...]` and we shard on seqlen dimension + """ + ctx.fn = fn + ctx.seqlen = seqlen + ctx.shards = shards + ctx.grad_requiring_tensor_key = grad_requiring_tensor_key + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.output_unshard_dimension = output_unshard_dimension + ctx.output_reduction = output_reduction + + with torch.no_grad(): + args = list(args) + ctx.total_args = len(args) + ctx.grad_requiring_tensor_key_index = (keys_to_shard + keys_to_pass).index(grad_requiring_tensor_key) + + kwargs_to_shard = {k: args.pop(0) for k in keys_to_shard} + kwargs_to_pass = {k: args.pop(0) for k in keys_to_pass} + ctx.kwargs_to_shard = kwargs_to_shard + ctx.kwargs_to_pass = kwargs_to_pass + + with torch.no_grad(): + shard_step = math.ceil(seqlen / shards) + output_shards = [] + + for i in range(shards): + output = fn( + **{ + k: v[:, i * shard_step:(i + 1) * shard_step] + for k, v in kwargs_to_shard.items() + }, + **kwargs_to_pass, + ) + output_shards.append(output) + + if output_unshard_dimension == 0: + # this is just the shape=[1] loss use-case, not sure if it's generic enough + output_unsharded = torch.cat([l.unsqueeze(0) for l in output_shards], dim=output_unshard_dimension) + else: + output_unsharded = torch.cat(output_shards, dim=output_unshard_dimension) # .clone().detach() + + if output_reduction is None: + return output_unsharded + elif output_reduction == "mean": + return output_unsharded.mean() + elif output_reduction == "sum": + return output_unsharded.sum() + else: + raise ValueError(f"unknown value {output_reduction}: valid values are: none/mean/sum") + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + shards = ctx.shards + kwargs_to_shard = ctx.kwargs_to_shard + kwargs_to_pass = ctx.kwargs_to_pass + output_reduction = ctx.output_reduction + + grad_requiring_tensor_key = ctx.grad_requiring_tensor_key + grad_requiring_tensor_key_index = ctx.grad_requiring_tensor_key_index + compute_params = ctx.compute_params + output_unshard_dimension = ctx.output_unshard_dimension + grad_requiring_tensor = kwargs_to_shard[grad_requiring_tensor_key] + + grad_requiring_tensor_requires_grad = grad_requiring_tensor.requires_grad + grad_requiring_tensor = grad_requiring_tensor.detach() + # detach() unsets `grad_requiring_tensor.requires_grad`, so restore it + grad_requiring_tensor.requires_grad_(grad_requiring_tensor_requires_grad) + + incoming_grad = grads[0] + # since we perform a reduction of outputs that doesn't get included in `autograd.backward` below we need to pre-adjust the incoming gradient. in the case of "sum" the gradient is 1.0, in the case of "mean" it's 1.0/num_elements, which in this case is 1/shards. + if output_reduction == "mean": + incoming_grad /= shards + + if grad_requiring_tensor.shape[0] == 1: + grad_requiring_tensor_grad = torch.zeros_like(grad_requiring_tensor) + else: + grad_requiring_tensor_grad = torch.empty_like(grad_requiring_tensor) + + kwargs_to_shard_shards = {k: list(torch.chunk(v, chunks=shards, dim=1)) for k, v in kwargs_to_shard.items()} + + for i in range(shards): + # when fn involves one or more model weights deepspeed will normally push a grad to + # reduce per sub-module call, so since we only want it to add a grad for the last + # shard's call, we signal to ZeRO not to add new gradients to reduce until the last + # shard when all gradients have been accumulated. An example for such a call is + # `model.lm_head(hidden_states)` + if compute_params is not None: + if i + 1 < shards: + for param in compute_params: + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + param.ds_grad_is_ready = True + + kwargs_to_shard_shard = {k: v[i] for k, v in kwargs_to_shard_shards.items()} + grad_requiring_tensor_shard = kwargs_to_shard_shard[grad_requiring_tensor_key] + + grad_requiring_tensor_shard.requires_grad_(grad_requiring_tensor_requires_grad) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = kwargs_to_shard_shards[grad_requiring_tensor_key][i].shape[1] + shard_offset = i * kwargs_to_shard_shards[grad_requiring_tensor_key][0].shape[1] + + if grad_requiring_tensor.shape[0] == 1: + # on narrow the shard's stride is unaffected with dim0==1 (bs) so we use the most efficient `narrow` alias: + # this will enable gradual population of the pre-allocated + # `grad_requiring_tensor_shard.grad` during `torch.autograd.backward` calls + grad_requiring_tensor_shard.grad = grad_requiring_tensor_grad.narrow( + 1, shard_offset, shard_step).view_as(grad_requiring_tensor_shard) + + with torch.enable_grad(): + output = fn(**kwargs_to_shard_shard, **kwargs_to_pass) + + if output_unshard_dimension == 0: + # loss use-case + torch.autograd.backward(output, incoming_grad) + else: + incoming_grad_shard = (incoming_grad.narrow(1, shard_offset, + shard_step).view_as(grad_requiring_tensor_shard)) + torch.autograd.backward(output, incoming_grad_shard) + + if grad_requiring_tensor.shape[0] > 1: + # this is less efficient than dim0==1 (bs) use case, due to a required copy to fix + # the stride and needing a bit more memory for one shard's grad, since + # narrow(dim=1, ...) while dim0>1 will lead to: + # UserWarning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. + # when backward is called. + grad_requiring_tensor_grad.narrow(1, shard_offset, + shard_step).view_as(grad_requiring_tensor_shard).copy_( + grad_requiring_tensor_shard.grad) + + # positional args + grad_outputs = [None] * 9 + # inject the grad for the position of forward input that is grad-requiring + arg_outputs = [None] * ctx.total_args + arg_outputs[grad_requiring_tensor_key_index] = grad_requiring_tensor_grad + + return tuple(grad_outputs + arg_outputs) + + +class TiledMLP(torch.autograd.Function): + """ + Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP when using very long sequence lengths. + + Please note this module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. And if you're using activation checkpointing it then occurs trice. + + For a general tiled compute implementation that can handle any `forward` see `SequenceTiledCompute`. + + Args: + - fn: the function to call on sharded inputs + - `self`: the MLP nn.Module object + - `x`: the input to MLP.forward (`hidden_states`) + - `shards`: how many shards to use + - compute_params: a list of weights engaged in the compute Default: `None` (only needed when using DeepSpeed ZeRO) + + Returns: + - the computed `hidden_states` + + Here is an example that monkey patches HF Transformers' LLamaMLP: + + def tiled_mlp_forward(self, x): + bs, seqlen, hidden = x.shape + num_shards = math.ceil(seqlen / hidden) + # to avoid deadlocks get all ranks to agree on the same num_shards by using the max value + tensor = torch.tensor(num_shards, device=x.device) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + num_shards = tensor.item() + compute_params = [self.down_proj.weight, self.gate_proj.weight, self.up_proj.weight] + + def mlp_forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return TiledMLP.apply( + mlp_forward, + self, + x, + num_shards, + compute_params, + ) + + # this needs to be done before the model is instantiated + from transformers.models.llama import modeling_llama + modeling_llama.LlamaMLP.forward = tiled_mlp_forward + """ + + @staticmethod + def forward( + ctx, + fn, + self, + x, + shards, + compute_params, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x, ) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + # detach() unsets `x.requires_grad`, so restore it + x.requires_grad_(x_requires_grad) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + hidden_size = x.shape[-1] + x_shape_orig = x.shape + + # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + x_grad = torch.zeros_like(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + for i, x_shard in enumerate(x_shards): + # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run + # XXX: DDP, FSDP will need something similar to make it work + if compute_params is not None: + if i + 1 < shards: + for param in compute_params: + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + param.ds_grad_is_ready = True + + x_shard.requires_grad_(x_requires_grad) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + with torch.enable_grad(): + output = fn(self, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + # unflatten + x_grad = x_grad.view(x_shape_orig) + + return (None, None, x_grad, None, None) + + +class TiledFusedLogitsLoss(torch.autograd.Function): + """ + Perform a tiled loss computation while not manifesting a full logits tensor to massively reduce memory usage. + + Args: + - fn: the function to call on sharded inputs + - `self`: the lm_head module object, often it will be `unwrapped_model.model.lm_head` + - `x`: the input (typically `hidden_states`) - which gets sharded + - `y`: the target (typically `labels` or `shift_labels`) - which gets sharded. + - `mask`: an optional mask. It will be not passed to the `fn` if set to `None`. If not-`None` it'll be sharded with `x` and `y` + - `shards`: how many shards to use + - compute_params: a list of weights engaged in the compute Default: `None` (only needed when using DeepSpeed ZeRO) + - output_reduction: "mean" or "sum". If the unmasked elements in `x` are of different sizes in different shards, it's recommended to use "sum" instead of "mean" and perform the balanced mean to the output. This would be the case if `x` is not evenly divisible by `shards` or if the mask may lead to a different number of unmasked elements. + + Returns: + - the computed `loss` + + Note, that since this autograd function is typically the last one in the call stack, it performs `backward` inside `forward` and compensates for `output_reduction` artificially. This removes the need to re-run `forward` a second time inside `backward` + + For a generic tiled compute implementation that can handle many other types of `forward` see `SequenceTiledCompute`. + + An example: + + def loss_fn(self, x, y): + logits = self.lm_head(x) + return self.cross_entropy_loss(logits.view(-1, self.vocab_size), y.view(-1)) + + x = hidden_states + y = shift_labels + mask = None + shards = 2 + compute_params = [self.lm_head.weight] + output_reduction = "mean" + loss = TiledFusedLogitsLoss.apply( + loss_fn, + self, + x, + y, + mask, + shards, + compute_params, + output_reduction, + ) + + """ + + @staticmethod + def forward( + ctx, + fn, + self, + x, + y, + mask, + shards, + compute_params, + output_reduction, + ) -> torch.Tensor: + + if output_reduction not in ["mean", "sum"]: + raise ValueError(f'unknown reduction {output_reduction}: valid values are: "mean"/"sum"') + if x.dim() < 2: + raise ValueError("x must be at least 2D [batch_size, seq_len, ...]") + if y.dim() < 2: + raise ValueError("y must be at least 2D [batch_size, seq_len, ...]") + if x.shape[:2] != y.shape[:2]: + raise ValueError("x and y batch/seq dims must match") + if mask is not None: + if mask.dim() != 2: + raise ValueError(f"mask must be 2D [batch_size, seq_len], but got {mask.dim()}") + if mask.shape != x.shape[:2]: + raise ValueError(f"mask shape must match x and y batch/seq") + + compute_params = [p for p in compute_params if p.requires_grad] + + x_requires_grad = x.requires_grad + x = x.detach().requires_grad_(x_requires_grad) + + bs, seqlen = x.shape[:2] + + # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 + x = x.view(-1, *x.shape[2:]) + y = y.view(-1, *y.shape[2:]) + if mask is not None: + mask = mask.view(-1) + incoming_grad = torch.tensor(1.0, dtype=x.dtype, device=x.device) + + # we are faking the incoming gradient, and since we perform a reduction outside of `autograd.backward` below we need to pre-adjust the incoming gradient. in the case of "sum" the gradient is 1.0, in the case of "mean" it's 1.0/num_elements, which in this case is 1/shards. + if output_reduction == "mean": + incoming_grad /= shards + + # XXX: deal with the use case of running in inference mode, where we don't need backward + x_grad = torch.zeros_like(x) if x_requires_grad else None + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + y_shards = list(torch.chunk(y, chunks=shards, dim=0)) + if mask is not None: + mask_shards = list(torch.chunk(mask, chunks=shards, dim=0)) + + output_shards = [] + for i, (x_shard, y_shard) in enumerate(zip(x_shards, y_shards)): + # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run + # XXX: DDP, FSDP will need something similar to make it work + if compute_params is not None: + if i + 1 < shards: + for param in compute_params: + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + param.ds_grad_is_ready = True + + x_shard.requires_grad_(x_requires_grad) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + args = (self, x_shard, y_shard) + if mask is not None: + args += (mask_shards[i], ) + if x_grad is not None: + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + with torch.enable_grad(): + output = fn(*args) + output_shards.append(output) + torch.autograd.backward(output, incoming_grad) + else: + output = fn(*args) + output_shards.append(output) + + output_unsharded = torch.cat([l.unsqueeze(0) for l in output_shards], dim=0) + + if output_reduction == "mean": + output = output_unsharded.mean() + elif output_reduction == "sum": + output = output_unsharded.sum() + + # unflatten + if x_grad is not None: + x_grad = x_grad.view(bs, seqlen, *x_grad.shape[1:]) + ctx.save_for_backward(x_grad.detach()) + + return output + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + (x_grad, ) = ctx.saved_tensors + # grads[0] should normally be 1.0 as it should be coming from loss.backward() + if grads[0] != 1.0: + x_grad *= grads[0] + return (None, None, x_grad, None, None, None, None, None, None) + + +class AutogradComputeMLP(torch.autograd.Function): + """ + This is a simplified example to override the normal MLP via an autograd function - then tiling can be added - this simplified version was useful to detect a leak in Deepspeed, so let's keep it. + + Here is an example of performing the monkey patching on LlamaMLP + + def mlp_forward_new(self, x): + + def mlp_forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return AutogradComputeMLP.apply(mlp_forward, self, x) + + from transformers.models.llama import modeling_llama + modeling_llama.LlamaMLP.forward = mlp_forward_new + """ + + @staticmethod + def forward( + ctx, + fn, + self, + x, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.save_for_backward(x) + + with torch.no_grad(): + return fn(self, x) + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x, ) = ctx.saved_tensors + self = ctx.self + + x1 = x.detach() + x1.requires_grad = x.requires_grad + with torch.enable_grad(): + output = fn(self, x1) + + torch.autograd.backward(output, grads[0]) + return (None, None, x1.grad, None) + + +########################################################### +### below are older versions that some might still want ### +########################################################### + + +class TiledLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, loss_fn, logits, vocab_size, shift_labels, shards) -> torch.Tensor: + """ + + This is a memory efficient loss autograd function that takes the existing logits and performs loss calculation in shards. + + This one is an SFT-aware version, therefore it takes care of special cases where the whole shard is made of -100 labels and which requires then a special care. + + Note: logits seqlen dimension doesn't have to be divisible by shards, the last shard will be shorter than the rest. The calculating of the number of shards is in the example. + + Here is an example of using it: + + def loss(self, batch) -> torch.Tensor: + batch = to_device(batch, self.device) + shift_labels = batch.pop("shift_labels") + outputs = self.model(**batch, use_cache=False) + logits = outputs.logits + + if all((shift_labels == -100).squeeze()): + # this is the case where all labels in a micro-batch are -100 (very common for SFT if the seqlen is short) - CE returns `nan` in this case, so we don't want to call loss and instead create a differentiable loss `0` which will also set all the grads to `0` in `backward` - the effect of this is akin to a perfect score where the model needs no adjustment since grads will be all zeros. + loss = (logits.sum() * 0.0).float() + + num_shards: Any = "auto" + if num_shards == "auto": + # parameterize to about 1GB fp32 logits shards + slice_size_in_gb = 1 + size_in_gb = logits.numel() * 4 / 2**30 # fp32 + # the sp shard's seqlen sp shard can be easily not divisible by the derived number of chunked loss shards, so we use the uppper ceiling and allow the last chunk to be shorter than the rest + num_shards = math.ceil(size_in_gb / slice_size_in_gb) + # print(f"derived {num_shards} shards for size {size_in_gb}GB") + if num_shards > 1: + # if shards == 1 this will lead to a higher memory usage then calling the normal loss function, so don't do that. + loss = TiledLoss.apply( + self.model_unwrapped.loss_function, + logits, + self.model_unwrapped.config.vocab_size, + shift_labels, + num_shards, + ) + else: + loss = self.model_unwrapped.loss_function( + logits=logits, + labels=None, + vocab_size=self.model_unwrapped.config.vocab_size, + shift_labels=shift_labels, + ) + + return loss + + + """ + ctx.save_for_backward(logits, shift_labels) + ctx.loss_fn = loss_fn + ctx.vocab_size = vocab_size + ctx.shards = shards + + with torch.no_grad(): + seqlen = shift_labels.shape[1] + shard_step = math.ceil(seqlen / shards) + loss_shards = [] + total_good_items = 0 + + # since -100s are ignored we have to perform a weighted average on each loss slice as each slice may contribute a different number of non- -100 labels + # if seqlen / shards != 0 - the last chunk is just shorter than the rest but no data is ignored + for i in range(shards): + # XXX: here and everywhere don't make a copy, pass the slice or perhaps narrow/view? + shift_labels_shard = shift_labels[:, i * shard_step:(i + 1) * shard_step] + if all((shift_labels_shard == -100).squeeze()): + continue # ignore this shard + loss_shard = loss_fn( + logits=logits[:, i * shard_step:(i + 1) * shard_step, :], + labels=None, + vocab_size=vocab_size, + shift_labels=shift_labels_shard, + ) + good_items = sum((shift_labels_shard != -100).squeeze()) + loss_shards.append(loss_shard * good_items) + total_good_items += good_items + total_loss = torch.cat([l.unsqueeze(0) for l in loss_shards], dim=0).sum() + weighted_loss = total_loss / total_good_items + + return weighted_loss + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + logits, shift_labels = ctx.saved_tensors + loss_fn = ctx.loss_fn + vocab_size = ctx.vocab_size + shards = ctx.shards + + grad = grads[0] + logits_grad = torch.zeros_like(logits) + logits_shards = list(torch.chunk(logits, chunks=shards, dim=1)) + shift_labels_shards = list(torch.chunk(shift_labels, chunks=shards, dim=1)) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = logits_shards[0].shape[1] + for i in range(shards): + logits_shard = logits_shards.pop(0) + shift_labels_shard = shift_labels_shards.pop(0) + + shard_offset = i * shard_step + # this will enable gradual population of the pre-allocated `logits_shard.grad` during `torch.autograd.backward` calls + logits_shard.grad = (logits_grad.narrow(1, shard_offset, shard_step).view_as(logits_shard)) + + with torch.enable_grad(): + if all((shift_labels_shard == -100).squeeze()): + # fake loss calculation, since CE will return nan, but grads will be set + # a normal loss_fn upcasts logits to float so match it + loss_shard = (logits_shard.sum() * 0.0).float() + else: + loss_shard = loss_fn( + logits=logits_shard.requires_grad_(), + labels=None, + vocab_size=vocab_size, + shift_labels=shift_labels_shard, + ) + + torch.autograd.backward(loss_shard, grad) + + logits_grad /= shards + + # only logits (2nd arg) needs grads + return None, logits_grad, None, None, None + + +# This is the original implementation/integration of UlyssesSP into the training loop, which was superseded by using UlyssesSPDataLoaderAdapter which did all the sharding and pull the shards from the DL +# +# There are 2 issues with this implementation: +# - it's complex and difficult to integrate into various training scenarios +# - it could lead to a huge number of tokens per step - e.g. 32 ranks of 15M seqlen -> 0.5B token step - which is very wasteful +# +# Therefore if you want to use UlyssesSP via UlyssesSPFwdLossBwdWithLogits with its fwd/loss/bwd for those don't want to use UlyssesSPDataLoaderAdapter - here is how it should be installed into the sub-trainer class: +# class SFTTrainer(Trainer): +# def sp_fwd_loss_bwd(self, batch) -> torch.Tensor: +# batch = to_device(batch, self.device) +# +# from arctic_training.trainer.trainer import UlyssesAttentionHFFwdLossBwdWithLogits +# ulysses = UlyssesAttentionHFFwdLossBwdWithLogits( +# model=self.model, +# model_unwrapped=self.model_unwrapped, +# device=self.device, +# num_loss_logit_shards="auto", +# ) +# return ulysses.sp_fwd_loss_bwd(batch) + + +class UlyssesSPFwdLossBwdWithLogits: + + def __init__(self, model, model_unwrapped, device, num_loss_logit_shards="auto", **kwargs): + + self.model = model + self.model_unwrapped = model_unwrapped + self.device = device + self.num_loss_logit_shards = num_loss_logit_shards + self.kwargs = kwargs + + from deepspeed.utils import groups + + self.sp_group = groups._get_sequence_parallel_group() + self.sp_world_size = groups._get_sequence_parallel_world_size() + self.sp_rank = groups._get_sequence_parallel_rank() + + def sp_fwd_loss_bwd(self, batch) -> torch.Tensor: + + see_memory_usage("entered sp_fwd_loss_bwd", force=True) + + # ensure shapes are correct + if not (batch["input_ids"].shape == batch["position_ids"].shape == batch["labels"].shape): + raise ValueError( + f'Borked batch {batch["input_ids"].shape=} != {batch["position_ids"].shape=} !=' + f' {batch["labels"].shape=}) in DataLoader->iter->next, cannot continue with Ulysses Sequence' + " parallelism") + + # gather DL batches into super-batches + # Important: DL doesn't always yield max_length batches. Different ranks may have different seqlen and each could be <= max_length (but always divisible by 256) + + micro_batches: list[Any] = defaultdict(dict) + # Efficient gathering of batch inputs across ranks: + # The problem is that our DL doesn't guarantee the same seqlen on all ranks and may give, 3x 1024 and 1x 768 on 4 gpus for max_length 1024. so 3 options we have to be able to gather batches are: + # 1. use all_gather_object - which allows different shapes - but potentially introducing an undesired overhead - 2x pickle calls + # 2. use all_gather and change DL pad to make sure that all ranks always get the same input shape - this creates its own overhead since if we say have ranks with seqlen 512, 768, 1024, 1024 - now we will need to process 4x 1024 seqlens + # 3. use all_gather and post gathering truncate tensors to their intended length - another overhead of allocating and truncating tensors + # using approach (1) for now but might want to benchmark later the other 2 approaches + + # XXX: if using all_gather_object we can gather the whole batch at once and not per-key! so can drop the loop for that approach + + # we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank + seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device) + # print(seqlen) + seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)] + dist.all_gather(seqlens, seqlen, group=self.sp_group) + seqlens = [x[0].item() for x in seqlens] + + for k in batch.keys(): + batch[k] = batch[k].to(self.device) + with torch.no_grad(): + tensor_list = [ + torch.zeros((batch[k].shape[0], seqlens[i]), dtype=batch[k].dtype, device=batch[k].device) + for i in range(self.sp_world_size) + ] + dist.all_gather(tensor_list, batch[k], group=self.sp_group) + + # gathering on the data dimension + # will be concatenating and later splitting again for the more general case + # batch[k] = torch.cat(tensor_list, dim=1) + for rank, tensor in enumerate(tensor_list): + micro_batches[rank][k] = tensor + + del tensor_list + del batch + + # we need to chunk twice - each time on SP size level + # - the first time is because we artificially made the seqlen SP-times longer + # - the second time is because of the Ulysses algorithm + + see_memory_usage("after gathering", force=False) + + self.model.set_gradient_accumulation_boundary(False) + + losses = [] + for sub_step_id in range(self.sp_world_size): + batch = micro_batches[sub_step_id] + seq_length = len(batch["input_ids"][0]) + + if seq_length % self.sp_world_size != 0: + raise ValueError( + f"{sub_step_id=}: batch's seqlen={seq_length} isn't divisible by sp-size={self.sp_world_size}") + chunk_len = int(seq_length / self.sp_world_size) + + # to enable the correct mean calculation across shards before sharding the micro batch: + # 1. count the number of non- `-100`` elements per shard + # 2. and subtract one more element because of label shifting + non_skipped_items = {} + for rank in range(self.sp_world_size): + non_skipped = (batch["labels"][:, chunk_len * rank:chunk_len * (rank + 1)] != -100).sum().item() + if non_skipped > 1: + non_skipped -= 1 + non_skipped_items[rank] = non_skipped + + # because we have to gather logits from all sp ranks we have to do the loss function ourselves + # therefore remove labels to avoid an attempt to calculate loss by transformers + labels = batch.pop("labels") + labels = torch.nn.functional.pad(labels, (0, 1), value=-100) + batch["shift_labels"] = labels[..., 1:].contiguous() + # free up temp memory + del labels + + # batch sharding + for k in batch.keys(): + batch[k] = batch[k][:, chunk_len * self.sp_rank:chunk_len * (self.sp_rank + 1)].to(self.device) + + shift_labels = batch.pop("shift_labels") + + outputs = self.forward(batch) + loss = self.compute_loss(labels=None, shift_labels=shift_labels) + + # free up temp mem (e.g. outputs.logits are huge) + del outputs + + # differentiable loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group) + + # since each shard may have a variable number of skipped elemented - need to calculate a weighted mean depending on each rank's contribution - this will also take care of loss=0 when all elements are -100 in a shard + # XXX: not expecting a total of 0-non-skipped items for div + loss = sum(losses_per_rank[rank] * non_skipped_items[rank] + for rank in range(self.sp_world_size)) / sum(non_skipped_items.values()) + + self.backward() + + losses.append(loss.detach().item()) + + self.model.set_gradient_accumulation_boundary(True) + + # for per-iteration reporting + if len(losses) == 0: + loss = float("nan") + else: + loss = sum(losses) / len(losses) + + return loss + + def forward(self, batch): + # critical: the labels shouldn't be in batch + outputs = self.model(**batch, use_cache=False) + self.logits = outputs.logits + return outputs + + def compute_loss(self, labels, shift_labels): + if all((shift_labels == -100).squeeze()): + # this is the case where all labels in a micro-batch are -100 (very common for SFT) - CE returns `nan` in this case, so we don't want to call loss and instead create a differentiable loss `0` which will also set all the grads to `0` in `backward` - the effect of this is akin to a perfect score where the model needs no adjustment since grads will be all zeros. + # XXX: should this be float and not the original dtype? + loss = (self.logits.sum() * 0.0).float() + else: + if self.num_loss_logit_shards == "auto": + # parameterize to about 1GB fp32 logits shards + slice_size_in_gb = 1 # XXX: make configurable? + size_in_gb = self.logits.numel() * 4 / 2**30 # fp32 + # the sp shard's seqlen sp shard can be easily not divisible by the derived number of chunked loss shards, so we use the uppper ceiling and allow the last chunk to be shorter than the rest + self.num_loss_logit_shards = math.ceil(size_in_gb / slice_size_in_gb) + # print(f"derived {self.num_loss_logit_shards} shards for size {size_in_gb}GB") + if self.num_loss_logit_shards > 1: + loss = TiledLoss.apply( + self.model_unwrapped.loss_function, + self.logits, + self.model_unwrapped.config.vocab_size, + shift_labels, + self.num_loss_logit_shards, + ) + else: + # XXX: for some reason this fails with zero1 + loss = self.model_unwrapped.loss_function( + logits=self.logits, + labels=None, + vocab_size=self.model_unwrapped.config.vocab_size, + shift_labels=shift_labels, + ) + + self.loss = loss + return loss + + def backward(self): + self.model.backward(self.loss) diff --git a/deepspeed/runtime/sparse_tensor.py b/deepspeed/runtime/sparse_tensor.py index f0bb5c75530e..f87c4e418e8b 100644 --- a/deepspeed/runtime/sparse_tensor.py +++ b/deepspeed/runtime/sparse_tensor.py @@ -15,9 +15,10 @@ class SparseTensor(object): def __init__(self, dense_tensor=None): self.orig_dense_tensor = dense_tensor - self.is_sparse = dense_tensor.is_sparse if dense_tensor is not None: - if dense_tensor.is_sparse: + self.is_sparse = dense_tensor.is_sparse + self.dtype = self.orig_dense_tensor.dtype + if self.is_sparse: dense_tensor = dense_tensor.coalesce() self.indices = dense_tensor.indices().flatten() self.values = dense_tensor.values() diff --git a/deepspeed/runtime/superoffload/__init__.py b/deepspeed/runtime/superoffload/__init__.py new file mode 100644 index 000000000000..6f5f5619004b --- /dev/null +++ b/deepspeed/runtime/superoffload/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/superoffload/superoffload_stage3.py b/deepspeed/runtime/superoffload/superoffload_stage3.py new file mode 100644 index 000000000000..7c496a3dda37 --- /dev/null +++ b/deepspeed/runtime/superoffload/superoffload_stage3.py @@ -0,0 +1,375 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +import torch +from typing import List + +from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes +from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.utils.nvtx import instrument_w_nvtx +from deepspeed.utils import logger +from deepspeed.accelerator import get_accelerator + +OPTIMIZER_STEP_TIMER = 'optimizer_step' + + +def _validate_superoffload_accelerator(): + """Validate that the current accelerator is compatible with SuperOffload.""" + accelerator = get_accelerator() + assert accelerator.device_name() == 'cuda', ( + f"SuperOffload only supports NVIDIA CUDA GPUs, but found accelerator '{accelerator.device_name()}'.") + + +class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3): + + def __init__( + self, + module, + init_optimizer, + param_names, + timers, + ds_config, + **kwargs, + ): + _validate_superoffload_accelerator() + + self.sub_group_to_param_num = {} + self.sub_group_grad_partition_counts = {} + self.async_cpuadam_num = 0 + self.max_grad_numel = 0 + + super().__init__(module, init_optimizer, param_names, timers, ds_config, **kwargs) + + optimizer_configs = [] + for pg in self.optimizer.param_groups: + optimizer_configs.append({ + "lr": pg["lr"], + "betas": pg["betas"], + "eps": pg["eps"], + "weight_decay": pg["weight_decay"], + "amsgrad": pg["amsgrad"], + }) + cpuadam_cores_perc = kwargs.get("cpuadam_cores_perc", 0.8) + self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_configs, + cpuadam_cores_perc=cpuadam_cores_perc, + max_grad_numel=self.max_grad_numel) + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partition_numel() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + global_idx = len(self.sub_group_to_param_num) + self.sub_group_to_param_num[global_idx] = len(params_group) + self.max_grad_numel = max(self.max_grad_numel, params_group_numel) + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + + for param in params_group: + sub_group.append(param) + local_sub_group_size += param.partition_numel() + + if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]): + self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size) + sub_groups.append(sub_group) + global_idx = len(self.sub_group_to_param_num) + self.sub_group_to_param_num[global_idx] = len(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + def step_with_gradscaler(optimizer): + if self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.step(optimizer) + self.torch_autocast_gradscaler.update() + else: + optimizer.step() + + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device != 'cpu': + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + step_with_gradscaler(self.backup_optimizer) + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + super().independent_gradient_partition_epilogue() + self.sub_group_grad_partition_counts.clear() + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.subgroup_to_device[sub_group_id] == 'cpu': + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + self._unflatten_partitioned_parameters(sub_group_id) + return + + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param): + """Asynchronously update partitioned parameters with optimized values.""" + self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True) + + @instrument_w_nvtx + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + completed_sub_groups = [] + + for param, grad_partition in zip(params_to_release, grad_partitions): + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + + # Accumulate gradient into the grad_buffer, mirroring base class logic + grad_buffer = self._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id].narrow( + 0, 0, grad_partition.numel()) + if self.micro_step_id == 0: + grad_buffer.copy_(grad_partition, non_blocking=True) + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif get_accelerator().on_accelerator(grad_buffer): + grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(grad_buffer.shape)) + else: + cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + cuda_grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + grad_buffer = cuda_grad_buffer + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_buffer) + + fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.to(dtype=self.master_weights_and_grads_dtype), non_blocking=True) + + self.sub_group_grad_partition_counts[i] = self.sub_group_grad_partition_counts.get(i, 0) + 1 + if self.sub_group_grad_partition_counts[i] == self.sub_group_to_param_num[i]: + completed_sub_groups.append(i) + + if self.is_gradient_accumulation_boundary and completed_sub_groups: + get_accelerator().current_stream().synchronize() + for i in completed_sub_groups: + if self.subgroup_to_device[i] == 'cpu' and not self.clip_grad: + param_group_id = self.sub_group_to_group_id[i] + fp32_param = self.fp32_partitioned_groups_flat[i] + current_lr = self.optimizer.param_groups[param_group_id]['lr'] + + self.superoffload_cpu_optimizer.async_step(param_group_id, + i, + fp32_param.data, + fp32_param.grad.data, + lr=current_lr) + self.async_cpuadam_num += 1 + + result = self.superoffload_cpu_optimizer.get_result() + if result is not None: + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + self.async_cpuadam_num -= 1 + + for param in params_to_release: + if not get_accelerator().is_synchronized_device(): + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) + param.grad = None + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + self._wait_for_async_operations() + + self._pre_step() + self._partition_all_parameters() + + if self._overflow_check_and_loss_scale_update(): + if not self.clip_grad: + self._handle_overflow_rollback() + return + + norm_groups = self._get_norm_groups() + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale + + timer_names = set() + timer_names.add(OPTIMIZER_STEP_TIMER) + self.timers(OPTIMIZER_STEP_TIMER).start() + + if self.clip_grad: + self._step_with_clipping(scaled_global_grad_norm, timer_names) + else: + self._step_without_clipping(scaled_global_grad_norm, timer_names) + + self.timers(OPTIMIZER_STEP_TIMER).stop() + self._post_step(timer_names) + + def _step_without_clipping(self, scaled_global_grad_norm, timer_names): + """Fast path: async CPU steps already completed during backward.""" + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + self._optimizer_step(sub_group_id) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + + def _step_with_clipping(self, scaled_global_grad_norm, timer_names): + """Clipping path: no async steps were done during backward, + so we unscale+clip first, then step all sub-groups.""" + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + current_lr = self.optimizer.param_groups[param_group_id]['lr'] + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + lr=current_lr) + else: + self._optimizer_step(sub_group_id) + + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + + def _wait_for_async_operations(self, timeout_seconds=60): + """Wait for all pending asynchronous CPU optimizer operations to complete with timeout error. + + Args: + timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds. + """ + if self.async_cpuadam_num > 0: + logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...") + if self.async_cpuadam_num == 0: + return + + start_time = time.time() + initial_pending_ops = self.async_cpuadam_num + + while self.async_cpuadam_num > 0: + result = self.superoffload_cpu_optimizer.get_result() + if result is None: + current_time = time.time() + elapsed_time = current_time - start_time + + # Throw error if we've been waiting longer than the timeout + if elapsed_time >= timeout_seconds: + raise RuntimeError( + f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. " + f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. " + f"This indicates a deadlock or critical performance issue in the CPU optimizer.") + + time.sleep(0.001) # 1ms sleep + continue + + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + self.async_cpuadam_num -= 1 + + def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60): + """Wait for a single asynchronous CPU-Adam optimizer operation with timeout. + + Args: + event_type (str): Type of operation expected ('adam_step' or 'rollback'). + timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds. + """ + start_time = time.time() + + while True: + result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type) + if result is not None: + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + break + + current_time = time.time() + elapsed_time = current_time - start_time + + # Throw error if we've been waiting longer than the timeout + if elapsed_time >= timeout_seconds: + raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. " + f"This indicates a deadlock or critical performance issue in the CPU optimizer.") + + time.sleep(0.001) # 1ms sleep + + def _sync_cpu_optimizer_step(self, + param_group_id: int, + sub_group_id: int, + fp32_param_data, + fp32_grad_data, + rollback: bool = False, + lr: float = None, + timeout_seconds: int = 60): + event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP + self.superoffload_cpu_optimizer.async_step(param_group_id, + sub_group_id, + fp32_param_data, + fp32_grad_data, + rollback=rollback, + lr=lr) + # Wait for completion + self._wait_for_single_async_result(event_type, timeout_seconds) + + def _handle_overflow_rollback(self): + """Handle gradient overflow by rolling back CPU optimizer states.""" + for sub_group_id, _ in enumerate(self.fp16_groups): + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + # Trigger rollback + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=True) + + def _handle_gradient_clipping(self, scaled_global_grad_norm): + """Handle gradient clipping with CPU optimizer rollback and re-optimization.""" + for sub_group_id, _ in enumerate(self.fp16_groups): + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + # Rollback CPU optimizer states + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=True) + + # Clip gradients and re-optimize + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + + current_lr = self.optimizer.param_groups[param_group_id]['lr'] + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=False, + lr=current_lr) + + @instrument_w_nvtx + def check_clip_grads(self, total_norm): + """Check if gradients need to be clipped based on the global norm.""" + unscaled_norm = total_norm / self.loss_scale + return self.clip_grad and unscaled_norm > self.clip_grad diff --git a/deepspeed/runtime/superoffload/superoffload_utils.py b/deepspeed/runtime/superoffload/superoffload_utils.py new file mode 100644 index 000000000000..c8a734b7c48c --- /dev/null +++ b/deepspeed/runtime/superoffload/superoffload_utils.py @@ -0,0 +1,297 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +SuperOffload utilities for 1) running CPU optimizers in separate processes. + +""" + +from typing import Dict, Optional, Any +import torch +import torch.multiprocessing as mp +import psutil + +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.utils import logger + + +class TaskKeys: + PARAM_DATA = "param_data" + PARAM_GRAD = "param_grad" + PARAM_GROUP_ID = "param_group_id" + SUB_GROUP_ID = "sub_group_id" + ROLLBACK = "rollback" + LR = "lr" + + +class ResultKeys: + UPDATED_PARAM = "updated_param" + EVENT_TYPE = "event_type" + + +class EventTypes: + ADAM_STEP = "adam_step" + ROLLBACK = "rollback" + + +def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue, + optimizer_config: Dict[str, Any], max_grad_numel: int) -> None: + """ + This function runs in a separate process and continuously processes optimization + tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and + applies optimization steps to parameters received from the main process. + + Args: + param_queue: Queue for receiving optimization tasks + result_queue: Queue for sending back optimization results + optimizer_config: Configuration dictionary for the optimizer containing + lr, betas, eps, weight_decay, and amsgrad parameters + max_grad_numel: Maximum number of elements expected in gradient tensors + """ + cpu_tensor = torch.randn(1, device="cpu") + cpu_param = torch.nn.Parameter(cpu_tensor) + + try: + if isinstance(optimizer_config, list): + pg_configs = optimizer_config + else: + pg_configs = [optimizer_config] + + first_cfg = pg_configs[0] + optimizer = DeepSpeedCPUAdam([cpu_param], + lr=first_cfg["lr"], + betas=first_cfg["betas"], + eps=first_cfg["eps"], + weight_decay=first_cfg["weight_decay"], + amsgrad=first_cfg["amsgrad"]) + for cfg in pg_configs[1:]: + dummy = torch.nn.Parameter(torch.randn(1, device="cpu")) + optimizer.add_param_group({ + "params": [dummy], + "lr": cfg["lr"], + "betas": cfg["betas"], + "eps": cfg["eps"], + "weight_decay": cfg["weight_decay"], + "amsgrad": cfg["amsgrad"], + }) + except KeyError as e: + error_msg = f"Missing required optimizer config key: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + return + + # Pre-allocate reusable pinned memory buffer for gradients + pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True) + + while True: + try: + task = param_queue.get() + + if task is None: + logger.debug("Received termination signal, shutting down worker") + break + + param_data = task[TaskKeys.PARAM_DATA] + param_grad = task[TaskKeys.PARAM_GRAD] + param_group_id = task[TaskKeys.PARAM_GROUP_ID] + sub_group_id = task[TaskKeys.SUB_GROUP_ID] + rollback = task.get(TaskKeys.ROLLBACK, False) + task_lr = task.get(TaskKeys.LR, None) + + logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}") + + del task[TaskKeys.PARAM_DATA] + del task[TaskKeys.PARAM_GRAD] + task.clear() + + if task_lr is not None: + optimizer.param_groups[param_group_id]['lr'] = task_lr + + grad_numel = param_grad.numel() + if grad_numel > max_grad_numel: + error_msg = ( + f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. " + f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.") + result_queue.put({"error": error_msg}) + break + + param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad) + param_grad_cpu.copy_(param_grad, non_blocking=False) + + fp32_param = torch.nn.Parameter(param_data) + fp32_param.grad = param_grad_cpu + + optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + if rollback: + logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}") + optimizer.rollback_subgroup(sub_group_id) + else: + optimizer.step_subgroup(sub_group_id) + + # Send result back to main process + event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP + result_queue.put({ + TaskKeys.PARAM_GROUP_ID: param_group_id, + TaskKeys.SUB_GROUP_ID: sub_group_id, + ResultKeys.UPDATED_PARAM: fp32_param.data, + ResultKeys.EVENT_TYPE: event_type, + }) + + # Clean up references to free memory + optimizer.param_groups[param_group_id]['params'] = [] + del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data + + except KeyError as e: + error_msg = f"Missing required task key: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + break + except Exception as e: + error_msg = f"Unexpected error in worker process: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + break + + # Clean up pinned memory buffer + if 'pinned_grad_buffer' in locals(): + del pinned_grad_buffer + logger.debug("Cleaned up pinned memory buffer") + + logger.debug("Worker process terminated") + + +class SuperOffloadCPUOptimizer: + + def __init__(self, + optimizer_config: Dict[str, Any], + cpuadam_cores_perc: float = 0.8, + max_grad_numel: int = 1000000) -> None: + if not 0 < cpuadam_cores_perc <= 1: + raise ValueError("cpuadam_cores_perc must be between 0 and 1") + + self.max_grad_numel = max_grad_numel + self.mp_context = mp.get_context('spawn') + self.param_queue = self.mp_context.SimpleQueue() + self.result_queue = self.mp_context.SimpleQueue() + + self.cpuadam_process = self.mp_context.Process( + target=superoffload_optimizer_worker, + args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel), + daemon=True, + ) + self.cpuadam_process.start() + + # Set CPU affinity for better performance isolation + self._set_cpu_affinity(cpuadam_cores_perc) + + def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None: + """ + Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process. + + Args: + cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process + """ + try: + current_process = psutil.Process() + all_cores = current_process.cpu_affinity() + num_cores = len(all_cores) + + split_idx = int((1 - cpuadam_cores_perc) * num_cores) + pt_cores = all_cores[:split_idx] + cpuadam_cores = all_cores[split_idx:] + + # Set affinity for main process (PyTorch) + current_process.cpu_affinity(pt_cores) + + # Set affinity for optimizer process (CPU Adam) + optimizer_process = psutil.Process(self.cpuadam_process.pid) + optimizer_process.cpu_affinity(cpuadam_cores) + + logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, " + f"Optimizer cores: {cpuadam_cores}") + + except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e: + logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}") + except Exception as e: + logger.warning(f"Unexpected error setting CPU affinity: {e}") + + def async_step(self, + param_group_id: int, + sub_group_id: int, + fp32_param: torch.Tensor, + fp32_grad: torch.Tensor, + rollback: bool = False, + lr: float = None) -> None: + """ + Queue parameter for optimization in the worker process. + """ + if not self.cpuadam_process.is_alive(): + raise RuntimeError("Worker process is not alive") + + task = { + TaskKeys.PARAM_DATA: fp32_param, + TaskKeys.PARAM_GRAD: fp32_grad, + TaskKeys.PARAM_GROUP_ID: param_group_id, + TaskKeys.SUB_GROUP_ID: sub_group_id, + TaskKeys.ROLLBACK: rollback, + } + if lr is not None: + task[TaskKeys.LR] = lr + self.param_queue.put(task) + + def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]: + """ + Get result from worker process with optional event type validation. + + Args: + expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback'). + If provided, validates that the result matches. + """ + if self.result_queue.empty(): + return None + + result = self.result_queue.get() + + if "error" in result: + raise RuntimeError(f"Error in worker process: {result['error']}") + + # Validate event type if expected_event_type is provided + if expected_event_type is not None: + result_event_type = result.get(ResultKeys.EVENT_TYPE) + if result_event_type != expected_event_type: + raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'") + + return result + + def close(self) -> None: + """ + Shutdown the worker process gracefully. + + Sends termination signal to worker and waits for clean shutdown. + If the process doesn't terminate within the timeout, it will be forcefully killed. + """ + if not self.cpuadam_process.is_alive(): + logger.debug("Worker process already terminated") + return + + # Send termination signal + self.param_queue.put(None) + + # Wait for graceful shutdown + self.cpuadam_process.join(timeout=5) + + if self.cpuadam_process.is_alive(): + logger.warning("Optimizer process did not terminate cleanly within timeout, " + "forcefully terminating") + self.cpuadam_process.terminate() + self.cpuadam_process.join(timeout=2) + + # Last resort: kill the process + if self.cpuadam_process.is_alive(): + logger.error("Failed to terminate optimizer process, killing it") + self.cpuadam_process.kill() + self.cpuadam_process.join() + + logger.debug("SuperOffload CPU optimizer closed successfully") diff --git a/deepspeed/runtime/swap_tensor/__init__.py b/deepspeed/runtime/swap_tensor/__init__.py index 208299fb8c50..006dfd6dcbc6 100644 --- a/deepspeed/runtime/swap_tensor/__init__.py +++ b/deepspeed/runtime/swap_tensor/__init__.py @@ -2,3 +2,4 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +from .utils import MIN_SWAPPABLE_BYTES diff --git a/deepspeed/runtime/swap_tensor/aio_config.py b/deepspeed/runtime/swap_tensor/aio_config.py index df4a38380089..be6c7d93c86a 100644 --- a/deepspeed/runtime/swap_tensor/aio_config.py +++ b/deepspeed/runtime/swap_tensor/aio_config.py @@ -5,25 +5,39 @@ from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.swap_tensor.constants import * +from deepspeed.accelerator import get_accelerator AIO_DEFAULT_DICT = { AIO_BLOCK_SIZE: AIO_BLOCK_SIZE_DEFAULT, AIO_QUEUE_DEPTH: AIO_QUEUE_DEPTH_DEFAULT, - AIO_THREAD_COUNT: AIO_THREAD_COUNT_DEFAULT, + AIO_INTRA_OP_PARALLELISM: AIO_INTRA_OP_PARALLELISM_DEFAULT, AIO_SINGLE_SUBMIT: AIO_SINGLE_SUBMIT_DEFAULT, - AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT + AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT, + AIO_USE_GDS: AIO_USE_GDS_DEFAULT } def get_aio_config(param_dict): if AIO in param_dict.keys() and param_dict[AIO] is not None: aio_dict = param_dict[AIO] - return { - AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT), - AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT), - AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT), - AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT), - AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT) + aio_config = { + AIO_BLOCK_SIZE: + get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT), + AIO_QUEUE_DEPTH: + get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT), + AIO_INTRA_OP_PARALLELISM: + get_scalar_param(aio_dict, AIO_INTRA_OP_PARALLELISM, AIO_INTRA_OP_PARALLELISM_DEFAULT), + AIO_SINGLE_SUBMIT: + get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT), + AIO_OVERLAP_EVENTS: + get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT), + AIO_USE_GDS: + get_scalar_param(aio_dict, AIO_USE_GDS, AIO_USE_GDS_DEFAULT) } + if aio_config[AIO_USE_GDS]: + assert get_accelerator().device_name() == 'cuda', 'GDS currently only supported for CUDA accelerator' + + return aio_config + return AIO_DEFAULT_DICT diff --git a/deepspeed/runtime/swap_tensor/async_swapper.py b/deepspeed/runtime/swap_tensor/async_swapper.py index 152b5b70a515..b808721537fe 100644 --- a/deepspeed/runtime/swap_tensor/async_swapper.py +++ b/deepspeed/runtime/swap_tensor/async_swapper.py @@ -10,6 +10,7 @@ from deepspeed import comm as dist from deepspeed.utils.logging import logger from deepspeed.runtime.swap_tensor.utils import swap_out_tensors, SwapBuffer +from deepspeed.accelerator import get_accelerator INVALID_BUFFER_INDEX = -1 ASYNC_SWAPPER_WAIT_TIMER = 'async_swap_gradient_wait' @@ -37,7 +38,7 @@ def has_buffers(self): def add_buffers(self, buffer_list): assert len(self.all_buffers) == 0 - assert all([buffer.is_pinned() for buffer in buffer_list]) + assert all([get_accelerator().is_pinned(buffer) for buffer in buffer_list]) dtype = buffer_list[0].dtype assert all([buffer.dtype == dtype for buffer in buffer_list]) diff --git a/deepspeed/runtime/swap_tensor/constants.py b/deepspeed/runtime/swap_tensor/constants.py index 4c9722bc4e4f..c1207749eac6 100644 --- a/deepspeed/runtime/swap_tensor/constants.py +++ b/deepspeed/runtime/swap_tensor/constants.py @@ -9,9 +9,10 @@ "aio": { "block_size": 1048576, "queue_depth": 8, - "thread_count": 1, + "intra_op_parallelism": 1, "single_submit": false, - "overlap_events": true + "overlap_events": true, + "use_gds": false } ''' AIO = "aio" @@ -19,9 +20,11 @@ AIO_BLOCK_SIZE_DEFAULT = 1048576 AIO_QUEUE_DEPTH = "queue_depth" AIO_QUEUE_DEPTH_DEFAULT = 8 -AIO_THREAD_COUNT = "thread_count" -AIO_THREAD_COUNT_DEFAULT = 1 +AIO_INTRA_OP_PARALLELISM = "intra_op_parallelism" +AIO_INTRA_OP_PARALLELISM_DEFAULT = 1 AIO_SINGLE_SUBMIT = "single_submit" AIO_SINGLE_SUBMIT_DEFAULT = False AIO_OVERLAP_EVENTS = "overlap_events" AIO_OVERLAP_EVENTS_DEFAULT = True +AIO_USE_GDS = "use_gds" +AIO_USE_GDS_DEFAULT = False diff --git a/deepspeed/runtime/swap_tensor/optimizer_utils.py b/deepspeed/runtime/swap_tensor/optimizer_utils.py index 12be256f8055..4c858b8cf049 100644 --- a/deepspeed/runtime/swap_tensor/optimizer_utils.py +++ b/deepspeed/runtime/swap_tensor/optimizer_utils.py @@ -15,6 +15,7 @@ from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, \ MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPool +from deepspeed.accelerator import get_accelerator class FlattenedTensorSwapInfo(object): @@ -25,36 +26,54 @@ def __init__(self, path, length, offset): self.length = length +class SwapTensorContext(object): + + def __init__(self, tensor, swap_folder): + self.compute_tensor = tensor + self.swap_tensor = torch.Tensor() + self.swap_path = os.path.join(swap_folder, f'{OptimizerSwapper.parameter_id(tensor)}.tensor.swp') + + def release_memory(self): + self.compute_tensor.data = torch.Tensor() + self.swap_tensor.data = torch.Tensor() + + def set_buffers(self, compute_buffer, swap_buffer): + self.compute_tensor.data = compute_buffer.data + self.swap_tensor.data = swap_buffer.data + + class OptimizerStateSwapInfo(object): def __init__(self, parameter, numel, base_folder): self.tensors = [] - self.param_id = id(parameter) + self.param_id = OptimizerSwapper.parameter_id(parameter) self.swap_folder = base_folder - self.swap_paths = [] self.swapped_gradients = {} self.unswapped_gradients = {} self.tensor_numel = numel self.tensor_dtype = parameter.dtype self.tensor_device = parameter.device self.has_state_tensors = False + self.swap_buffers = [] self._add_tensors([parameter]) def numel(self): return self.tensor_numel def has_gradients(self): - return self.swapped_gradients or self.unswapped_gradients + return bool(self.swapped_gradients) or bool(self.unswapped_gradients) def _add_tensors(self, tensor_list): for t in tensor_list: - self.tensors.append(t) - self.swap_paths.append(os.path.join(self.swap_folder, f'{id(t)}.tensor.swp')) + self.tensors.append(SwapTensorContext(t, self.swap_folder)) def add_state_tensors(self, tensor_list): self.has_state_tensors = True self._add_tensors(tensor_list) + def num_tensors(self): + return len(self.tensors) + def device(self): return self.tensor_device @@ -62,13 +81,28 @@ def dtype(self): return self.tensor_dtype def release_memory(self): - for tensor in self.tensors: - tensor.data = torch.Tensor() + for t in self.tensors: + t.release_memory() + + def get_compute_tensors(self): + return [t.compute_tensor for t in self.tensors] + + def get_swap_paths(self): + return [t.swap_path for t in self.tensors] + + def get_swap_buffers_and_paths(self, pinned): + swap_buffers = [] + swap_paths = [] + select_tensors = [t for t in self.tensors if get_accelerator().is_pinned(t.compute_tensor) == pinned] + for t in select_tensors: + swap_buffers.append(t.swap_tensor if pinned else t.compute_tensor) + swap_paths.append(t.swap_path) + return swap_buffers, swap_paths def get_or_create_gradient_paths(self, offsets, lengths): gradient_paths = [] for offset, length in zip(offsets, lengths): - if not offset in self.swapped_gradients.keys(): + if offset not in self.swapped_gradients.keys(): path = os.path.join(self.swap_folder, f'{self.param_id}_gradient_{offset}_{length}.tensor.swp') self.swapped_gradients[offset] = FlattenedTensorSwapInfo(path, length, offset) @@ -76,11 +110,15 @@ def get_or_create_gradient_paths(self, offsets, lengths): return gradient_paths - def set_swap_buffers(self, buffers): - compute_lengths = [self.numel()] * len(self.tensors) + def set_swap_buffers(self, buffers, aligned_numel): + num_tensors = len(self.tensors) + compute_lengths = [self.numel()] * num_tensors compute_buffers = get_sized_buffers(buffers, compute_lengths) - for t, buffer in zip(self.tensors, compute_buffers): - t.data = buffer.data + swap_lengths = [aligned_numel] * num_tensors + swap_buffers = get_sized_buffers(buffers, swap_lengths) + + for i, t in enumerate(self.tensors): + t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i]) def get_swap_gradient_buffers(self, swap_buffer): assert self.numel() <= swap_buffer.numel() @@ -90,7 +128,7 @@ def get_swap_gradient_paths(self): return [grad.path for grad in self.swapped_gradients.values()] def get_unpinned_state_tensors(self): - return [t for t in self.tensors if not t.is_pinned()] + return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)] def read_unswapped_gradients(self, dest_buffer): num_elem_count = 0 @@ -101,6 +139,15 @@ def read_unswapped_gradients(self, dest_buffer): return num_elem_count + def write_unswapped_gradients(self, src_buffer): + num_elem_count = 0 + for offset, grad_partition in self.unswapped_gradients.items(): + src_tensor = src_buffer.narrow(0, offset, grad_partition.numel()) + grad_partition.data.copy_(src_tensor.data) + num_elem_count += grad_partition.numel() + + return num_elem_count + def release_unswapped_gradients(self): self.unswapped_gradients = {} @@ -111,6 +158,10 @@ def release_unswapped_gradients(self): class OptimizerSwapper(object): + @staticmethod + def parameter_id(param): + return param.ds_id + def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers): self.swap_config = swap_config self.aio_config = aio_config @@ -125,7 +176,7 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume # Read/Write alignment for each thread during Intra-request parallelism self.min_aio_bytes = max(MIN_AIO_BYTES, aio_config[AIO_BLOCK_SIZE]) - self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_THREAD_COUNT] + self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_INTRA_OP_PARALLELISM] self.numel_alignment = self.aligned_bytes // self.swap_element_size # Swap buffer management @@ -148,10 +199,15 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume 'timer_names', ] - def swappable_tensor(self, param=None, numel=None): - assert param is not None or numel is not None, "Either param or numel must be provided" - if param is not None: - return self.min_aio_bytes <= (param.numel() * self.swap_element_size) + def purge_state(self): + for swap_info in self.swap_params_info.values(): + swap_info.tensors = [swap_info.tensors[0]] + swap_info.has_state_tensors = False + + def is_swappable_tensor(self, tensor=None, numel=None): + assert tensor is not None or numel is not None, "Either tensor or numel must be provided" + if tensor is not None: + return self.min_aio_bytes <= (tensor.numel() * self.swap_element_size) return self.min_aio_bytes <= (numel * self.swap_element_size) def init_timers(self): @@ -177,10 +233,10 @@ def _flush_gradient_swapper(self, gradient_swapper): self.timer_names.update(gradient_swapper.get_timer_names()) def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gradient_swapper): - if not id(parameter) in self.swap_params_info.keys(): + if OptimizerSwapper.parameter_id(parameter) not in self.swap_params_info.keys(): return - swap_info = self.swap_params_info[id(parameter)] + swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)] swappable_tensors = [] swappable_offsets = [] @@ -191,7 +247,7 @@ def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gra self._start_timer(SWAP_OUT_GRADIENT_TIMER) for tensor, offset in zip(aligned_gradients, aligned_offsets): - if not self.swappable_tensor(param=tensor): + if not self.is_swappable_tensor(tensor=tensor): swap_info.unswapped_gradients[offset] = tensor continue @@ -216,7 +272,7 @@ def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, fp16_pinned_buffers, fp32_parameters): assert len(fp32_parameters) == len(fp16_partitions_info) assert len(fp32_parameters) == len(fp16_num_elems) - assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers]) + assert all([get_accelerator().is_pinned(buffer) for buffer in fp16_pinned_buffers]) fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters, num_elems=fp16_num_elems) @@ -240,7 +296,7 @@ def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, for i, tensor in enumerate(fp16_pinned_tensors): true_index = curr_index + i logger.info( - f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}' + f'swap_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}' ) swap_out_count = self._swap_out_fp16_params(aio_handle=aio_handle, @@ -287,7 +343,8 @@ def _swap_in_fp16_params(self, aio_handle, fp16_num_elems, fp16_partitions_info, for src, dst in zip(unswapped_srcs, unswapped_dsts): dst.data.copy_(src.data) - assert len(swap_tensors) == aio_handle.wait() + if len(swap_tensors) > 0: + assert len(swap_tensors) == aio_handle.wait() return swapped_fp16_tensors @@ -329,7 +386,7 @@ def _initialize_parameters(self, parameters, src_tensors, aio_handle): if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE: for i, tensor in enumerate(src_tensors): logger.info( - f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}' + f'copy_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}' ) self.swap_buffer_manager.free(pinned_buffers) @@ -345,7 +402,7 @@ def _get_swap_paths(self, parameters, num_elems): ] assert len(swap_info_list) == len(num_elems) - swap_paths = [info.swap_paths[0] for info in swap_info_list] + swap_paths = [info.tensors[0].swap_path for info in swap_info_list] return swap_paths def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers): @@ -376,7 +433,7 @@ def _adjust_for_misaligned_lengths(self, tensors, offsets): new_offsets = [] for orig_tensor, orig_offset in zip(tensors, offsets): - if not self.swappable_tensor(param=orig_tensor): + if not self.is_swappable_tensor(tensor=orig_tensor): new_tensors.append(orig_tensor) new_offsets.append(orig_offset) continue @@ -415,12 +472,13 @@ def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer): ) def _get_state_tensors(self, parameter): - if not parameter in self.optimizer.state: + if parameter not in self.optimizer.state: return [] tensor_list = [] - for value in self.optimizer.state[parameter].values(): - if torch.is_tensor(value): + for state_name, value in self.optimizer.state[parameter].items(): + if torch.is_tensor(value) and self.is_swappable_tensor(tensor=value): + value.ds_id = state_name + '-' + parameter.ds_id tensor_list.append(value) return tensor_list @@ -432,8 +490,8 @@ def _update_param_state_info(self, swap_info, parameter): swap_info.add_state_tensors(state_tensors) def _create_param_swap_info(self, parameter, numel): - param_id = id(parameter) - assert not param_id in self.swap_params_info + param_id = OptimizerSwapper.parameter_id(parameter) + assert param_id not in self.swap_params_info self.swap_params_info[param_id] = OptimizerStateSwapInfo(parameter=parameter, numel=numel, @@ -445,7 +503,7 @@ def _create_param_swap_info(self, parameter, numel): return swap_info def _get_param_swap_info(self, parameter): - param_id = id(parameter) + param_id = OptimizerSwapper.parameter_id(parameter) swap_info = self.swap_params_info.get(param_id, None) if swap_info is not None: diff --git a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py index 677bc2aa4a8e..52b873ba58a1 100644 --- a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py @@ -6,8 +6,6 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -import torch - from deepspeed.utils.logging import logger from deepspeed.ops.op_builder import AsyncIOBuilder from deepspeed import comm as dist @@ -17,6 +15,7 @@ get_sized_buffers from deepspeed.runtime.swap_tensor.async_swapper import AsyncTensorSwapper from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper +from deepspeed.accelerator import get_accelerator DEBUG_MODE = False @@ -32,9 +31,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume largest_numel, device, dtype, timers) aio_op = AsyncIOBuilder().load() - self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], - aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS], - aio_config[AIO_THREAD_COUNT]) + self.aio_handle = aio_op.aio_handle(block_size=aio_config[AIO_BLOCK_SIZE], + queue_depth=aio_config[AIO_QUEUE_DEPTH], + single_submit=aio_config[AIO_SINGLE_SUBMIT], + overlap_events=aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM]) # Overlap swapping out self.gradient_swapper = AsyncTensorSwapper(aio_handle=self.aio_handle, @@ -60,6 +61,15 @@ def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_ele def flush_gradients(self): self._flush_gradient_swapper(self.gradient_swapper) + def release_swap_buffers(self, parameter): + swap_info = self._get_param_swap_info(parameter) + if swap_info is None: + return + swap_info.release_memory() + + self.swap_buffer_manager.free(swap_info.swap_buffers) + swap_info.swap_buffers = [] + def swap_in_optimizer_state(self, parameter, async_parameter=None): swap_info = self._get_param_swap_info(parameter) if swap_info is None: @@ -67,64 +77,82 @@ def swap_in_optimizer_state(self, parameter, async_parameter=None): self._flush_gradient_swapper(self.gradient_swapper) - required_buffer_count = len(swap_info.tensors) + (1 if swap_info.has_gradients() else 0) + required_buffer_count = swap_info.num_tensors() + (1 if swap_info.has_gradients() else 0) aligned_numel = self._io_aligned_numel(swap_info.numel()) pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel, count=required_buffer_count, dtype=parameter.dtype) assert pinned_buffers is not None - self.allocated_swap_buffers = pinned_buffers.copy() + swap_info.swap_buffers = pinned_buffers.copy() self._start_timer(SWAP_IN_PARAM_TIMER) self._swap_in_parameter(aio_handle=self.aio_handle, parameter=parameter, - dest_buffers=pinned_buffers[:required_buffer_count]) + dest_buffers=pinned_buffers[:swap_info.num_tensors()]) self._stop_timer(SWAP_IN_PARAM_TIMER) self.timer_names.add(SWAP_IN_PARAM_TIMER) - self._start_timer(SWAP_IN_GRADIENT_TIMER) - self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1]) - self._stop_timer(SWAP_IN_GRADIENT_TIMER) - self.timer_names.add(SWAP_IN_GRADIENT_TIMER) - - def swap_out_optimizer_state(self, parameter, async_swap=False): - swap_info = self._get_param_swap_info(parameter=parameter) - - if swap_info is None: - return - - self._start_timer(SWAP_OUT_PARAM_TIMER) - pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info) - swap_bytes = sum([self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors]) + if swap_info.has_gradients(): + self._start_timer(SWAP_IN_GRADIENT_TIMER) + self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1]) + self._stop_timer(SWAP_IN_GRADIENT_TIMER) + self.timer_names.add(SWAP_IN_GRADIENT_TIMER) + def _swap_out_optimizer_state(self, swap_info): + pinned_tensors, pinned_paths = swap_info.get_swap_buffers_and_paths(True) WRITE_TIMER = 'swap_submit_write' self._start_timer(WRITE_TIMER) swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths) assert self.aio_handle.wait() == len(pinned_tensors) - for t in pinned_tensors: - t.data = torch.Tensor() + unpinned_tensors, unpinned_paths = swap_info.get_swap_buffers_and_paths(False) if len(unpinned_tensors) > 0: pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype) self._swap_out_unpinned_tensors(aio_handle=self.aio_handle, unpinned_tensors=unpinned_tensors, dest_paths=unpinned_paths, pinned_buffers=pinned_buffers) - self.allocated_swap_buffers += pinned_buffers + swap_info.swap_buffers += pinned_buffers.copy() - for t in unpinned_tensors: - t.data = torch.Tensor() self._stop_timer(WRITE_TIMER) + self._log_timers([WRITE_TIMER]) + + def writeback_optimizer_state_and_gradients(self, parameter, write_opt_state, write_gradients): + swap_info = self._get_param_swap_info(parameter=parameter) + + if swap_info is None: + return - self.swap_buffer_manager.free(self.allocated_swap_buffers) - self.allocated_swap_buffers = [] + if write_opt_state: + self._swap_out_optimizer_state(swap_info) + if write_gradients and swap_info.has_gradients(): + param_gradients = swap_info.swapped_gradients.values() + swap_buffers = [parameter.grad.narrow(0, grad.offset, grad.length) for grad in param_gradients] + swap_paths = [grad.path for grad in param_gradients] + swap_out_tensors(self.aio_handle, swap_buffers, swap_paths) + assert len(swap_buffers) == self.aio_handle.wait() + if swap_info.unswapped_gradients: + swap_info.write_unswapped_gradients(src_buffer=parameter.grad) + + self.release_swap_buffers(parameter) + + def swap_out_optimizer_state(self, parameter, async_swap=False): + swap_info = self._get_param_swap_info(parameter=parameter) + + if swap_info is None: + return + + swap_bytes = sum( + [self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.get_compute_tensors()]) + + self._start_timer(SWAP_OUT_PARAM_TIMER) + self._swap_out_optimizer_state(swap_info) + self.release_swap_buffers(parameter) self._stop_timer(SWAP_OUT_PARAM_TIMER) self.timer_names.add(SWAP_OUT_PARAM_TIMER) - self._log_timers([WRITE_TIMER]) - if DEBUG_MODE and dist.get_rank() == 0: logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB') @@ -139,16 +167,20 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): if swap_info is None: return - assert len(swap_info.tensors) <= len(dest_buffers) + num_swap_tensors = swap_info.num_tensors() + assert num_swap_tensors <= len(dest_buffers) - swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(swap_info.tensors) + swap_lengths = [self._io_aligned_numel(swap_info.numel())] * num_swap_tensors swap_buffers = get_sized_buffers(dest_buffers, swap_lengths) + compute_lengths = [swap_info.numel()] * num_swap_tensors + compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) + READ_TIMER = 'swap_submit_read_param' WAIT_TIMER = 'swap_wait_read_param' self._start_timer(READ_TIMER) - swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths) + swap_in_tensors(aio_handle, swap_buffers, swap_info.get_swap_paths()) self._stop_timer(READ_TIMER) swap_bytes = sum([buffer.numel() * buffer.element_size() for buffer in swap_buffers]) @@ -157,40 +189,19 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): aio_handle.wait() self._stop_timer(WAIT_TIMER) - compute_lengths = [swap_info.numel()] * len(swap_info.tensors) - compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) - for t, buffer in zip(swap_info.tensors, compute_buffers): - t.data = buffer.data + swap_info.set_swap_buffers(dest_buffers, self._io_aligned_numel(swap_info.numel())) self._log_timers([READ_TIMER, WAIT_TIMER]) if DEBUG_MODE and dist.get_rank() == 0: logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB') - def _separate_pinned_tensors(self, swap_info): - pinned_tensors = [] - pinned_paths = [] - - unpinned_tensors = [] - unpinned_paths = [] - - for tensor, path in zip(swap_info.tensors, swap_info.swap_paths): - if tensor.is_pinned(): - pinned_tensors.append(tensor) - pinned_paths.append(path) - else: - unpinned_tensors.append(tensor) - unpinned_paths.append(path) - - return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths - def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor): - swap_info = self.swap_params_info[id(parameter)] + swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)] param_gradients = swap_info.swapped_gradients.values() swap_buffers = [gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients] swap_paths = [grad.path for grad in param_gradients] SWAP_READ_GRADIENTS = 'swap_submit_read_gradient' SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient' - self._start_timer(SWAP_READ_GRADIENTS) swap_in_tensors(aio_handle, swap_buffers, swap_paths) self._stop_timer(SWAP_READ_GRADIENTS) @@ -202,11 +213,11 @@ def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor): self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS]) def _swap_in_gradients(self, aio_handle, parameter, dest_buffer): - swap_info = self.swap_params_info.get(id(parameter), None) + swap_info = self.swap_params_info.get(OptimizerSwapper.parameter_id(parameter), None) if not (swap_info and swap_info.has_gradients()): return - assert dest_buffer.is_pinned() + assert get_accelerator().is_pinned(dest_buffer) assert parameter.numel() <= dest_buffer.numel() parameter.grad = dest_buffer.narrow(0, 0, parameter.numel()) diff --git a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py index 4109e0954148..06a0301462a4 100644 --- a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py +++ b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py @@ -13,6 +13,7 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import AsyncIOBuilder +from deepspeed.ops.op_builder import GDSBuilder from .constants import * from .utils import swap_in_tensors, swap_out_tensors, MIN_AIO_BYTES, AIO_ALIGNED_BYTES, print_object, SwapBufferPool @@ -37,8 +38,6 @@ class AsyncPartitionedParameterSwapper(object): def __init__(self, ds_config, model_dtype): - aio_op = AsyncIOBuilder().load(verbose=False) - self.aio_handle = aio_op.aio_handle self.dtype = model_dtype #set swap buffers, create aio handles @@ -93,9 +92,13 @@ def _configure_aio(self, ds_config): self.aio_config = ds_config.aio_config + self.use_gds = self.aio_config[AIO_USE_GDS] + self.aio_handle = GDSBuilder().load(verbose=False).gds_handle if self.use_gds else AsyncIOBuilder().load( + verbose=False).aio_handle + # Read/Write alignment for each thread during Intra-request parallelism self.min_aio_bytes = max(MIN_AIO_BYTES, self.aio_config[AIO_BLOCK_SIZE]) - self.aligned_bytes = AIO_ALIGNED_BYTES * self.aio_config[AIO_THREAD_COUNT] + self.aligned_bytes = AIO_ALIGNED_BYTES * self.aio_config[AIO_INTRA_OP_PARALLELISM] self.numel_alignment = self.aligned_bytes // self.swap_element_size self.elements_per_buffer = self.swap_config.buffer_size @@ -104,18 +107,28 @@ def _configure_aio(self, ds_config): self.available_buffer_ids = [i for i in range(self.param_buffer_count)] self.reserved_buffer_ids = [] - self.buffers = get_accelerator().pin_memory( - torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count), - dtype=self.dtype, - requires_grad=False)) - - self.aio_read_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE], self.aio_config[AIO_QUEUE_DEPTH], - self.aio_config[AIO_SINGLE_SUBMIT], self.aio_config[AIO_OVERLAP_EVENTS], - self.aio_config[AIO_THREAD_COUNT]) - self.aio_write_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE], self.aio_config[AIO_QUEUE_DEPTH], - self.aio_config[AIO_SINGLE_SUBMIT], - self.aio_config[AIO_OVERLAP_EVENTS], self.aio_config[AIO_THREAD_COUNT]) + self.aio_read_handle = self.aio_handle(block_size=self.aio_config[AIO_BLOCK_SIZE], + queue_depth=self.aio_config[AIO_QUEUE_DEPTH], + single_submit=self.aio_config[AIO_SINGLE_SUBMIT], + overlap_events=self.aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=self.aio_config[AIO_INTRA_OP_PARALLELISM]) + + self.aio_write_handle = self.aio_handle(block_size=self.aio_config[AIO_BLOCK_SIZE], + queue_depth=self.aio_config[AIO_QUEUE_DEPTH], + single_submit=self.aio_config[AIO_SINGLE_SUBMIT], + overlap_events=self.aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=self.aio_config[AIO_INTRA_OP_PARALLELISM]) + + buffer_device = get_accelerator().device_name() if self.use_gds else "cpu" + self.buffers = torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count), + dtype=self.dtype, + device=buffer_device, + requires_grad=False) + if self.use_gds: + self.aio_read_handle.pin_device_tensor(self.buffers) + else: + self.buffers = get_accelerator().pin_memory(self.buffers, align_bytes=0) self.swap_out_params = [] @@ -313,7 +326,8 @@ def swap_in(self, params, async_op=True, swap_in_buffers=None): def swap_into_buffer(self, param, dest_buffer): assert param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE, f"param {param.ds_id} is already available or inflight" - require_swap_buffer = not (dest_buffer.is_pinned() and self._is_io_aligned(dest_buffer.numel())) + require_swap_buffer = not (get_accelerator().is_pinned(dest_buffer) + and self._is_io_aligned(dest_buffer.numel())) if require_swap_buffer: assert len(self.available_buffer_ids) > 0, f"No buffer available to swap param {param.ds_id}." @@ -340,7 +354,7 @@ def get_buffer(self, param, numel): assert self.available_swap_in_buffers( ) > 0, f"No swap buffers to allocate for fp16 param {param_id} of numel = {numel}" - assert numel < self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}" + assert numel <= self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}" self.param_id_to_numel[param_id] = numel buffer_id = self.available_buffer_ids.pop() @@ -378,13 +392,15 @@ def _is_io_aligned(self, numel): def reserve_partitioned_swap_space(self, partition_num_elems): aligned_numel = sum([self._io_aligned_numel(numel) for numel in partition_num_elems]) - self.partitioned_swap_buffer = get_accelerator().pin_memory( - torch.zeros(aligned_numel, device='cpu', dtype=self.dtype)) + self.partitioned_swap_buffer = get_accelerator().pin_memory(torch.zeros(aligned_numel, + device='cpu', + dtype=self.dtype), + align_bytes=0) self.partitioned_swap_pool = SwapBufferPool([self.partitioned_swap_buffer]) def swap_out_partitioned_params(self, dst_fp16_params, src_fp32_params): - assert self.partitioned_swap_buffer is not None, f'partitioned swap buffers for fp16 params not initialized' - assert self.partitioned_swap_pool is not None, f'partitioned swap pool for fp16 params not initialized' + assert self.partitioned_swap_buffer is not None, 'partitioned swap buffers for fp16 params not initialized' + assert self.partitioned_swap_pool is not None, 'partitioned swap pool for fp16 params not initialized' assert len(dst_fp16_params) == len(src_fp32_params), \ f'mismatch in number of fp16 params {len(dst_fp16_params)} and fp32 params {len(src_fp32_params)}' diff --git a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py index cb00e3dc2fad..17d7a655c86f 100644 --- a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py @@ -8,6 +8,7 @@ from deepspeed.ops.op_builder import AsyncIOBuilder from deepspeed import comm as dist +import torch from deepspeed.runtime.swap_tensor.constants import * from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object @@ -28,7 +29,7 @@ def __init__(self, aio_handle, read_op, param_info, allocated_buffers, state_buf self.num_ops = num_ops def is_parameter(self, parameter): - return id(parameter) == self.param_info.param_id + return OptimizerSwapper.parameter_id(parameter) == self.param_info.param_id def wait(self): assert self.wait_required @@ -55,13 +56,17 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume device, dtype, timers) aio_op = AsyncIOBuilder().load() - self.write_aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], - aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS], - aio_config[AIO_THREAD_COUNT]) - - self.read_aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], - aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS], - aio_config[AIO_THREAD_COUNT]) + self.write_aio_handle = aio_op.aio_handle(block_size=aio_config[AIO_BLOCK_SIZE], + queue_depth=aio_config[AIO_QUEUE_DEPTH], + single_submit=aio_config[AIO_SINGLE_SUBMIT], + overlap_events=aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM]) + + self.read_aio_handle = aio_op.aio_handle(block_size=aio_config[AIO_BLOCK_SIZE], + queue_depth=aio_config[AIO_QUEUE_DEPTH], + single_submit=aio_config[AIO_SINGLE_SUBMIT], + overlap_events=aio_config[AIO_OVERLAP_EVENTS], + intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM]) # Overlap gradient swap out self.gradient_swapper = AsyncTensorSwapper(aio_handle=self.write_aio_handle, @@ -154,6 +159,8 @@ def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors): def _complete_swap_out(self, swap_out_type): self.swap_ops[swap_out_type].wait() + for buffer in self.swap_ops[swap_out_type].state_buffers: + buffer = torch.Tensor() self.swap_buffer_manager.free(self.swap_ops[swap_out_type].allocated_buffers) self.swap_ops[swap_out_type] = None @@ -180,7 +187,7 @@ def _swap_out_optimizer_state(self, aio_handle, parameter, swap_in_op): dst = get_sized_buffer(pinned_dst, unpinned_src.numel()) dst.data.copy_(unpinned_src.data) - swap_paths = param_info.swap_paths.copy() + swap_paths = param_info.get_swap_paths() assert len(swap_paths) == len(swap_buffers) swap_out_tensors(aio_handle, swap_buffers, swap_paths) @@ -199,19 +206,20 @@ def _swap_in_optimizer_state(self, aio_handle, parameter): if param_info is None: return None - required_buffer_count = len(param_info.tensors) + (1 if param_info.has_gradients() else 0) + num_swap_tensors = param_info.num_tensors() + required_buffer_count = num_swap_tensors + (1 if param_info.has_gradients() else 0) aligned_numel = self._io_aligned_numel(param_info.numel()) allocated_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel, count=required_buffer_count, dtype=parameter.dtype) assert allocated_buffers is not None, \ - f"PipelinedOptimizerSwapper ran out of swap buffers, try increasing 'buffer_count'" + "PipelinedOptimizerSwapper ran out of swap buffers, try increasing 'buffer_count'" - state_buffers = allocated_buffers[:len(param_info.tensors)] - param_info.set_swap_buffers(state_buffers) + state_buffers = allocated_buffers[:num_swap_tensors] + param_info.set_swap_buffers(state_buffers, aligned_numel) swap_buffers = state_buffers.copy() - swap_paths = param_info.swap_paths.copy() + swap_paths = param_info.get_swap_paths() if param_info.has_gradients(): parameter.grad = allocated_buffers[-1].narrow(0, 0, param_info.numel()) diff --git a/deepspeed/runtime/swap_tensor/utils.py b/deepspeed/runtime/swap_tensor/utils.py index 50a88f74351a..3cfe95c13088 100644 --- a/deepspeed/runtime/swap_tensor/utils.py +++ b/deepspeed/runtime/swap_tensor/utils.py @@ -14,22 +14,23 @@ MIN_AIO_BYTES = 1024**2 AIO_ALIGNED_BYTES = 1024 +MIN_SWAPPABLE_BYTES = MIN_AIO_BYTES def swap_in_tensors(swap_handle, tensor_buffers, swap_paths): for buffer, path in zip(tensor_buffers, swap_paths): - assert (swap_handle.async_pread(buffer, path) == 0) + assert (swap_handle.async_pread(buffer, path, 0) == 0) def swap_out_tensors(swap_handle, tensor_buffers, swap_paths): for buffer, path in zip(tensor_buffers, swap_paths): - assert (swap_handle.async_pwrite(buffer, path) == 0) + assert (swap_handle.async_pwrite(buffer, path, 0) == 0) def print_object(obj, name, exclude_list=[]): logger.info('{}:'.format(name)) for arg in sorted(vars(obj)): - if not arg in exclude_list: + if arg not in exclude_list: dots = '.' * (29 - len(arg)) logger.info(' {} {} {}'.format(arg, dots, getattr(obj, arg))) @@ -54,7 +55,7 @@ def insert_tensor(self, tensor, swap_path, aligned_numel): def allocate_tensor(self, swap_path, numel, aligned_numel): assert self.has_space(aligned_numel) - assert not self.offset in self.swap_tensors + assert self.offset not in self.swap_tensors allocate_offset = self.offset swap_tensor = self.buffer.narrow(0, allocate_offset, aligned_numel) @@ -96,7 +97,7 @@ def get_swap_path(self, offset): class SwapBufferPool(object): def __init__(self, buffers): - assert all([buf.is_pinned() for buf in buffers]) + assert all([get_accelerator().is_pinned(buf) for buf in buffers]) self.buffers = [SwapBuffer(buf) for buf in buffers] self.current_index = 0 @@ -184,7 +185,8 @@ def __init__(self, num_elems, count, dtype): self.count = count self.dtype = dtype self.all_buffers = [ - get_accelerator().pin_memory(torch.zeros(num_elems, device='cpu', dtype=dtype)) for _ in range(count) + get_accelerator().pin_memory(torch.zeros(num_elems, device='cpu', dtype=dtype), align_bytes=0) + for _ in range(count) ] self.free_buffer_index = [i for i in range(count)] self.used_buffer_index = {} diff --git a/deepspeed/runtime/tensor_parallel/__init__.py b/deepspeed/runtime/tensor_parallel/__init__.py new file mode 100644 index 000000000000..388239345351 --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .config import AUTOTP_MODE, get_tensor_parallel_config +from .tp_manager import TpTrainingManager diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py new file mode 100644 index 000000000000..56abf5868d5d --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/config.py @@ -0,0 +1,163 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +import torch +from pydantic import Field +from typing import Optional, Dict, Any + + +class AUTOTP_MODE(Enum): + TRAINING = "TRAINING" + INFERENCE = "INFERENCE" + + +class TPConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + tp_size: int = 1 + """ Number of devices to split the model across using tensor parallelism. """ + + tp_grain_size: int = 1 + "The variable required by the autoTP parser has not been activated in training yet" + "as it depends on the gather logic that supports uneven partitioning. " + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + + mpu: object = None + """ + A model parallelism unit object that implements + ``get_{model,data}_parallel_{rank,group,world_size}()``. + """ + + tp_group: object = None + + +class TPTrainingConfig(DeepSpeedConfigModel): + + dtype: torch.dtype = torch.float16 + """ + Desired model data type, will convert model to this type. + """ + + autotp_size: int = 0 + """ + In automatic tensor-parallelism training, 'tensor_parallel_size' + When set to 0, indicates that it is disabled. + """ + tp_overlap_comm: bool = False + """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """ + + tensor_parallel: TPConfig = Field({}, alias="tp") + """ + Configuration for tensor parallelism used to split the model across several + GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`. + """ + + injection_policy_tuple: Optional[tuple] = None + + # New configurable AutoTP settings + partition_config: Optional[Dict[str, Any]] = None + """ + Configuration for the new configurable AutoTP API. + Allows users to specify custom layer partitioning rules via TPLayerSpec. + + Example: + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + """ + + preset_model: Optional[str] = None + """ + Use a built-in preset for common model architectures. + Available presets: "llama", "bloom", "chatglm", "mixtral", "deepseek_v2", "qwen2", "phi3" + """ + + #The following parameters are required by autoTP parser. + ######################################## + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In very large models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + """ + + replace_with_kernel_inject: bool = Field(False, alias="kernel_inject") + """ + Set to true to inject inference kernels for models such as, Bert, GPT2, + GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two + linear layers as a tuple: + `(attention_output projection, transformer output projection)` + """ + + ######################################## + + def get_partition_config_object(self): + """ + Get the AutoTPConfig object from the configuration. + Returns None if no custom config is specified. + """ + from deepspeed.module_inject.autotp_config import AutoTPConfig, AutoTPPresets, merge_autotp_configs + + config = None + + # First check for preset + if self.preset_model: + config = AutoTPPresets.get_preset(self.preset_model) + + # Then check for custom config + if self.partition_config: + custom_config = AutoTPConfig.from_dict(self.partition_config) + if config and custom_config.use_default_specs: + config = merge_autotp_configs(config, custom_config) + else: + config = custom_config + + if config: + config.tp_size = self.autotp_size + + return config + + +def get_tensor_parallel_config(ds_config): + + if 'tensor_parallel' in ds_config: + return TPTrainingConfig(**ds_config['tensor_parallel']) + return TPTrainingConfig() + + +def _get_hf_tp_plan(model): + """Extract tp_plan from HuggingFace model. + + Prefer base_model_tp_plan (from model config) over _tp_plan (runtime attribute) + because _tp_plan often contains duplicate entries with a 'model.' prefix added + by HuggingFace, which causes spurious duplicate-match warnings during conversion. + """ + config = getattr(model, 'config', None) + if config and getattr(config, 'base_model_tp_plan', None): + return model.config.base_model_tp_plan + + if getattr(model, '_tp_plan', None): + return model._tp_plan + + return None diff --git a/deepspeed/runtime/tensor_parallel/init_utils.py b/deepspeed/runtime/tensor_parallel/init_utils.py new file mode 100644 index 000000000000..453f00af6db9 --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/init_utils.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import base64 +import os +from typing import Optional, Union + +import hjson +import torch + +from deepspeed.runtime.config_utils import dict_raise_error_on_duplicate_keys + +_TP_MODEL_INIT_ARGS = None + + +def load_ds_config(config: Union[str, dict]) -> dict: + if isinstance(config, dict): + return config + if isinstance(config, str): + if os.path.exists(config): + return hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys) + try: + config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') + return hjson.loads(config_decoded) + except (UnicodeDecodeError, AttributeError, ValueError) as exc: + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " + f"Received: {config}") from exc + raise ValueError(f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " + f"Received: {config}") + + +def record_tp_model_init_args(tp_size, dtype, tp_group, dist_module): + global _TP_MODEL_INIT_ARGS + new_args = { + "tp_size": tp_size, + "dtype": dtype, + "tp_group": tp_group, + } + + if _TP_MODEL_INIT_ARGS is None: + _TP_MODEL_INIT_ARGS = new_args + return + + if _TP_MODEL_INIT_ARGS["tp_size"] != tp_size or _TP_MODEL_INIT_ARGS["dtype"] != dtype: + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + existing_group = _TP_MODEL_INIT_ARGS.get("tp_group") + if existing_group is None and tp_group is None: + return + if (existing_group is None) != (tp_group is None): + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + existing_group_size = tp_group_world_size(existing_group, dist_module) + new_group_size = tp_group_world_size(tp_group, dist_module) + if existing_group_size != new_group_size: + raise ValueError("Conflicting tp_model_init arguments detected across multiple calls.") + + +def tp_group_world_size(tp_group, dist_module): + if tp_group is None or dist_module is None: + return None + return dist_module.get_world_size(group=tp_group) + + +def infer_config_dtype(config_dict: dict) -> Optional[torch.dtype]: + bf16_config = config_dict.get("bf16", {}) + if isinstance(bf16_config, dict) and bf16_config.get("enabled", False): + return torch.bfloat16 + fp16_config = config_dict.get("fp16", {}) + if isinstance(fp16_config, dict) and fp16_config.get("enabled", False): + return torch.float16 + return None + + +def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_module): + if _TP_MODEL_INIT_ARGS is None: + return + + tp_size = _TP_MODEL_INIT_ARGS["tp_size"] + dtype = _TP_MODEL_INIT_ARGS["dtype"] + tp_group = _TP_MODEL_INIT_ARGS["tp_group"] + + if tp_group is not None and mpu is not None: + raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.") + if tp_group is None and mpu is None and mesh_param is None: + # Auto-create TP groups for compatibility with HF Trainer (mpu is not passed). + from deepspeed.utils import groups + groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size) + + tp_section = config_dict.get("tensor_parallel") + if tp_section is None: + tp_section = {} + config_dict["tensor_parallel"] = tp_section + + config_autotp_size = tp_section.get("autotp_size") + if config_autotp_size is not None and config_autotp_size != tp_size: + raise ValueError( + f"Conflicting tensor_parallel.autotp_size in config ({config_autotp_size}) and tp_model_init ({tp_size}).") + + if config_autotp_size is None: + tp_section["autotp_size"] = tp_size + + tp_config = tp_section.get("tp") or {} + if not isinstance(tp_config, dict): + raise ValueError("tensor_parallel.tp must be a dict when provided.") + + config_tp_size = tp_config.get("tp_size") + if config_tp_size is not None and config_tp_size != tp_size: + raise ValueError( + f"Conflicting tensor_parallel.tp.tp_size in config ({config_tp_size}) and tp_model_init ({tp_size}).") + if config_tp_size is None: + tp_config["tp_size"] = tp_size + + if tp_group is not None: + config_tp_group = tp_config.get("tp_group") + if config_tp_group is not None and config_tp_group is not tp_group: + raise ValueError("Conflicting tensor_parallel.tp.tp_group in config and tp_model_init.") + tp_config["tp_group"] = tp_group + + tp_group_size = tp_group_world_size(tp_group, dist_module) + if tp_group_size is not None and tp_group_size != tp_size: + raise ValueError(f"tp_model_init tp_size ({tp_size}) does not match tp_group size ({tp_group_size}).") + + tp_section["tp"] = tp_config + + config_dtype = infer_config_dtype(config_dict) + if config_dtype is not None and config_dtype != dtype: + raise ValueError(f"Conflicting dtype: config uses {config_dtype} but tp_model_init requested {dtype}.") + + tp_dtype = tp_section.get("dtype") + if tp_dtype is not None: + if isinstance(tp_dtype, str): + tp_dtype_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + tp_dtype_value = tp_dtype_map.get(tp_dtype.lower()) + else: + tp_dtype_value = tp_dtype + if tp_dtype_value is not None and tp_dtype_value != dtype: + raise ValueError(f"Conflicting tensor_parallel.dtype in config ({tp_dtype}) and tp_model_init ({dtype}).") diff --git a/deepspeed/runtime/tensor_parallel/tp_manager.py b/deepspeed/runtime/tensor_parallel/tp_manager.py new file mode 100644 index 000000000000..cf0b5a75c92a --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/tp_manager.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from .config import TPTrainingConfig, TPConfig +from deepspeed.utils import groups +import deepspeed.comm as dist + + +class TpTrainingManager(): + + def __init__(self, model, tp_size, dtype): + self.module = model + self.config = self._initialize_config(dtype) + + from deepspeed.module_inject.auto_tp import AutoTP + from deepspeed import get_accelerator + + # Parse model configuration + parser_dict = AutoTP.tp_parser(model) + print("AutoTP: ", parser_dict) + + # Initialize TP configuration and model + self._initialize_tp_config(tp_size) + self._get_model_config_generate() + + # Synchronize random number generator state across devices + _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) + dist.broadcast(_rng_state, groups.get_tensor_model_parallel_src_rank(), self.tp_config.tp_group) + get_accelerator().set_rng_state(_rng_state.cpu()) + + # Apply injection policies + self._apply_policies(parser_dict) + + def _initialize_config(self, dtype): + """Initialize and return the DeepSpeed TP training configuration.""" + config = TPTrainingConfig() + config.dtype = dtype + return config + + def _apply_policies(self, parser_dict): + """Apply injection policies to the parsed modules.""" + for client_module, injection_policy in parser_dict: + self.config.injection_policy_tuple = injection_policy + self._apply_injection_policy(self.config, client_module) + + def _apply_injection_policy(self, config, client_module=None): + from deepspeed.module_inject import replace_transformer_layer + """Apply the given injection policy to a client module.""" + if isinstance(self.module, torch.nn.Module): + replace_transformer_layer(client_module, self.module, None, self.config, self.model_config) + + def _initialize_tp_config(self, tp_size): + """Perform TP configuration initialization.""" + self.tp_config = TPConfig() + self.tp_config.tp_size = tp_size + + groups._init_tp_mesh_device(tp_size) + self.tp_config.tp_group = groups.get_tensor_model_parallel_group() + self.config.tensor_parallel = self.tp_config + + def _get_model_config_generate(self): + """Generate and apply HF model configuration.""" + self.model_config = getattr(self.module, 'config', None) diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py new file mode 100644 index 000000000000..299693fdaab5 --- /dev/null +++ b/deepspeed/runtime/torch_autocast.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Set, List, Union +import importlib +from contextlib import contextmanager + +import torch +import deepspeed.comm as dist +from deepspeed.utils import logger +from deepspeed.accelerator import get_accelerator + +LOWER_PRECISION_SAFE_MODULES = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, +] + +PARAM_COMM_DTYPE_ATTR_NAME = "comm_dtype" +_WARNED_NESTED_AUTOCAST = False + +# TODO: Avoid using global variables +TORCH_AUTOCAST_INITIALIZED = False +TORCH_AUTOCAST_DTYPE = None + + +def _validate_auto_cast_settings(engine): + + assert not engine.zero_quantized_weights(), "Cannot enable both torch autocast and zero quantized weights" + + +def init_autocast_params(engine, dtype: torch.dtype, + torch_autocast_lower_precision_safe_modules: Union[None, List[str]]) -> None: + + _validate_auto_cast_settings(engine) + model = engine.module + + if torch_autocast_lower_precision_safe_modules is None: + lower_precision_safe_module_classes = LOWER_PRECISION_SAFE_MODULES + else: + lower_precision_safe_module_classes = [] + for module_name in torch_autocast_lower_precision_safe_modules: + try: + package_name, class_name = module_name.rsplit('.', 1) + module = importlib.import_module(package_name) + class_ = getattr(module, class_name) + lower_precision_safe_module_classes.append(class_) + except Exception as e: + raise ValueError(f"Failed to import lower precision safe module {module_name}: {e}") + + for module in model.modules(): + if module.__class__ in lower_precision_safe_module_classes: + for p in module.parameters(recurse=False): + setattr(p, PARAM_COMM_DTYPE_ATTR_NAME, dtype) + + global TORCH_AUTOCAST_INITIALIZED + TORCH_AUTOCAST_INITIALIZED = True + global TORCH_AUTOCAST_DTYPE + TORCH_AUTOCAST_DTYPE = dtype + + +def is_autocast_initialized() -> bool: + return TORCH_AUTOCAST_INITIALIZED + + +def get_default_autocast_lower_precision_modules() -> List[str]: + return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES] + + +def get_autocast_dtype() -> torch.dtype: + return TORCH_AUTOCAST_DTYPE + + +def has_comm_dtype(param: torch.nn.Parameter) -> bool: + return hasattr(param, PARAM_COMM_DTYPE_ATTR_NAME) + + +def get_comm_dtype(param: torch.nn.Parameter) -> torch.dtype: + return getattr(param, PARAM_COMM_DTYPE_ATTR_NAME, param.dtype) + + +def get_all_comm_dtypes(params: Iterable) -> Set[torch.dtype]: + return {get_comm_dtype(p) for p in params} + + +def sort_dtypes(dtypes: List[torch.dtype]) -> List[torch.dtype]: + return sorted(dtypes, key=str) + + +@contextmanager +def autocast_if_enabled(engine): + """Context manager for DeepSpeed autocast with conditional support. + + This function manages `torch.autocast` contexts under DeepSpeed, allowing + autocast to be enabled or disabled dynamically based on runtime conditions. + It ensures consistency when autocast is already active outside of DeepSpeed, + or when it is configured within the DeepSpeed engine. + + Args: + engine: DeepSpeed engine instance. + """ + global _WARNED_NESTED_AUTOCAST + + if torch.is_autocast_enabled(): + if engine.torch_autocast_enabled(): + if not _WARNED_NESTED_AUTOCAST: + if dist.get_rank() == 0: + logger.info( + "torch.autocast is already enabled outside DeepSpeed. " + "Switching to the configuration defined in `torch_autocast` section of DeepSpeed config.") + _WARNED_NESTED_AUTOCAST = True + with torch.autocast(device_type=get_accelerator().device_name(), + dtype=engine.torch_autocast_dtype(), + enabled=True): + yield + else: + if not _WARNED_NESTED_AUTOCAST: + if dist.get_rank() == 0: + logger.warning( + "torch.autocast is enabled outside DeepSpeed but disabled within the DeepSpeed engine. " + "If you are using DeepSpeed's built-in mixed precision, the engine will follow the settings in bf16/fp16 section. " + "To use torch's native autocast instead, configure the `torch_autocast` section in the DeepSpeed config." + ) + _WARNED_NESTED_AUTOCAST = True + with torch.autocast(device_type=get_accelerator().device_name(), enabled=False): + yield + else: + if engine.torch_autocast_enabled(): + with torch.autocast(device_type=get_accelerator().device_name(), + dtype=engine.torch_autocast_dtype(), + enabled=True): + yield + else: + yield diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index ffb09677f046..2392683db81d 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -9,29 +9,28 @@ """ from collections.abc import Iterable -from deepspeed.moe.utils import is_moe_param import os import psutil import gc from math import sqrt -from math import floor -from bisect import bisect_left -import torch -from deepspeed import comm as dist +from numpy import prod +import torch +from torch.nn import functional as F try: from torch._six import inf except ModuleNotFoundError: from torch import inf - +from typing import Union, List, Dict, Sequence +from deepspeed import comm as dist +from deepspeed.moe.utils import is_moe_param from deepspeed.utils import groups, logger +from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size, + bwc_pipeline_parallel_group) from deepspeed.runtime.constants import PIPE_REPLICATED -from numpy import prod from deepspeed.accelerator import get_accelerator - from deepspeed.module_inject.policy import transpose -from torch.nn import functional as F torch_memory_reserved = get_accelerator().memory_reserved torch_max_memory_reserved = get_accelerator().max_memory_reserved @@ -48,10 +47,82 @@ def __init__(self, params): self.param_groups.append({'params': params}) +def filter_empty_parameters(params): + """Filter out empty parameters (numel == 0) from optimizer params. + + This is useful for optimizers that perform operations like division by numel, + which would produce NaNs for empty parameters. + + Args: + params: Either a list/tuple of Parameters, or a list of parameter group dicts + (each dict has 'params' key with list of Parameters) + + Returns: + Filtered params in the same format as input (list of Parameters or list of dicts) + """ + if not isinstance(params, (list, tuple)) or len(params) == 0: + return params + + # Check if first element is a dict (parameter groups) or a Parameter + if isinstance(params[0], dict): + # params is a list of parameter group dicts + filtered_params = [] + for param_group in params: + filtered_group = {} + trainable_params = [] + for key, value in param_group.items(): + if key == 'params': + # Filter out empty parameters + trainable_params = [p for p in value if p.numel() > 0] + else: + filtered_group[key] = value + # Only add group if it has non-empty parameters + if len(trainable_params) > 0: + filtered_group['params'] = trainable_params + filtered_params.append(filtered_group) + return filtered_params + else: + # params is a list of Parameters + return [p for p in params if p.numel() > 0] + + +graph_cache = {} + + +def graph_process(replay_first_step, func, *args, **kwargs): + # `func` should only contain operations on the GPU + # Please ensure that the memory address of the data required by 'func' remains constant + if func.__name__ not in graph_cache: + cuda_stream = get_accelerator().Stream() + cuda_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(cuda_stream): + func(*args, **kwargs) + get_accelerator().current_stream().wait_stream(cuda_stream) + graph_cache[func.__name__] = get_accelerator().create_graph() + with get_accelerator().capture_to_graph(graph_cache[func.__name__]): + func(*args, **kwargs) + if replay_first_step: + get_accelerator().replay_graph(graph_cache[func.__name__]) + else: + get_accelerator().replay_graph(graph_cache[func.__name__]) + + def noop_decorator(func): return func +class noop_context(object): + + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def ensure_directory_exists(filename): """Create the directory path to ``filename`` if it does not already exist. @@ -71,8 +142,13 @@ def set_random_seed(seed): import numpy import random random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) + + # pytest-randomly passes a too large seed + # `numpy.random.default_rng` could be a better approach, but it requires more changes to use rngs explicitly + # numpy.random accepts only 32-bit integers + numpy.random.seed(seed % (2**32)) + # torch.manual_seed accepts only 64-bit integers + torch.manual_seed(seed % (2**63)) def is_model_parallel_parameter(p) -> bool: @@ -85,44 +161,6 @@ def is_model_parallel_parameter(p) -> bool: return False -def bwc_tensor_model_parallel_rank(mpu=None): - """Backwards-compatible way of querying the tensor model parallel rank from - an ``mpu`` object. - - *Tensor* model parallelism means that tensors are physically split across - processes. This contrasts with *pipeline* model parallelism, in which the - layers are partitioned but tensors left intact. - - The API for tensor model parallelism has changed across versions and this - helper provides a best-effort implementation across versions of ``mpu`` - objects. The preferred mechanism is - ``mpu.get_tensor_model_parallel_rank()``. - - This should "just work" with both Megatron-LM and DeepSpeed's pipeline - parallelism. - - Args: - mpu (model parallel unit, optional): The tensor model parallel rank. - If ``mpu=None``, returns 0. Defaults to ``None``. - - Returns: - int: the rank - """ - if mpu is None: - # No model parallelism in easy :) - return 0 - - if hasattr(mpu, 'get_tensor_model_parallel_rank'): - # New Megatron and DeepSpeed convention (post pipeline-parallelism release) - return mpu.get_tensor_model_parallel_rank() - elif hasattr(mpu, 'get_slice_parallel_rank'): - # Some DeepSpeed + pipeline parallelism versions - return mpu.get_slice_parallel_rank() - else: - # Deprecated Megatron and DeepSpeed convention - return mpu.get_model_parallel_rank() - - def copy_to_device(item, device, criterion_func): """ Return a copy of tensor on specified device. @@ -147,19 +185,19 @@ def copy_to_device(item, device, criterion_func): return item -def move_to_device(item, device, criterion_func): +def move_to_device(item, device, criterion_func=None): """ Move tensor on to specified device by changing the storage. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. Parameters: item: tensor to move or (possibly nested) container of tensors to move. device: target device - criterion_func: Function to restrict move operation to items meet criterion + criterion_func: Function to restrict move operation to items meet criterion, defaults to `None` which is an equivalent to always move Returns: None """ - if criterion_func(item): + if (criterion_func is not None and criterion_func(item)): device_copy = item.to(device) item.data = device_copy.data return item @@ -170,7 +208,18 @@ def move_to_device(item, device, criterion_func): elif isinstance(item, dict): return {k: move_to_device(v, device, criterion_func) for k, v in item.items()} else: - return item + return item.to(device) + + +def get_norm_with_moe_layers_fast(all_groups_norm, group): + # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. + # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce + scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=group) + all_groups_norm = scaled_norm_tensor.item() + #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") + return all_groups_norm class CheckOverflow(object): @@ -240,7 +289,7 @@ def has_overflow(self, params, has_moe_params=None): # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = get_accelerator().ByteTensor([overflow]) - # deepspeeed.comm.all_reduce(overflow_gpu, + # deepspeed.comm.all_reduce(overflow_gpu, # op=deepspeed.comm.ReduceOp.MAX, # group=mpu.get_model_parallel_group()) if has_moe_params: @@ -252,8 +301,8 @@ def has_overflow(self, params, has_moe_params=None): elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') - if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( - not using_pipeline and self.deepspeed.enable_backward_allreduce is False): + if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce + is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: @@ -331,48 +380,55 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) - max_norm = float(max_norm) norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for p in parameters: + all_norms.append(p.grad.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + total_norm = total_norm.to(get_accelerator().current_device_name()) # Take max across all GPUs. if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item()**norm_type + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) else: - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type - + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) + if len(all_norms) > 0: + total_norm = torch.stack(all_norms).square().sum().float() + else: + total_norm = get_accelerator().FloatTensor([0.0]) + total_norm = total_norm.to(get_accelerator().current_device_name()) # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + total_norm = total_norm.pow(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) + scaled_norm_tensor = scaled_norm - scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) - total_norm = scaled_norm_tensor.item() + total_norm = scaled_norm_tensor + total_norm = total_norm.to(parameters[0].device) + max_norm = torch.tensor([float(max_norm)], device=total_norm.device) clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1: - for p in parameters: - p.grad.data.mul_(clip_coef) + tmp_tensor = torch.tensor([1.0], device=clip_coef.device) + clip_coef = torch.min(tmp_tensor, clip_coef) + for p in parameters: + p.grad.data.mul_(clip_coef) return total_norm -def get_grad_norm(parameters, norm_type=2, mpu=None): +def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -382,10 +438,10 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - + grad_norm_mask (List[Tensor]): A list of Tensor, where + each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. Returns: Total norm of the parameters (viewed as a single vector). """ @@ -403,18 +459,27 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. - tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) - for p in parameters: - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): - continue + for idx, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: + + # A loop-free implementation to create a mask tensor based on a range list + # which is logically equivalent to the following implementation. + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(), + dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) + mask_tensor = torch.zeros(p.shape[0] + 1, + device=get_accelerator().current_device_name(), + dtype=p.dtype) + mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) - param_norm = p.grad.data.float().norm(norm_type) + else: + param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. @@ -479,12 +544,12 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). + -1 if the norm value is NaN or Inf. """ if isinstance(parameters, torch.Tensor): parameters = [parameters] @@ -539,6 +604,7 @@ def prefix_sum_inc(weights): def partition_uniform(num_items, num_parts): + import numpy parts = [0] * (num_parts + 1) # First check for the trivial edge case if num_items <= num_parts: @@ -546,74 +612,55 @@ def partition_uniform(num_items, num_parts): parts[p] = min(p, num_items) return parts - chunksize = floor(num_items / num_parts) - for p in range(num_parts): - parts[p] = min(chunksize * p, num_items) - parts[num_parts] = num_items - return parts - - -def _lprobe(weights, num_parts, bottleneck): - num_items = len(weights) - total_weight = weights[-1] - - # initialize partitioning - parts = [0] * (num_parts + 1) - for p in range(1, num_parts + 1): - parts[p] = num_items - - bsum = bottleneck # running sum of target weight for pth partition chunksize = num_items // num_parts - step = chunksize - for p in range(1, num_parts): - # Jump to the next bucket - while (step < num_items) and (weights[step] < bsum): - step += chunksize - - # Find the end index of partition p - parts[p] = bisect_left(weights, bsum, lo=step - chunksize, hi=min(step, num_items)) - # Nothing more to partition, return early - if parts[p] == num_items: - # See if the current partition is overweight. - part_size = weights[-1] - weights[parts[p - 1]] - return parts, part_size < bottleneck - - # Next partition target - bsum = weights[parts[p] - 1] + bottleneck - - return parts, bsum >= total_weight - - -def _rb_partition_balanced(weights, num_parts, eps): - total_weight = weights[-1] - lower = total_weight / num_parts # best case heaviest partition - upper = total_weight # worst case heaviest partition - - # Do a binary search for the best partitioning - while upper > lower + eps: - mid = lower + ((upper - lower) / 2) - parts, success = _lprobe(weights, num_parts, mid) - if success: - upper = mid - else: - lower = mid + eps - return upper + residual = num_items - (chunksize * num_parts) + parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize) -def partition_balanced(weights, num_parts, eps=1e-3): - num_items = len(weights) - # First check for the trivial edge case - if num_items <= num_parts: - return partition_uniform(num_items, num_parts) + for i in range(residual): + parts[i + 1:] += 1 + parts = parts.tolist() - weights_ = prefix_sum_inc(weights) + return parts - # Find the smallest bottleneck (weight of heaviest partition) - bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps) - # Now compute that partitioning - parts, success = _lprobe(weights_, num_parts, bottleneck) - assert success +def partition_balanced(weights, num_parts): + """ + use dynamic programming solve `The Linear Partition Problem`. + see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM + """ + import numpy as np + n = len(weights) + m = num_parts + + if n <= m: + return partition_uniform(n, m) + + dp_max = np.full((n + 1, m + 1), np.inf) + dp_min = np.full((n + 1, m + 1), np.inf) + dp_cost = np.full((n + 1, m + 1), np.inf) + position = np.zeros((n + 1, m + 1), dtype=int) + prefix_sum = np.zeros((n + 1)) + prefix_sum[1:] = np.cumsum(weights) + + dp_max[0, 0] = 0 + dp_cost[0, 0] = 0 + for i in range(1, n + 1): + for j in range(1, min(i, m) + 1): + for k in range(i): + max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k]) + min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k]) + cost = max_sum - min_sum + if dp_cost[i, j] >= cost: + dp_cost[i, j] = cost + dp_max[i, j] = max_sum + dp_min[i, j] = min_sum + position[i, j] = k + + parts = [n] + for i in reversed(range(1, m + 1)): + parts.append(position[parts[-1], i]) + parts.reverse() return parts @@ -626,10 +673,10 @@ def __init__(self, tensor, group, partition_meta=None): self.group = group self.num_parts = dist.get_world_size(group=self.group) self.rank = dist.get_rank(group=self.group) - self.orig_size = list(tensor.size()) self.orig_device = tensor.device self.local_data, self.partition = self._partition_tensor(tensor) + self.even_split = tensor.numel() % self.num_parts == 0 @classmethod def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()): @@ -672,23 +719,16 @@ def full(self, device=None): # Allocate the full tensor as a flat buffer. full_numel = prod(self.full_size()) flat_tensor = torch.zeros([full_numel], dtype=self.local_data.dtype, device=device) - - # Prepare all-gather buffer - partition_tensors = [] - for part_id in range(self.num_parts): - part_size = self.partition[part_id + 1] - self.partition[part_id] - buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size) - if part_id == self.rank: - buf.copy_(self.local_data) - partition_tensors.append(buf) - - # Collect the full tensor - dist.all_gather(partition_tensors, partition_tensors[self.rank], group=self.group) - - for i in range(len(partition_tensors)): - partition_tensors[i].data = torch.zeros(1) - partition_tensors[i] = None - + if self.even_split: + # Collect the full tensor + dist.all_gather_into_tensor(flat_tensor, self.local_data, group=self.group) + else: + for part_id in range(self.num_parts): + part_size = self.partition[part_id + 1] - self.partition[part_id] + buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size) + if part_id == self.rank: + buf.copy_(self.local_data) + dist.broadcast(buf, part_id, self.group) return flat_tensor.view(self.full_size()).clone().detach() def to_meta(self): @@ -782,15 +822,15 @@ def see_memory_usage(message, force=False): gc.collect() # Print message except when distributed but not rank 0 - logger.info(message) - logger.info(f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ + print(message) + print(f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ") vm_stats = psutil.virtual_memory() used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) - logger.info(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%') + print(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%') # get the peak memory to report correct data, so reset the counter for the next call get_accelerator().reset_peak_memory_stats() @@ -827,26 +867,15 @@ def get_only_unique_item(items): return unique_item -def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): - """Clip the gradient of a list of parameters. - Args: - parameters: List of parameters whose .grad will be clipped. - global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. - mpu (optional): model parallelism unit. Defaults to None. - eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 - Returns: - float: the global gradient norm - """ - if global_grad_norm is None: - global_grad_norm = get_grad_norm(parameters, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - return global_grad_norm +def mask_nan_or_inf_with_val_inplace(input, device=None, val=-1.): + norm_is_inf = input.isinf() + norm_is_nan = input.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + err = torch.tensor(-1.0, device=device, dtype=torch.float) + input.masked_fill_(inf_or_nan, err) -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -860,31 +889,72 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): Returns: Total norm of the tensors (viewed as a single vector). """ - assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' - assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' + assert all([torch.is_tensor(t) for t in input_tensors]), 'expected list of only tensors' norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(t.data.abs().max() for t in input_tensors) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for t in input_tensors: + all_norms.append(t.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + device_total_norm = total_norm.to(get_accelerator().current_device_name()) + # Max across model parallel if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + # For MoE grads, max over model parallel only if MoE-TP is enabled + if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + # If MoE grads and MoE-TP disabled, max over pipeline parallel + elif bwc_pipeline_parallel_world_size(mpu) > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu)) + + # MoE grads: max across expert parallel group + if moe_ep_group is not None: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device) else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + + if 'norm_tensors_compute_buffer' not in graph_cache or len( + graph_cache['norm_tensors_compute_buffer']) != len(input_tensors): + graph_cache['norm_tensors_compute_buffer'] = [ + torch.empty([], dtype=torch.float, device=get_accelerator().current_device_name()) + for t in input_tensors + ] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] + + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) + + if use_graph: + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + else: + _norm_tensors(input_tensors, compute_buffer, norm_type) + + device_total_norm = compute_buffer[0].float().detach() + + # Sum across model parallel if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + # For MoE grads, sum over model parallel only if MoE-TP is enabled + if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + # If MoE grads and MoE-TP disabled, sum over pipeline parallel + elif bwc_pipeline_parallel_world_size(mpu) > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu)) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + # MoE grads: sum across expert parallel group + if moe_ep_group is not None: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type) + + mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device) return total_norm -def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6): +def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False): """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped @@ -895,14 +965,26 @@ def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, m float: the global norm """ if global_norm is None: - global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) - + global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph) clip_coef = max_norm / (global_norm + eps) - if clip_coef < 1: - for t in input_tensors: - t.detach().mul_(clip_coef) + if use_graph: + + def clip_tensors(_tensor_list, _clip_coef_tensor): + for t in _tensor_list: + t.detach().mul_(_clip_coef_tensor) + + if 'clip_coef_tensor' not in graph_cache: + # Alloc memory + graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef, + dtype=torch.float32).to(get_accelerator().device_name()) + clip_coef_tensor = graph_cache['clip_coef_tensor'] + clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32)) + graph_process(False, clip_tensors, input_tensors, clip_coef_tensor) + else: + for t in input_tensors: + t.detach().mul_(clip_coef) return global_norm @@ -920,12 +1002,31 @@ def align_dense_tensors(tensor_list, alignment): return padded_tensor_list -def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_alignment_factor, allgather_bucket_size): +def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, dp_process_group): + for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)): + partition_id = dist.get_rank(group=dp_process_group[group_id]) + dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + if dp_world_size == 1: + # no groups share optimizer states + # pipeline parallel with bf16 will default call this even if dp size = 1. + continue + dist.all_gather_into_tensor(group_flat, partitioned_params[partition_id], dp_process_group[group_id]) + + +def all_gather_dp_groups(groups_flat, partitioned_param_groups, dp_process_group, start_alignment_factor, + allgather_bucket_size): + if dist.has_all_gather_into_tensor(): + return all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, dp_process_group) + for group_id, partitioned_params in enumerate(partitioned_param_groups): # Sequential AllGather Best of both worlds partition_id = dist.get_rank(group=dp_process_group[group_id]) dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + if dp_world_size == 1: + # no groups share optimizer states + # pipeline parallel with bf16 will default call this even if dp size = 1. + continue num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size) shard_size = partitioned_params[partition_id].numel() // num_shards @@ -950,6 +1051,37 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id]) +def get_tensor_bytes(item): + if torch.is_tensor(item): + return item.numel() * item.element_size() + elif isinstance(item, list): + return sum([get_tensor_bytes(v) for v in item]) + elif isinstance(item, tuple): + return sum([get_tensor_bytes(v) for v in item]) + elif isinstance(item, dict): + return sum([get_tensor_bytes(v) for v in item.values()]) + else: + return 0 + + +def _get_folder_size(folder): + size = 0 + for path, _, files in os.walk(folder): + size += sum([os.path.getsize(os.path.join(path, f)) for f in files]) + return size + + +def get_checkpoint_folder_size(save_dir, tag, local_rank=None): + if local_rank == 0: + folder = os.path.join(save_dir, tag) + size_tensor = torch.tensor(_get_folder_size(folder)).to(get_accelerator().device_name()) + else: + size_tensor = torch.tensor(0).to(get_accelerator().device_name()) + + dist.reduce(tensor=size_tensor, dst=0) + return int(size_tensor) + + class TLinear(torch.nn.Linear): def __init__(self, orig_layer, name=""): @@ -973,3 +1105,370 @@ def get_inactive_params(param_list): from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus return [param for param in param_list if (hasattr(param, 'ds_id') and \ param.ds_status == ZeroParamStatus.NOT_AVAILABLE)] + + +def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2): + """ Compute the global norm with MoE experts + + Inputs: + non_expert_norm (float) : the calculated norm of the non-expert params + expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors + norm_type (int): the norm to use + + Returns: + if norm is (-/+) inf, returns -1 + otherwise the global norm (float) + """ + + def to_tensor(v): + return get_accelerator().FloatTensor(float(v)).detach() + + group_norms = [non_expert_norm] + for exp_name, tensors in expert_tensors.items(): + group_norm = get_global_norm_of_tensors(input_tensors=tensors, + mpu=mpu, + norm_type=norm_type, + use_graph=False, + moe_ep_group=groups._get_expert_parallel_group(exp_name)) + group_norms.append(group_norm) + + # check if all norms are valid + group_norms = torch.stack([to_tensor(norm) for norm in group_norms]) + if group_norms.eq(-1).any(): + return -1 + + # combine norms + if norm_type == inf: + total_norm = group_norms.max().item() + else: + total_norm = group_norms.pow(norm_type).sum() + total_norm = total_norm.item()**(1. / norm_type) + if total_norm == float('inf') or total_norm == -float('inf'): + total_norm = -1 + + return total_norm + + +def _make_offload_state_key(key): + return f"{key}_offload_buffer" + + +def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_key(state, key): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = torch.empty_like(state[key], device=device) + if pin_memory: + state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) + state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[offload_buf_key] + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + +def reload_adam_states(optimizer, device, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_back_key(state, key): + state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_back_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_back_key(state, "exp_avg_sq") + + +def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[List, Dict]) -> bool: + """ + Compare two lists or dictionaries for equality, including any tensors they may contain. + + Args: + inputs1: First input, either a list or a dictionary. + inputs2: Second input, either a list or a dictionary. + + Returns: + True if inputs1 and inputs2 are equal; False otherwise. + """ + if type(inputs1) != type(inputs2): # Ensure types match + return False + + if isinstance(inputs1, list) and isinstance(inputs2, list): + if len(inputs1) != len(inputs2): + return False + for val1, val2 in zip(inputs1, inputs2): + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + val1 = val1.to(torch.device(get_accelerator().current_device_name())) + val2 = val2.to(torch.device(get_accelerator().current_device_name())) + if not torch.equal(val1, val2): + return False + elif val1 != val2: + return False + return True + + elif isinstance(inputs1, dict) and isinstance(inputs2, dict): + if inputs1.keys() != inputs2.keys(): + return False + for key in inputs1: + val1, val2 = inputs1[key], inputs2[key] + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + val1 = val1.to(torch.device(get_accelerator().current_device_name())) + val2 = val2.to(torch.device(get_accelerator().current_device_name())) + if not torch.equal(val1, val2): + return False + elif val1 != val2: + return False + return True + + return False + + +def maybe_loss_for_backward(value) -> bool: + """Check if the value is a loss tensor. + Conditions: + - The value must be a tensor. + - The tensor must have exactly one element. + - The tensor must have grad_fn defined. + + Args: + value: The value to check. + """ + return isinstance(value, torch.Tensor) and value.numel() == 1 and value.grad_fn is not None + + +class OutputBackwardHookManager: + """ + Manages backward hooks on output tensors to trigger preprocessing only once. + + This is an alternative to register_full_backward_pre_hook that avoids warnings + and provides more fine-grained control over when preprocessing occurs. + + The hook manager automatically manages its lifetime by attaching itself to the + output tensors. When the outputs are freed, the hook manager is also freed. + + This manager handles two types of preprocessing: + 1. Global preprocessing (run once per backward pass): timers, flags, setup + 2. Per-tensor preprocessing (run for each output tensor): gradient scaling, loss logging + + Usage: + # Only global preprocessing (run once) + hook_manager = OutputBackwardHookManager( + preprocess_once_fn=lambda: start_timers() + ) + + # Both global and per-tensor preprocessing + hook_manager = OutputBackwardHookManager( + preprocess_once_fn=lambda: start_timers(), + preprocess_per_tensor_fn=lambda tensor: scale_gradient(tensor) + ) + + outputs = model(*inputs) + hook_manager.register_hooks_on_outputs(outputs) + # No need to manually clean up - it's freed when outputs are freed + """ + + def __init__(self, preprocess_once_fn, preprocess_per_tensor_fn=None): + """ + Args: + preprocess_once_fn: A callable that takes no arguments and performs + one-time preprocessing before backward (e.g., start timers). + Will only be called once per backward pass. + preprocess_per_tensor_fn: Optional callable that takes a tensor and returns + a potentially modified tensor. Called for each output + tensor during backward (e.g., gradient scaling). + If None, no per-tensor processing is done. + """ + self.preprocess_once_fn = preprocess_once_fn + self.preprocess_per_tensor_fn = preprocess_per_tensor_fn + self.preprocess_done = False + self.hook_handles = [] + + def _make_backward_hook(self, tensor): + """ + Creates a backward hook for a specific tensor. + + Args: + tensor: The output tensor this hook is attached to + """ + + def backward_hook(grad): + # First, ensure global preprocessing happens once + if not self.preprocess_done: + self.preprocess_done = True + self.preprocess_once_fn() + + # Then apply per-tensor preprocessing if provided + if self.preprocess_per_tensor_fn is not None: + # Per-tensor preprocessing receives the tensor + # It can perform operations like gradient scaling + grad = self.preprocess_per_tensor_fn(grad) + + return grad + + return backward_hook + + def _traverse_and_register_hooks(self, outputs, first_tensor_holder): + """ + Recursively traverse outputs to find tensors with grad_fn and register hooks. + + Args: + outputs: Can be a tensor, tuple, list, dict, or nested structure of these. + first_tensor_holder: List to hold the first tensor found (for attaching self) + """ + if isinstance(outputs, torch.Tensor): + if outputs.grad_fn is not None: + # Store reference to first tensor to attach hook manager lifetime + if not first_tensor_holder: + first_tensor_holder.append(outputs) + # Pass the tensor to _make_backward_hook so per-tensor processing can access it + hook_handle = outputs.register_hook(self._make_backward_hook(outputs)) + self.hook_handles.append(hook_handle) + elif isinstance(outputs, (tuple, list)): + for item in outputs: + self._traverse_and_register_hooks(item, first_tensor_holder) + elif isinstance(outputs, dict): + for value in outputs.values(): + self._traverse_and_register_hooks(value, first_tensor_holder) + + def register_hooks_on_outputs(self, outputs): + """ + Register backward hooks on all output tensors that have grad_fn. + + Args: + outputs: The outputs from the forward pass. Can be a tensor or nested structure. + """ + # Reset state for new forward pass + self.preprocess_done = False + self.remove_hooks() + + # Register hooks on all tensors with grad_fn + first_tensor_holder = [] + self._traverse_and_register_hooks(outputs, first_tensor_holder) + + # Attach this hook manager instance to the first output tensor + # This ensures the hook manager is kept alive as long as the outputs are alive + # and automatically freed when outputs are freed + if first_tensor_holder: + first_tensor = first_tensor_holder[0] + if not hasattr(first_tensor, '_backward_hook_managers'): + first_tensor._backward_hook_managers = [] + first_tensor._backward_hook_managers.append(self) + + def remove_hooks(self): + """Remove all registered hooks.""" + for handle in self.hook_handles: + handle.remove() + self.hook_handles.clear() + + def reset(self): + """Reset the preprocessing flag without removing hooks.""" + self.preprocess_done = False + + +def register_output_backward_hooks(outputs, preprocess_once_fn, preprocess_per_tensor_fn=None): + """ + Convenience function to register backward hooks on outputs. + + This function creates a hook manager that is automatically tied to the lifetime + of the output tensors. When outputs are freed, the hook manager is also freed. + + Args: + outputs: The outputs from forward pass (tensor, tuple, list, dict, or nested) + preprocess_once_fn: A callable that takes no arguments and performs one-time + preprocessing before backward. Will only be called once per backward pass. + preprocess_per_tensor_fn: Optional callable that takes a tensor and performs + per-tensor preprocessing (e.g., gradient scaling). + Called for each output tensor during backward. + + Returns: + The hook manager instance (usually not needed, as lifetime is automatic) + + Example: + # Only global preprocessing + outputs = model(x) + register_output_backward_hooks(outputs, lambda: print("Backward starting!")) + + # Both global and per-tensor preprocessing + outputs = model(x) + register_output_backward_hooks( + outputs, + preprocess_once_fn=lambda: start_timers(), + preprocess_per_tensor_fn=lambda tensor: scale_tensor(tensor) + ) + # Hook manager is automatically freed when outputs are freed + """ + hook_manager = OutputBackwardHookManager(preprocess_once_fn, preprocess_per_tensor_fn) + hook_manager.register_hooks_on_outputs(outputs) + return hook_manager + + +def check_internal_apis_for_count_used_parameters() -> bool: + """ + Ensure the Torch internal APIs needed by `count_used_parameters_in_backward` exist. + """ + if not hasattr(torch.autograd.graph, '_get_grad_fn_or_grad_acc'): + return False + + missing = [attr for attr in ("_current_graph_task_id", "_will_engine_execute_node") if not hasattr(torch._C, attr)] + + if missing: + return False + + return True + + +def count_used_parameters_in_backward(parameters: Sequence[torch.nn.Parameter]) -> int: + """ + Count the number of parameters that participate in the currently running backward graph. + + This helper is designed to be invoked from within a backward hook where a graph task + is active. Parameters that do not require gradients, are detached, or are not touched + by the current backward pass are ignored. + + torch.autograd.graph.register_multi_grad_hook is used for the purpose, but + its verification on tensor shapes throws an error with ZeRO3 (it expects original tensor shape). + So this function simplifies register_multi_grad_hook just to count used parameters. + + Args: + parameters: Iterable of model parameters to inspect. + + Returns: + The number of parameters whose gradient nodes will be executed by the autograd engine + for the active backward call. + """ + assert check_internal_apis_for_count_used_parameters(), ( + "count_used_parameters_in_backward requires internal PyTorch APIs that are not available " + "in this PyTorch build.") + + from torch.autograd.graph import _get_grad_fn_or_grad_acc + if torch._C._current_graph_task_id() == -1: + raise RuntimeError("count_used_parameters_in_backward must be called during backward execution") + + seen_nodes = set() + for param in parameters: + if not isinstance(param, torch.Tensor) or not param.requires_grad: + continue + + # Backward hooks run with grad mode disabled, but PyTorch <=2.4's + # _get_grad_fn_or_grad_acc() requires grad mode for leaf params. + with torch.enable_grad(): + grad_fn = _get_grad_fn_or_grad_acc(param) + if grad_fn is None: + continue + + if grad_fn in seen_nodes: + continue + + seen_nodes.add(grad_fn) + + if not seen_nodes: + return 0 + + participating = sum(map(torch._C._will_engine_execute_node, seen_nodes)) + return int(participating) diff --git a/deepspeed/runtime/zenflow/__init__.py b/deepspeed/runtime/zenflow/__init__.py new file mode 100644 index 000000000000..6f5f5619004b --- /dev/null +++ b/deepspeed/runtime/zenflow/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/zenflow/engine.py b/deepspeed/runtime/zenflow/engine.py new file mode 100644 index 000000000000..2236d097169b --- /dev/null +++ b/deepspeed/runtime/zenflow/engine.py @@ -0,0 +1,151 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed import comm as dist +from typing import TYPE_CHECKING +from deepspeed.utils.torch import required_torch_version + +if TYPE_CHECKING: + from deepspeed.runtime.engine import DeepSpeedEngine + + +def configure_zenflow(engine: "DeepSpeedEngine") -> None: + """Configure ZenFlow-related scheduling parameters on the engine. + + This function initializes ZenFlow flags (e.g., `zenflow`, `auto_update`, + `select_interval`, etc.) based on the `zenflow_config` object. It handles + selection/update strategy resolution and performs basic validation. + + Args: + engine (DeepSpeedEngine): The DeepSpeed engine to configure. + """ + zenflow_config = engine.zenflow_config() + if zenflow_config == None: + engine.zenflow = False + return + if not required_torch_version(min_version=2.1): + raise ValueError( + "Please use PyTorch 2.1 or later to enable ZenFlow. Alternatively, omit `zenflow` config in the config file to fall back to the default ZeRO-Offload optimizer." + ) + + engine.zenflow = True + select_strategy = zenflow_config.select_strategy + + if select_strategy == 'auto': + select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + engine.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + engine.select_interval = zenflow_config.select_interval + + if isinstance(zenflow_config.update_interval, str): + engine.auto_update = True + engine.update_interval = 0 + else: + engine.auto_update = False + engine.update_interval = int(zenflow_config.update_interval) + + if select_strategy == 'epoch': + if engine.training_dataloader is not None: + zenflow_config.steps_per_epoch = len(engine.training_dataloader) + engine.select_interval = engine.select_interval * len(engine.training_dataloader) + else: + engine.select_interval = 0 + + if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + engine.overlap_step = zenflow_config.overlap_step + + engine.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + + engine._config.gradient_accumulation_steps = engine.update_interval + + +def is_zenflow_update_boundary(engine: "DeepSpeedEngine"): + """Determine whether the current step is an update boundary for ZenFlow. + + This function checks whether the engine should trigger an optimizer update + based on gradient accumulation, warmup phase, and selection/update intervals. + + Returns: + bool: True if this step is an update boundary, otherwise False. + """ + if engine.auto_update: + if (engine.micro_steps + 1) <= engine.full_warm_up_rounds: + return True + return (engine.optimizer.zenflow_need_update[engine.optimizer.zenflow_state ^ 1] + or (engine.select_interval != 0 and (engine.micro_steps + 1) % engine.select_interval == 0)) + else: + if (engine.micro_steps + 1) < engine.full_warm_up_rounds: + return True + return ((engine.micro_steps + 1 - engine.full_warm_up_rounds) % engine.gradient_accumulation_steps() == 0 + or (engine.select_interval != 0 and (engine.micro_steps + 1) % engine.select_interval == 0)) + + +def zenflow_step(engine: "DeepSpeedEngine", lr_kwargs): + """Main step logic for ZenFlow update scheduling. + + This function performs either: + - a selective optimizer update (if at accumulation boundary), + - or just a learning rate scheduler step and logging (if at accumulation iteration). + + Args: + engine (DeepSpeedEngine): The engine managing training state. + lr_kwargs (dict): Optional kwargs passed to the LR scheduler step. + """ + if engine.is_gradient_accumulation_boundary(): + if engine.micro_steps + 1 >= engine.full_warm_up_rounds: + _take_selective_parameter_step(engine) + if engine.auto_update: + if dist.get_rank() == 0: + print(f"Zenflow: This is an update iter. update_interval: {engine.update_interval}") + engine.update_interval = 0 + else: + _take_lr_scheduler_step(engine, lr_kwargs) + _log_selective_optimizer_timers(engine) + + +def _take_selective_parameter_step(engine: "DeepSpeedEngine"): + """ + Trigger a step on the selective optimizer. + """ + engine.optimizer.selective_optimizer_step() + + +def _take_lr_scheduler_step(engine: "DeepSpeedEngine", lr_kwargs): + """ + Take a step on the learning rate scheduler. + """ + if engine.lr_scheduler is not None: + try: + engine.lr_scheduler.step(**(lr_kwargs or {})) + except TypeError: + # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. + # We don't currently have a way to specify lr_kwargs from + # pipe_engine.train_batch() + engine.lr_scheduler.step(engine.train_batch_size()) + + +def _log_selective_optimizer_timers(engine): + """ + Log the selective optimizer timers. + """ + engine.optimizer.log_selective_optimizer_timers() + + +def sync_zenflow_optimizer_lr(engine: "DeepSpeedEngine"): + """ + Synchronize the learning rate of the selective optimizer. + If auto_update is enabled, increment the update interval. + """ + engine.optimizer._sync_selective_optimizer_lr() + if engine.auto_update: + engine.update_interval += 1 diff --git a/deepspeed/runtime/zenflow/engine_stage3.py b/deepspeed/runtime/zenflow/engine_stage3.py new file mode 100644 index 000000000000..0d95be9f2f8a --- /dev/null +++ b/deepspeed/runtime/zenflow/engine_stage3.py @@ -0,0 +1,643 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.zero.partition_parameters import * + +import torch +import math +from deepspeed import comm as dist +from deepspeed.utils import logger +from deepspeed.ops.adam import ZenFlowSelectiveAdamW_stage3 +from deepspeed.runtime.utils import see_memory_usage +from typing import List +from deepspeed.accelerator import get_accelerator +from typing import TYPE_CHECKING +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process + +if TYPE_CHECKING: + from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 + +OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state' +INIT_OPTIMIZER_TIMER = 'init_optimizer_state' +OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state' +OPTIMIZER_STEP_TIMER = 'optimizer_step' + + +def configure_zenflow(optimizer_z3, zenflow_config): + + optimizer_z3.select_strategy = zenflow_config.select_strategy + if optimizer_z3.select_strategy == 'auto': + optimizer_z3.select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + optimizer_z3.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + optimizer_z3.select_interval = int(zenflow_config.select_interval) + + if isinstance(zenflow_config.update_interval, str): + optimizer_z3.auto_update = True + optimizer_z3.update_interval = 0 + else: + optimizer_z3.auto_update = False + optimizer_z3.update_interval = int(zenflow_config.update_interval) + + if optimizer_z3.select_strategy == 'epoch': + if zenflow_config.steps_per_epoch is not None: + optimizer_z3.select_interval = optimizer_z3.select_interval * zenflow_config.steps_per_epoch + else: + optimizer_z3.select_interval = 0 + + if not optimizer_z3.auto_update and optimizer_z3.select_interval != 0 and optimizer_z3.select_interval < optimizer_z3.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + optimizer_z3.topk_ratio = zenflow_config.topk_ratio + + optimizer_z3.param_id_grad_sum_buffer_offset = {} + + optimizer_z3.zf_stage3 = True + + if optimizer_z3.auto_update: + optimizer_z3.param_id_sum_buffer_offset = {} + optimizer_z3.auto_ratio = zenflow_config.auto_ratio + optimizer_z3.zenflow_need_update = [False, False] + optimizer_z3.zenflow_state = 0 + optimizer_z3.num_need_update = 0 + + +def _initialize_zenflow_stage3_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", + module, + zenflow_config: dict = None): + + optimizer_z3.zenflow = True if zenflow_config is not None else False + + if not optimizer_z3.zenflow: + return + + optimizer_z3.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc + + for p in module.parameters(): + p.data = p.data.t().contiguous() if len(p.shape) != 1 else p.data + + +def _initialize_zenflow_stage3_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", + zenflow_config: dict = None, + overlap_comm: bool = False): + + if not optimizer_z3.zenflow: + return + + optimizer_z3.micro_step = -1 + optimizer_z3.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + optimizer_z3.offload_selective_optimizer = zenflow_config.offload + optimizer_z3.zenflow_overlap_step = zenflow_config.overlap_step + + if optimizer_z3.offload_selective_optimizer: + assert overlap_comm, "offload selective optimizer should be used with overlap_comm" + + if optimizer_z3.zenflow_overlap_step: + optimizer_z3.process_optimizer_established = False + optimizer_z3.first_update_round_after_warmup = True + optimizer_z3.initialize_optimizer_states = lambda: initialize_optimizer_states(optimizer_z3) + optimizer_z3.step = lambda closure=None: step(optimizer_z3, closure) + optimizer_z3.zenflow_cpu_optimizer_overlap_step = lambda now_state, scaled_global_grad_norm: zenflow_cpu_optimizer_overlap_step( + optimizer_z3, now_state, scaled_global_grad_norm) + optimizer_z3.wait_last_update_and_copy = lambda timer_names: wait_last_update_and_copy( + optimizer_z3, timer_names) + optimizer_z3.partition_grads = lambda params_to_release, grad_partitions: partition_grads( + optimizer_z3, params_to_release, grad_partitions) + optimizer_z3.get_overlap_step_state = lambda: get_overlap_step_state(optimizer_z3) + optimizer_z3.start_optimizer_process = lambda: start_optimizer_process(optimizer_z3) + optimizer_z3.unscale_and_clip_grads = lambda sub_group_id, total_norm, now_state: unscale_and_clip_grads( + optimizer_z3, sub_group_id, total_norm, now_state) + + configure_zenflow(optimizer_z3, zenflow_config) + optimizer_z3.selective_optimizer = ZenFlowSelectiveAdamW_stage3([{ + k: v + for k, v in group.items() if k != "params" + } | { + "params": group["params"] + } for group in optimizer_z3.optimizer.param_groups], + offload=optimizer_z3.offload_selective_optimizer) + optimizer_z3.num_total_param = sum( + sum(1 for param in group["params"] if len(param.ds_shape) != 1) + for group in optimizer_z3.optimizer.param_groups) + + +def zenflow_cpu_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + return optimizer_z3.optimizer.step(step_id=optimizer_z3.micro_step + 1) + + +def _sync_selective_optimizer_lr(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + for group_selected, group in zip(optimizer_z3.selective_optimizer.param_groups, + optimizer_z3.optimizer.param_groups): + group_selected["lr"] = group["lr"] + + +def selective_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3.selective_optimizer.step() + + +def is_zenflow_select_boundary(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> bool: + return optimizer_z3.zenflow and (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) >= 0 and ( + (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) == 0 or + (optimizer_z3.select_interval != 0 and optimizer_z3.micro_step % optimizer_z3.select_interval == 0)) + + +def update_selected_channels(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", params_to_update, grad_partitions): + src_rk = dist.get_rank(optimizer_z3.dp_process_group) + total_rk = dist.get_world_size(optimizer_z3.dp_process_group) + + total_chunk_size = 0 + param_local_offset = [0 for _ in range(total_rk)] + + for param, grad_partition in zip(params_to_update, grad_partitions): + param_max_chunk_size = 0 + param_rk_offset = 0 + for rk in range(total_rk): + contains_real_data = param.partition_numel() * rk < param.ds_numel + if not contains_real_data: + param.grad = None + continue + + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + if num_row == 1: + continue + + partition_size = param.partition_numel() + start = partition_size * rk + end = min(start + partition_size, param.ds_numel) + + start_idx = math.ceil(start / num_row) + end_idx = end // num_row + num_cols = end_idx - start_idx + + if param.ds_id not in optimizer_z3.param_id_grad_sum_buffer_offset: + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = [] + + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id].append( + (param_local_offset[rk], num_cols, param_rk_offset)) + + param_max_chunk_size = max(param_max_chunk_size, num_cols) + param_rk_offset += num_cols + param_local_offset[rk] += num_cols + + total_chunk_size += param_max_chunk_size + + optimizer_z3.grad_sum_buffer = torch.zeros(total_chunk_size, dtype=optimizer_z3.dtype, device='cuda') + + for param, grad_partition in zip(params_to_update, grad_partitions): + contains_real_data = param.partition_numel() * src_rk < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + #ds_shape is the transposed shape, it should not be same as param.shape + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + continue + + partition_size = param.partition_numel() + start = partition_size * src_rk + end = min(start + partition_size, param.ds_numel) + + start_idx = math.ceil(start / num_row) + end_idx = end // num_row + + num_elements = (end_idx - start_idx) * num_row + + param.complete_column_offset = start_idx * num_row - start + param.complete_numel = (end_idx - start_idx) * num_row + + sum_per_column = grad_partition.narrow(0, param.complete_column_offset, num_elements) + sum_per_column = sum_per_column.view(end_idx - start_idx, num_row) + sum_array = sum_per_column.abs().sum(dim=1) + + offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk] + optimizer_z3.grad_sum_buffer.narrow(0, offset, length).copy_(sum_array) + + gathered_chunks = [torch.zeros_like(optimizer_z3.grad_sum_buffer) for _ in range(total_rk)] + dist.all_gather(gathered_chunks, optimizer_z3.grad_sum_buffer, group=optimizer_z3.dp_process_group) + + for param in params_to_update: + + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + continue + + param_column_sum = [] + for rk in range(total_rk): + offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][rk] + param_column_sum.append(gathered_chunks[rk].narrow(0, offset, length)) + global_param_column_sum = torch.cat(param_column_sum, dim=0) + + num_select = max(1, int(global_param_column_sum.numel() * optimizer_z3.topk_ratio)) + _, global_topk_indices = torch.topk(global_param_column_sum, num_select, largest=True) + + _, length, rk_offset = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk] + local_indices = [(idx.item() - rk_offset) for idx in global_topk_indices + if rk_offset <= idx < rk_offset + length] + param.selected_indices = torch.tensor(local_indices, device='cuda') + optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = [] + + optimizer_z3.grad_sum_buffer = None + + +def _process_selected_fp32_groups_grad(optimizer_z3, params_to_update, grad_partitions): + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer = torch.zeros(optimizer_z3.num_total_param, dtype=optimizer_z3.dtype, device='cuda') + optimizer_z3.critic_sum_buffer = torch.zeros(optimizer_z3.num_total_param, + dtype=optimizer_z3.dtype, + device='cuda') + curr_buffer_idx = 0 + + for param, grad_partition in zip(params_to_update, grad_partitions): + + rk = dist.get_rank(optimizer_z3.dp_process_group) + + contains_real_data = param.partition_numel() * rk < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + #ds_shape is the transposed shape, it should not be same as param.shape + num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1) + + if num_row == 1: + param.selected_grad = grad_partition.clone().detach() + else: + grad_2d = grad_partition.narrow(0, param.complete_column_offset, + param.complete_numel).view(param.complete_numel // num_row, num_row) + if param.selected_indices.numel() == 0: + param.selected_indices = param.selected_indices.long() + param.selected_grad = grad_2d[param.selected_indices, :].clone().detach() + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer[curr_buffer_idx] = grad_partition.abs().sum() + optimizer_z3.critic_sum_buffer[curr_buffer_idx] = param.selected_grad.abs().sum() + curr_buffer_idx += 1 + + if optimizer_z3.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'): + buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=optimizer_z3.device) + param.exp_avg_cpu_data = get_accelerator().pin_memory( + buffer) if optimizer_z3.offload_optimizer_pin_memory else buffer + param.exp_avg_sq_cpu_data = get_accelerator().pin_memory( + buffer.clone()) if optimizer_z3.offload_optimizer_pin_memory else buffer.clone() + + if optimizer_z3.auto_update: + total_rk = dist.get_world_size(optimizer_z3.dp_process_group) + sum_gather_list = [torch.zeros_like(optimizer_z3.sum_buffer) for _ in range(total_rk)] + critic_gather_list = [torch.zeros_like(optimizer_z3.critic_sum_buffer) for _ in range(total_rk)] + curr_buffer_idx = 0 + + dist.all_gather(sum_gather_list, optimizer_z3.sum_buffer, group=optimizer_z3.dp_process_group) + dist.all_gather(critic_gather_list, optimizer_z3.critic_sum_buffer, group=optimizer_z3.dp_process_group) + + for param in params_to_update: + if len(param.ds_shape) == 1: + continue + + if not hasattr(param, 'non_critic_sum'): + param.non_critic_sum = 0 + if not hasattr(param, 'avg_critic_sum'): + param.avg_critic_sum = 0 + + grad_total_sum = sum(sum_gather_list[rk][curr_buffer_idx] for rk in range(total_rk)) + grad_critic_sum = sum(critic_gather_list[rk][curr_buffer_idx] for rk in range(total_rk)) + + param.avg_critic_sum = (param.avg_critic_sum * (optimizer_z3.update_interval - 1) + + grad_critic_sum) / optimizer_z3.update_interval / (optimizer_z3.topk_ratio * 10) + param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - optimizer_z3.topk_ratio) * 10) + if param.non_critic_sum >= param.avg_critic_sum: + optimizer_z3.num_need_update += 1 + if optimizer_z3.num_need_update >= int(optimizer_z3.auto_ratio * optimizer_z3.num_total_param): + optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = True + + curr_buffer_idx += 1 + + if not optimizer_z3.is_gradient_accumulation_boundary: + optimizer_z3.selective_optimizer.group_step(params_to_update) + else: + optimizer_z3.selective_optimizer.temp_copy_param(params_to_update) + + if optimizer_z3.auto_update: + optimizer_z3.sum_buffer = None + optimizer_z3.critic_sum_buffer = None + + +def sync_fp32_param_from_gpu(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + + if optimizer_z3.micro_step == 0: + return + + for fp16_partitions, fp32_partition in zip(optimizer_z3.fp16_partitioned_groups_flat, + optimizer_z3.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + +def zenflow_backward_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3.micro_step += 1 + if optimizer_z3.auto_update: + optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = False + optimizer_z3.num_need_update = 0 + if optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state ^ 1]: + optimizer_z3.update_interval = 0 + for group in optimizer_z3.fp16_groups: + for p in group: + p.non_critic_sum = 0 + optimizer_z3.update_interval += 1 + if optimizer_z3.is_zenflow_select_boundary(): + sync_fp32_param_from_gpu(optimizer_z3) + optimizer_z3.selective_optimizer.clear_selected_mv() + + +def zenflow_backward_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + optimizer_z3._partition_all_parameters() + + +def log_selective_optimizer_timers(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + pass + + +def initialize_optimizer_states(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"): + num_subgroups = len(optimizer_z3.fp16_groups) + + largest_numel = max([sum([p.ds_numel for p in psg]) for psg in optimizer_z3.fp16_partitioned_groups]) + gradient_dtype = optimizer_z3.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=optimizer_z3.device) + + timer_names = set() + + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + is_adagrad = isinstance(optimizer_z3.optimizer, torch.optim.Adagrad) + + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.init_timers() + + timer_names.add(INIT_OPTIMIZER_TIMER) + optimizer_z3.timers(INIT_OPTIMIZER_TIMER).start() + + for i, group in enumerate(optimizer_z3.fp16_groups): + swappable_optimizer_subgroup = optimizer_z3._swappable_optimizer_subgroup(i) + swappable_param_subgroup = optimizer_z3.fp16_partitioned_groups_flat[i] is None + + num_elements = int(optimizer_z3.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + optimizer_z3._optimizer_states_and_gradient_swap_in(i, timer_names) + + if optimizer_z3.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=optimizer_z3.device) + if optimizer_z3.offload_optimizer_pin_memory: + subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) + + optimizer_z3.fp32_partitioned_groups_flat[i].grad = None + optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad = [ + subgroup_gradient_buffer.to(optimizer_z3.subgroup_to_device[i]), + subgroup_gradient_buffer.clone().to(optimizer_z3.subgroup_to_device[i]) + ] + else: + optimizer_z3.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) + + if swappable_param_subgroup: + optimizer_z3._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + optimizer_z3._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + # Initialize the optimizer states with the flattened fp32 partition. + if is_adagrad: + optimizer_z3.optimizer = torch.optim.Adagrad(optimizer_z3.fp32_partitioned_groups_flat, + **optimizer_z3.optimizer.defaults) + + optimizer_z3.timers(INIT_OPTIMIZER_TIMER).stop() + optimizer_z3.timers.log(timer_names) + + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.log_timers() + + if not optimizer_z3.offload_optimizer: + for group in optimizer_z3.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + +def get_overlap_step_state(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> int: + if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds: + return optimizer_z3.micro_step & 1 + else: + if not optimizer_z3.auto_update: + return (optimizer_z3.micro_step // optimizer_z3.update_interval) & 1 + else: + return optimizer_z3.zenflow_state + + +@instrument_w_nvtx +def partition_grads(optimizer_z3, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + offload_fp32_gradients = {} + offload_fp32_offsets = {} + buffers = [] + for param, grad_partition in zip(params_to_release, grad_partitions): + + contains_real_data = param.partition_numel() * dist.get_rank(optimizer_z3.dp_process_group) < param.ds_numel + if not contains_real_data: + # this grad partition is empty - don't need to do anything + param.grad = None + continue + + # move or accumulate gradient partition to target buffer + param_id_to_grad_partition = getattr(optimizer_z3, + f"_{optimizer_z3.__class__.__name__}__param_id_to_grad_partition") + grad_buffer = param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) + buffers.append(grad_buffer) + if optimizer_z3.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif get_accelerator().on_accelerator(grad_buffer): + grad_buffer.add_(grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(grad_buffer.shape)) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + cuda_grad_buffer.add_( + grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + # offload the gradient partition if applicable + if optimizer_z3.offload_optimizer: + i, dest_offset, _ = optimizer_z3.grad_position[optimizer_z3.get_param_id(param)] + now_state = optimizer_z3.get_overlap_step_state() + + if optimizer_z3.is_gradient_accumulation_boundary: + optimizer_z3.norm_for_param_grads[optimizer_z3.get_param_id( + param)] = optimizer_z3._constant_buffered_norm2(grad_buffer) + + if optimizer_z3._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad[now_state].narrow( + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.float()) + + # free the gradient + if not get_accelerator().is_synchronized_device(): + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) + param.grad = None + + if optimizer_z3.offload_optimizer and optimizer_z3.swap_optimizer: + for i in offload_fp32_gradients.keys(): + optimizer_z3.optimizer_swapper.swap_out_gradients(parameter=optimizer_z3.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + return buffers + + +@instrument_w_nvtx +def unscale_and_clip_grads(self, sub_group_id, total_norm, now_state): + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale + + self.fp32_partitioned_groups_flat[sub_group_id].overlap_grad[now_state].mul_(1. / combined_scale) + + +def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_grad_norm): + + if not optimizer_z3.process_optimizer_established: + optimizer_z3.start_optimizer_process() + + group_infos = [] + for group_no, group in enumerate(optimizer_z3.fp16_groups): + optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state) + param_group_id = optimizer_z3.sub_group_to_group_id[group_no] + + group_info = { + "lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"], + "betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"], + "eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"], + "weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"], + "bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"], + } + + group_infos.append(group_info) + + optimizer_z3.parent_conn.send({ + "type": "step", + "now_state": now_state, + "micro_step": optimizer_z3.micro_step, + "group_infos": group_infos + }) + + +def wait_last_update_and_copy(optimizer_z3, timer_names): + + if not hasattr(optimizer_z3, 'parent_conn'): + return + + if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup: + optimizer_z3.first_update_round_after_warmup = False + return + + msg = optimizer_z3.parent_conn.recv() + assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + + for sub_group_id, group in enumerate(optimizer_z3.fp16_groups): + if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None: + optimizer_z3.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + optimizer_z3.fp32_partitioned_groups_flat[sub_group_id].stale_param.data) + + #unflatten fp16 parameter subgroup + optimizer_z3._unflatten_partitioned_parameters(sub_group_id) + else: + optimizer_z3._partitioned_params_swap_out(sub_group_id) + + optimizer_z3._post_step(timer_names) + + # warn user about caching allocator flushes + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > optimizer_z3.n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "get_accelerator().empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - optimizer_z3.n_caching_allocator_flushes) + optimizer_z3.n_caching_allocator_flushes = alloc_retries + + +@instrument_w_nvtx +def step(optimizer_z3, closure=None): + """ + Not supporting closure. + """ + optimizer_z3._pre_step() + optimizer_z3._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if optimizer_z3._overflow_check_and_loss_scale_update(): + if optimizer_z3.swap_optimizer: + optimizer_z3.optimizer_swapper.log_timers() + return + + norm_groups = optimizer_z3._get_norm_groups() + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) + + # Stash unscaled gradient norm + optimizer_z3._global_grad_norm = scaled_global_grad_norm / optimizer_z3.loss_scale + + if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds: + optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm) + + timer_names = set() + + timer_names.add(OPTIMIZER_STEP_TIMER) + + optimizer_z3.wait_last_update_and_copy(timer_names) + + if optimizer_z3.micro_step >= optimizer_z3.full_warm_up_rounds: + optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm) + + return diff --git a/deepspeed/runtime/zenflow/zenflow_config.py b/deepspeed/runtime/zenflow/zenflow_config.py new file mode 100644 index 000000000000..9482961d642d --- /dev/null +++ b/deepspeed/runtime/zenflow/zenflow_config.py @@ -0,0 +1,69 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from pydantic import Field, model_validator +from typing import Optional, Union + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class ZenFlowConfig(DeepSpeedConfigModel): + """Configuration options for ZenFlow optimization module.""" + + topk_ratio: float = Field(0.1, ge=0.0, le=1.0) + """Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0).""" + + select_strategy: str = "auto" + """Strategy for selecting important gradient indices. + Options: "auto", "step", or "epoch".""" + + select_interval: Union[str, int] = "auto" + """Interval at which to reselect important gradient indices. + Can be "auto" or a fixed integer step/epoch interval.""" + + update_interval: Union[str, int] = "auto" + """Interval for applying accumulated unimportant gradients to model parameters. + Can be "auto" or a fixed integer step interval.""" + + overlap_step: bool = False + """Whether to overlap CPU-side optimizer steps with forward/backward computation.""" + + offload: bool = False + """Whether to offload selective optimizer states to CPU to save memory.""" + + auto_ratio: float = Field(0.99, ge=0.0, le=1.0) + """Threshold used in the "auto" strategy to determine update_interval.""" + + full_warm_up_rounds: int = 0 + """Number of initial rounds during which all gradients are fully updated (no selection).""" + + pt_reserved_cores_perc: float = Field(0.5, ge=0.0, le=1.0) + """Number of cores reserved for pytorch threads, + the remaining cores will be used by zenflow optimizer workers""" + + steps_per_epoch: Optional[int] = Field( + default=None, + description= + "Number of steps per epoch. This field is initialized during execution and should not be set by users.", + exclude=True) + + @model_validator(mode="after") + def validate_fields(self): + if self.select_strategy not in ["auto", "step", "epoch"]: + raise ValueError('select_strategy must be one of "auto", "step", or "epoch"') + + if isinstance(self.select_interval, str) and self.select_interval != "auto": + raise ValueError('If select_interval is a string, it must be "auto"') + + if isinstance(self.update_interval, str) and self.update_interval != "auto": + raise ValueError('If update_interval is a string, it must be "auto"') + + if not isinstance(self.full_warm_up_rounds, int): + raise ValueError('full_warm_up_rounds must be an integer') + + if not isinstance(self.pt_reserved_cores_perc, float): + raise ValueError('pt_reserved_cores_perc must be a float') + + return self diff --git a/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py new file mode 100644 index 000000000000..2f5e423f1320 --- /dev/null +++ b/deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py @@ -0,0 +1,793 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed import comm as dist + +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer +from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process +from deepspeed.runtime.utils import (see_memory_usage) +from deepspeed.ops.adam import ZenFlowSelectiveAdamW + +from deepspeed.moe.utils import is_moe_param + +from deepspeed.accelerator import get_accelerator + +from deepspeed.runtime.utils import all_gather_dp_groups + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather' +OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' +OPTIMIZER_STEP_TIMER = 'optimizer_step' +OPTIMIZER_TRANSMIT_TIMER = 'optimizer_transmit_time' +OPTIMIZER_CALC_TIMER = 'optimizer_calc_time' +OPTIMIZER_RECV_PARAMS_TIMER = 'optimizer_receive_params_time' +OPTIMIZER_UPDATE_MODEL_TIMER = 'optimizer_update_model_time' +OPTIMIZER_TIMERS = [ + OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER, OPTIMIZER_TRANSMIT_TIMER, + OPTIMIZER_CALC_TIMER, OPTIMIZER_RECV_PARAMS_TIMER, OPTIMIZER_UPDATE_MODEL_TIMER +] +INITIAL_MICRO_STEP_ID = -1 + +SELECTIVE_OPTIMIZER_UPDATE_TIMER = 'selective_optimizer_update' +SELECTIVE_OPTIMIZER_PROCESS_TIMER = 'selective_optimizer_process' +SELECTIVE_OPTIMIZER_STEP_TIMER = 'selective_optimizer_step' +SELECTIVE_OPTIMIZER_SYNC_TIMER = 'selective_optimizer_sync' +SELECTIVE_OPTIMIZER_TIMERS = [ + SELECTIVE_OPTIMIZER_UPDATE_TIMER, SELECTIVE_OPTIMIZER_PROCESS_TIMER, SELECTIVE_OPTIMIZER_STEP_TIMER, + SELECTIVE_OPTIMIZER_SYNC_TIMER +] + + +class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer): + + def __init__( + self, + init_optimizer, + param_names, + timers, + optimizer_params, + **kwargs, + ): + + super().__init__(init_optimizer, param_names, timers, optimizer_params, **kwargs) + + zenflow_config = kwargs.get("zenflow_config", None) + + self.micro_step = -1 + self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds + self.offload_selective_optimizer = zenflow_config.offload + self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc + self.start_optimizer_process = lambda: start_optimizer_process(self) + self.zf_stage3 = False + + if self.offload_selective_optimizer: + assert kwargs.get("overlap_comm", False), "offload selective optimizer should be used with overlap_comm" + + self._configure_zenflow(zenflow_config) + + + self.selective_optimizer = ZenFlowSelectiveAdamW([{"params": group} for group in self.bit16_groups], \ + offload=zenflow_config.offload, + bucket_size=self.allgather_bucket_size, + **optimizer_params) + self.num_total_param = sum(sum(1 for param in group if len(param.shape) != 1) for group in self.bit16_groups) + + @classmethod + def create(cls, zenflow_config): + if zenflow_config.overlap_step: + return ZenFlowZeroOptimizerParallel + else: + return ZenFlowZeroOptimizerSequential + + def _configure_zenflow(self, zenflow_config): + """ + Configure ZenFlow optimizer + """ + if not self.cpu_offload: + raise ValueError("Zenflow must be used with cpu offload") + + self.select_strategy = zenflow_config.select_strategy + if self.select_strategy == 'auto': + self.select_strategy = "epoch" + if isinstance(zenflow_config.select_interval, int): + raise Warning( + "If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten." + ) + self.select_interval = 1 + else: + if isinstance(zenflow_config.select_interval, str): + raise ValueError("If don't use auto select strategy, select_interval must be a number.") + self.select_interval = int(zenflow_config.select_interval) + + if isinstance(zenflow_config.update_interval, str): + self.auto_update = True + self.update_interval = 0 + else: + self.auto_update = False + self.update_interval = int(zenflow_config.update_interval) + + if self.select_strategy == 'epoch': + if zenflow_config.steps_per_epoch is not None: + self.select_interval = self.select_interval * zenflow_config.steps_per_epoch + else: + self.select_interval = 0 + + if not self.auto_update and self.select_interval != 0 and self.select_interval < self.update_interval: + raise ValueError("Select interval must be greater or equal to update interval") + + self.topk_ratio = zenflow_config.topk_ratio + + self.param_id_index_buffer_offset = {} + self.param_id_grad_buffer_offset = {} + + if self.auto_update: + self.param_id_sum_buffer_offset = {} + self.auto_ratio = zenflow_config.auto_ratio + self.zenflow_need_update = [False, False] + self.zenflow_state = 0 + self.num_need_update = 0 + + def is_zenflow_select_boundary(self): + return self.zenflow and (self.micro_step - self.full_warm_up_rounds) >= 0 and ( + (self.micro_step - self.full_warm_up_rounds) == 0 or + (self.select_interval != 0 and self.micro_step % self.select_interval == 0)) + + def sync_fp32_param_from_gpu(self): + if self.micro_step == 0: + return + + for i, group in enumerate(self.bit16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] + + with torch.no_grad(): + fp32_partition.copy_(bit16_partitions[partition_id].to(dtype=fp32_partition.dtype, + device=fp32_partition.device)) + + def update_selected_channels(self, tensor, total_size, communication_data_type): + curr_size = 0 + curr_index_buffer_size = 0 + rank_and_offsets = [] + prev_id, prev_process_group = -1, None + + process_group = self.dp_process_group + rank = dist.get_rank(process_group) + + self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name()) + + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] + + if len(param.shape) == 1: + continue + + if not hasattr(param, 'selected_indices'): + param.selected_indices = None + + partition_ids = self.param_to_partition_ids[i][param_id] + + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + if idx == len(partition_ids_w_offsets) - 1: + numel = param.numel() - offset + else: + numel = partition_ids_w_offsets[idx + 1][1] - offset + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + start_column = 0 if not offset else int((offset - 1) / num_row) + 1 + end_column = int((offset + numel) / num_row) + num_select = int(self.topk_ratio * (end_column - start_column)) + + if partition_id == rank: + + start_idx = int(curr_size + start_column * num_row - offset) + num_elements = (end_column - start_column) * num_row + sum_per_column = tensor.narrow(0, start_idx, num_elements) + sum_per_column = sum_per_column.view(end_column - start_column, num_row) + sum_array = sum_per_column.abs().sum(dim=1) + + _, top_indices = torch.topk(sum_array, num_select) + top_indices += start_column + self.index_buffer.narrow(0, curr_index_buffer_size, num_select).copy_(top_indices) + + if partition_id == prev_id and process_group == prev_process_group: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + num_select) + else: + rank_and_offsets.append((partition_id, curr_index_buffer_size, num_select)) + + if param_id not in self.param_id_index_buffer_offset: + self.param_id_index_buffer_offset[param_id] = [] + self.param_id_index_buffer_offset[param_id].append((curr_index_buffer_size, num_select)) + + curr_size += numel + curr_index_buffer_size += num_select + + for src_rank, offset, num_select in rank_and_offsets: + index_slice = self.index_buffer.narrow(0, offset, num_select) + dist.broadcast(index_slice, src=src_rank, group=process_group) + + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] + + if len(param.shape) == 1: + continue + + param.selected_indices = None + param.partition_selected_indices = [] + + for offset, num_select in self.param_id_index_buffer_offset[param_id]: + selected = self.index_buffer.narrow(0, offset, num_select).clone().sort()[0] + if param.selected_indices is None: + param.selected_indices = selected + else: + param.selected_indices = torch.cat([param.selected_indices, selected]) + param.partition_selected_indices.append(selected) + + self.param_id_index_buffer_offset[param_id] = [] + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + param.selected_indices.sort() + param.selected_shape = (param.selected_indices.shape[0], + num_row) if num_row != 1 else (param.selected_indices.shape[0], ) + + self.index_buffer = None + + def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_data_type): + """ + Process gradients for selected columns in FP32 groups + + Args: + param: The parameter to process + param_id: ID of the parameter + """ + + curr_size = 0 + curr_grad_buffer_size = 0 + curr_sum_buffer_size = 0 + rank_and_offsets = [] + prev_id, prev_process_group = -1, None + + process_group = self.dp_process_group + rank = dist.get_rank(process_group) + + self.grad_buffer = torch.empty(total_size, dtype=self.dtype, device=get_accelerator().current_device_name()) + + bucket = self.ipg_buckets[communication_data_type] + if self.auto_update: + self.sum_buffer = torch.empty(len(bucket.params) + dist.get_world_size(group=process_group), + dtype=torch.bfloat16, + device=get_accelerator().current_device_name()) + + group_to_paramlist = {} + + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] + + if not hasattr(param, 'selected_indices'): + param.selected_indices = None + + partition_ids = self.param_to_partition_ids[i][param_id] + + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + if idx == len(partition_ids_w_offsets) - 1: + numel = param.numel() - offset + else: + numel = partition_ids_w_offsets[idx + 1][1] - offset + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + start_column = 0 if not offset else int((offset - 1) / num_row) + 1 + end_column = int((offset + numel) / num_row) + num_select = int(self.topk_ratio * (end_column - start_column)) if len(param.shape) == 2 else numel + grad_size = num_select * num_row + + if partition_id == rank: + selected_grad = param.grad[ + param.partition_selected_indices[idx], :] if num_row != 1 else param.grad[offset:offset + + numel] + self.grad_buffer.narrow(0, curr_grad_buffer_size, grad_size).copy_(selected_grad.view(-1)) + + if self.auto_update: + self.sum_buffer[curr_sum_buffer_size] = tensor.narrow(0, int(curr_size), + int(numel)).abs().sum() + + if partition_id == prev_id and process_group == prev_process_group: + if self.auto_update: + prev_pid, prev_size, prev_numel, prev_sum_size, prev_sum_num = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + grad_size, prev_sum_size, + prev_sum_num + 1) + else: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + grad_size) + else: + if self.auto_update: + rank_and_offsets.append( + (partition_id, curr_grad_buffer_size, grad_size, curr_sum_buffer_size, 1)) + else: + rank_and_offsets.append((partition_id, curr_grad_buffer_size, grad_size)) + + if param_id not in self.param_id_grad_buffer_offset: + self.param_id_grad_buffer_offset[param_id] = [] + if self.auto_update and param_id not in self.param_id_sum_buffer_offset: + self.param_id_sum_buffer_offset[param_id] = [] + self.param_id_grad_buffer_offset[param_id].append((curr_grad_buffer_size, grad_size)) + if self.auto_update: + self.param_id_sum_buffer_offset[param_id].append(curr_sum_buffer_size) + + curr_size += numel + curr_grad_buffer_size += grad_size + curr_sum_buffer_size += 1 + + for item in rank_and_offsets: + if self.auto_update: + src_rank, offset, grad_size, sum_offset, sum_num = item + else: + src_rank, offset, grad_size = item + + grad_slice = self.grad_buffer.narrow(0, offset, grad_size) + dist.broadcast(grad_slice, src=src_rank, group=process_group) + + if self.auto_update: + sum_slice = self.sum_buffer.narrow(0, sum_offset, sum_num) + dist.broadcast(sum_slice, src=src_rank, group=process_group) + + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] + + selected_grad = None + for offset, grad_size in self.param_id_grad_buffer_offset[param_id]: + selected_grad_buffer = self.grad_buffer.narrow(0, offset, grad_size).clone().detach() + if selected_grad is None: + selected_grad = selected_grad_buffer + else: + selected_grad = torch.cat([selected_grad, selected_grad_buffer]) + param.selected_grad = selected_grad.view(param.selected_shape).t() if len( + param.shape) != 1 else selected_grad + + if self.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'): + buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=self.device) + param.exp_avg_cpu_data = get_accelerator().pin_memory( + buffer) if self.cpu_offload_pin_memory else buffer + param.exp_avg_sq_cpu_data = get_accelerator().pin_memory( + buffer.clone()) if self.cpu_offload_pin_memory else buffer.clone() + + param_list = group_to_paramlist.setdefault(i, []) + param_list.append(param) + + self.param_id_grad_buffer_offset[param_id] = [] + + if self.auto_update: + grad_total_sum = 0 + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + if num_row == 1: + continue + + for offset in self.param_id_sum_buffer_offset[param_id]: + grad_total_sum += self.sum_buffer.narrow(0, offset, 1) + + grad_critic_sum = param.selected_grad.abs().sum() + + if not hasattr(param, 'non_critic_sum'): + param.non_critic_sum = 0 + if not hasattr(param, 'avg_critic_sum'): + param.avg_critic_sum = 0 + + param.avg_critic_sum = (param.avg_critic_sum * (self.update_interval - 1) + + grad_critic_sum) / self.update_interval / (self.topk_ratio * 10) + param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - self.topk_ratio) * 10) + if param.non_critic_sum >= param.avg_critic_sum: + self.num_need_update += 1 + + if self.num_need_update >= int(self.auto_ratio * self.num_total_param): + self.zenflow_need_update[self.zenflow_state] = True + + self.param_id_sum_buffer_offset[param_id] = [] + + if not self.is_gradient_accumulation_boundary: + self.selective_optimizer.group_step(group_to_paramlist) + else: + self.selective_optimizer.temp_copy_param(group_to_paramlist) + + self.grad_buffer = None + if self.auto_update: + self.sum_buffer = None + + def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dtype): + if self.overlap_comm: + stream = self.reduction_stream + if not get_accelerator().resolves_data_dependency(): + stream.wait_stream(get_accelerator().current_stream()) + get_accelerator().current_stream().wait_stream(stream) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + if not self.reduce_scatter: + self.gradient_reduction_w_predivide(tensor) + return + + # Accumulate destination ranks and bucket offsets for each gradient slice. + # Note: potential future optimization, record access pattern of parameters + # in backward pass and partition gradients w.r.t. access pattern so that our + # bucket is guaranteed to be contiguous w.r.t. ranks + rank_and_offsets = [] + real_dp_process_group = [] + curr_size = 0 + prev_id, prev_process_group = -1, None + + curr_column_size = 0 + curr_selected_reduce_size = 0 + + process_group = self.dp_process_group + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] + + process_group = self.dp_process_group + + if bucket.has_moe_params: + process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( + param) else self.dp_process_group + + partition_ids = self.param_to_partition_ids[i][param_id] + assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids + ]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}" + partition_size = self.partition_size[i] + # Get all partition ids + their offsets + partition_ids_w_offsets = [] + for partition_id in partition_ids: + offset = self.grad_start_offset[i][partition_id][param_id] + partition_ids_w_offsets.append((partition_id, offset)) + partition_ids_w_offsets.sort(key=lambda t: t[1]) + + num_row, num_col = param.shape if len(param.shape) == 2 else (1, param.shape[0]) + curr_column_size += int(num_col * self.topk_ratio) if num_row != 1 else 0 + + # Calculate rank and offsets for grad slices + for idx in range(len(partition_ids_w_offsets)): + partition_id, offset = partition_ids_w_offsets[idx] + + # Calculate numel for grad slice depending on partition location + if idx == len(partition_ids_w_offsets) - 1: + # Last partition_id uses its own offset + numel = param.numel() - offset + else: + # Set numel to next partition's offset + numel = partition_ids_w_offsets[idx + 1][1] - offset + + # Merge bucket ranges if they belong to the same rank + if partition_id == prev_id and process_group == prev_process_group: + prev_pid, prev_size, prev_numel = rank_and_offsets[-1] + rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) + else: + rank_and_offsets.append((partition_id, curr_size, numel)) + real_dp_process_group.append(process_group) + curr_size += numel + curr_selected_reduce_size += int(numel * self.topk_ratio) if num_row != 1 else numel + + prev_id, prev_process_group = partition_id, process_group + + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + + buckets = {} + for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): + grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) + bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( + dst, real_dp_process_group[i]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + if self.use_multi_rank_bucket_allreduce: + buckets[bucket_key].append((dst, grad_slice)) + else: + buckets[bucket_key].append(grad_slice) + + for bucket_key in buckets: + if self.use_multi_rank_bucket_allreduce: + self.allreduce_and_scatter(buckets[bucket_key], + communication_data_type, + numel_per_bucket=self.reduce_bucket_size, + divide=False, + process_group=bucket_key) + else: + dst, process_group = bucket_key + self.allreduce_no_retain(buckets[bucket_key], + communication_data_type, + numel_per_bucket=self.reduce_bucket_size, + rank=dst, + divide=False, + process_group=process_group) + + if self.is_zenflow_select_boundary(): + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() + self.update_selected_channels(tensor, curr_column_size, communication_data_type) + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() + elif self.zenflow: + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() + self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() + + if self.zenflow and self.micro_step >= self.full_warm_up_rounds: + self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).start() + self._process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size, communication_data_type) + self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop() + + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + self.backward_prologue() + self.micro_step += 1 + + if self.auto_update: + self.zenflow_need_update[self.zenflow_state] = False + self.num_need_update = 0 + if self.zenflow_need_update[self.zenflow_state ^ 1]: + self.update_interval = 0 + for group in self.bit16_groups: + for p in group: + p.non_critic_sum = 0 + self.update_interval += 1 + + if self.is_zenflow_select_boundary(): + self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).start() + self.sync_fp32_param_from_gpu() + self.selective_optimizer.clear_selected_mv() + self.timers(SELECTIVE_OPTIMIZER_SYNC_TIMER).stop() + + self.enter_backward() + if self.custom_loss_scaler: + scaled_loss = self.external_loss_scale * loss + scaled_loss.backward(retain_graph=retain_graph) + else: + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.backward_epilogue() + self.exit_backward() + + def log_selective_optimizer_timers(self): + self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) + + def _sync_selective_optimizer_lr(self): + for group_selected, group in zip(self.selective_optimizer.param_groups, self.optimizer.param_groups): + group_selected["lr"] = group["lr"] + + def _selective_optimizer_step(self, group_no): + original_param_groups = self.selective_optimizer.param_groups + self.selective_optimizer.param_groups = [original_param_groups[group_no]] + self.selective_optimizer.step() + self.selective_optimizer.param_groups = original_param_groups + + def selective_optimizer_step(self, closure=None): + for i, group in enumerate(self.bit16_groups): + self.timers(SELECTIVE_OPTIMIZER_STEP_TIMER).start() + self._selective_optimizer_step(i) + self.timers(SELECTIVE_OPTIMIZER_STEP_TIMER).stop() + + self.timers.log(SELECTIVE_OPTIMIZER_TIMERS) + + +class ZenFlowZeroOptimizerSequential(ZenFlowZeroOptimizer): + + def __init__(self, *args, **kwargs): + super(ZenFlowZeroOptimizerSequential, self).__init__(*args, **kwargs) + + def zenflow_cpu_optimizer_step(self, group_no): + self.optimizer.step(step_id=self.micro_step + 1) + + +class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer): + + def __init__(self, *args, **kwargs): + super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs) + self.process_optimizer_established = False + self.first_update_round_after_warmup = True + + def initialize_optimizer_states(self): + + for i, group in enumerate(self.bit16_groups): + single_grad_partition = torch.zeros(int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=self.device) + self.single_partition_of_fp32_groups[i].grad = None + buffer = get_accelerator().pin_memory( + single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition + self.single_partition_of_fp32_groups[i].overlap_grad = [buffer, buffer.clone()] + + # Initialize the optimizer states with the flattened fp32 partition. + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + if isinstance(self.optimizer, torch.optim.Adagrad): + self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) + + if not self.cpu_offload: + for group in self.single_partition_of_fp32_groups: + group.grad = None #class init + + return + + def _get_offload_gradient_dict(self): + for param_group_index, _ in enumerate(self.optimizer.param_groups): + self.offload_gradient_dict[param_group_index] = [] + for lp_param in self.params_in_partition[param_group_index]: + param_id = self.get_param_id(lp_param) + [_, _, dest_offset, num_elements] = self.grad_position[param_id] + dest_tensor = self.single_partition_of_fp32_groups[param_group_index].overlap_grad[0].view(-1).narrow( + 0, dest_offset, num_elements) + self.offload_gradient_dict[param_group_index].append(dest_tensor) + + def get_overlap_step_state(self): + if self.micro_step < self.full_warm_up_rounds: + return self.micro_step & 1 + else: + if not self.auto_update: + return (self.micro_step // self.update_interval) & 1 + else: + return self.zenflow_state + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): + param_id = self.get_param_id(param) + now_state = self.get_overlap_step_state() + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + dest_tensor = self.single_partition_of_fp32_groups[i].overlap_grad[now_state].view(-1).narrow( + 0, dest_offset, num_elements) + + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + else: + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + if src_tensor.dtype != self.master_weights_and_grads_dtype: + src_tensor = src_tensor.to(self.master_weights_and_grads_dtype) + + dest_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None #offload only + + def wait_last_update_and_copy(self): + + if not hasattr(self, 'parent_conn'): + return + + if self.micro_step + 1 > self.full_warm_up_rounds and self.first_update_round_after_warmup: + self.first_update_round_after_warmup = False + return + + self.timers(OPTIMIZER_RECV_PARAMS_TIMER).start() + msg = self.parent_conn.recv() + assert msg["type"] == "done", "Optimizer process did not finish stepping correctly." + self.timers(OPTIMIZER_RECV_PARAMS_TIMER).stop() + + for i, group in enumerate(self.bit16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.optimizer.param_groups[i]['params'][0].stale_param.data + self.timers(OPTIMIZER_TRANSMIT_TIMER).start() + bit16_partitions[partition_id].data.copy_(fp32_partition.to(get_accelerator().current_device_name()).data) + self.timers(OPTIMIZER_TRANSMIT_TIMER).stop() + + see_memory_usage('After optimizer before all-gather') + if self.cpu_offload: + self.reset_cpu_buffers() + + self.timers(OPTIMIZER_ALLGATHER_TIMER).start() + # Gather the updated weights from everyone. + # Then all partitions of the model parameters are updated and ready for next round forward. + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() + + self.timers(OPTIMIZER_UPDATE_MODEL_TIMER).start() + # TODO: we probably don't need this? just to be safe + for i in range(len(self.bit16_groups)): + self._update_model_bit16_weights(i) + self.timers(OPTIMIZER_UPDATE_MODEL_TIMER).stop() + + self.timers.log(OPTIMIZER_TIMERS) + see_memory_usage('After zero_optimizer step') + + def zenflow_cpu_optimizer_step(self, now_state, scaled_global_grad_norm): + + if not self.process_optimizer_established: + self.start_optimizer_process() + + group_infos = [] + for group_no, group in enumerate(self.bit16_groups): + single_grad_partition = self.single_partition_of_fp32_groups[group_no].overlap_grad[now_state] + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + + group_info = { + "lr": self.optimizer.param_groups[group_no]["lr"], + "betas": self.optimizer.param_groups[group_no]["betas"], + "eps": self.optimizer.param_groups[group_no]["eps"], + "weight_decay": self.optimizer.param_groups[group_no]["weight_decay"], + "bias_correction": self.optimizer.param_groups[group_no]["bias_correction"], + } + + group_infos.append(group_info) + + self.parent_conn.send({ + "type": "step", + "now_state": now_state, + "micro_step": self.micro_step, + "group_infos": group_infos + }) + + def step(self, closure=None): + """ + Not supporting closure. + """ + self.micro_step_id = INITIAL_MICRO_STEP_ID + + see_memory_usage(f"In step before checking overflow") + + # First compute norm for all group so we know if there is overflow + if self.dtype == torch.float16: + self.check_overflow() + + self._update_scale(self.overflow) + if self.overflow: + see_memory_usage('After overflow before clearing gradients') + self.zero_grad(set_to_none=True) + if self.cpu_offload: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients') + + for timer in OPTIMIZER_TIMERS: + self.timers(timer).start() + self.timers(timer).stop() + return + + prev_scale = self.loss_scale + # Step 1:- Calculate gradient norm using bit-16 grads + see_memory_usage('Before norm calculation') + scaled_global_grad_norm = self.scaled_global_norm() + self._global_grad_norm = scaled_global_grad_norm / prev_scale + see_memory_usage('After norm before optimizer') + + if self.micro_step < self.full_warm_up_rounds: + self.zenflow_cpu_optimizer_step(self.get_overlap_step_state(), scaled_global_grad_norm) + + self.wait_last_update_and_copy() + + if self.micro_step >= self.full_warm_up_rounds: + self.zenflow_cpu_optimizer_step(self.get_overlap_step_state(), scaled_global_grad_norm) + + return diff --git a/deepspeed/runtime/zenflow/zenflow_utils.py b/deepspeed/runtime/zenflow/zenflow_utils.py new file mode 100644 index 000000000000..f238b3626506 --- /dev/null +++ b/deepspeed/runtime/zenflow/zenflow_utils.py @@ -0,0 +1,191 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import math +import torch +import psutil +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator + + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + transposed_tensors = [t.transpose(0, 1).contiguous() if t.dim() == 2 else t for t in tensors] + return torch._C._nn.flatten_dense_tensors(transposed_tensors) + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors] + unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors) + return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat] + + +def disable_accelerator(): + accelerator = get_accelerator() + accelerator.is_available = lambda: False + accelerator.device_count = lambda: 0 + accelerator.current_device = lambda: -1 + # Optionally mark it as initialized if needed + if hasattr(accelerator, "_initialized"): + accelerator._initialized = True + + +def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity): + disable_accelerator() + + current_process = psutil.Process() + current_process.cpu_affinity(zf_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity)) + + from deepspeed.ops.adam import ZenFlowCPUAdam + optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) + + pipe.send({"type": "ready"}) + + # TODO: replace this with rpc + + while True: + cmd = pipe.recv() + if cmd["type"] == "step": + now_state = cmd["now_state"] + micro_step = cmd["micro_step"] + group_infos = cmd["group_infos"] + + for group_no, group_info in enumerate(group_infos): + original_param_groups = optimizer.param_groups + optimizer.param_groups = [original_param_groups[group_no]] + group = optimizer.param_groups[0] + + for param_idx, param in enumerate(group["params"]): + key = (group_no, param_idx) + if key in shared_overlap_grad_map: + param.overlap_grad = shared_overlap_grad_map[key] + if key in shared_stale_param_map: + param.stale_param = shared_stale_param_map[key] + + optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info) + + optimizer.param_groups = original_param_groups + + pipe.send({"type": "done"}) + elif cmd["type"] == "exit": + break + + +def all_tensors_equal(tensor_list): + first_tensor = tensor_list[0] + for tensor in tensor_list[1:]: + if not torch.equal(first_tensor, tensor): + return False + return True + + +def start_optimizer_process(zf_optimizer): + from multiprocessing import Pipe, get_context, Manager + + ctx = get_context("spawn") + zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe() + + manager = Manager() + zf_optimizer.shared_overlap_grad_map = manager.dict() + zf_optimizer.shared_stale_param_map = manager.dict() + + if zf_optimizer.zf_stage3: + params_iter = [((group_no, 0), param) + for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)] + else: + params_iter = [((group_no, param_idx), param) + for group_no, group in enumerate(zf_optimizer.optimizer.param_groups) + for param_idx, param in enumerate(group["params"])] + + for key, param in params_iter: + param.data.share_memory_() + + if not hasattr(param, "stale_param"): + param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device) + param.stale_param.data.share_memory_() + zf_optimizer.shared_stale_param_map[key] = param.stale_param + + if getattr(param, "overlap_grad", None) is not None: + param.overlap_grad[0].data.share_memory_() + param.overlap_grad[1].data.share_memory_() + zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad + + param_groups_data = ([{ + "params": [param] + } for param in zf_optimizer.fp32_partitioned_groups_flat] + if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups) + + curr_rank = dist.get_rank() + total_rank = dist.get_world_size() + + current_process = psutil.Process() + current_affinity = current_process.cpu_affinity() + all_affinities = [ + torch.zeros(len(current_affinity), + dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name()) for _ in range(total_rank) + ] + dist.all_gather( + all_affinities, + torch.tensor(current_affinity, dtype=type(current_affinity[0]), + device=get_accelerator().current_device_name())) + # When affinity across all ranks are the same, the workers are not binded. Do a soft bind here + if all_tensors_equal(all_affinities): + num_phy_cores = psutil.cpu_count(logical=False) + available_phy_cores = [i for i in current_affinity if i < num_phy_cores] + num_available_phy_cores = len(available_phy_cores) + my_rank = curr_rank + my_size = total_rank + cores_per_rank = num_available_phy_cores // my_size + current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank] + pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity)) + if pt_num_cores > 0 and pt_num_cores < len(current_affinity): + zf_affinity = current_affinity[pt_num_cores:] + pt_affinity = current_affinity[:pt_num_cores] + else: + zf_affinity = current_affinity + pt_affinity = current_affinity + + zf_optimizer.process = ctx.Process( + target=zenflow_optimizer_process, + args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map, + zf_optimizer.shared_stale_param_map, zf_affinity), + ) + zf_optimizer.process.daemon = True + zf_optimizer.process.start() + + current_process.cpu_affinity(pt_affinity) + os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity)) + + msg = zf_optimizer.parent_conn.recv() + assert msg["type"] == "ready", "Optimizer process did not initialize correctly." + + zf_optimizer.process_optimizer_established = True diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 1307fc911625..8b045dbc8f97 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -8,6 +8,12 @@ from .partition_parameters import Init from .partition_parameters import GatheredParameters from .partition_parameters import register_external_parameter +from .parameter_offload import DeepSpeedZeRoOffload +from .partition_parameters import DeepSpeedTensorOverride from .tiling import TiledLinear from .tiling import TiledLinearReturnBias + +from .mics import MiCS_Init + +from .stage3 import unwrap_model_for_generation diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index cc204afa7bf2..79fbcb97a188 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -3,13 +3,15 @@ # DeepSpeed Team -from pydantic import Field, validator import sys -from typing import Optional +from typing import Optional, Dict, Any from enum import Enum +from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum +from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig +from .leaf_module_config import DeepSpeedZeroLeafModuleConfig # ZeRO optimization. By default, this optimization is not enabled. # Users have to configure the desired optimization (0 means disabled) in params.json as below example: @@ -20,7 +22,11 @@ "stage": [0|1|2], "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, + "stage3_use_all_reduce_for_fetch_params": [true|false], + "stage3_module_granularity_threshold": 0, "allgather_partitions": [true|false], + "use_multi_rank_bucket_allreduce": [true|false], + "stage3_allgather_sequential": [true|false], "allgather_bucket_size": 500000000, "reduce_scatter": [true|false], "contiguous_gradients" : [true|false] @@ -28,14 +34,22 @@ "reduce_bucket_size": 500000000, "load_from_fp32_weights": [true|false], "cpu_offload": [true|false] (deprecated), - "cpu_offload_params" : [true|false] (deprecated), + "cpu_offload_param" : [true|false] (deprecated), "cpu_offload_use_pin_memory": [true|false] (deprecated), "sub_group_size" : 1000000000000, "offload_param": {...}, "offload_optimizer": {...}, "ignore_unused_parameters": [true|false], "round_robin_gradients": [true|false], - "memory_efficient_linear": [true|false] + "zero_hpz_partition_size": 1, + "zero_quantized_weights": [true|false], + "zero_quantized_nontrainable_weights": [true|false], + "zero_quantized_gradients": [true|false], + "memory_efficient_linear": [true|false], + "override_module_apply": [true|false], + "zeropp_loco_param": {...}, + "log_trace_cache_warnings" : [true|false], + "enable_sanity_checks": [true|false], } } """ @@ -102,6 +116,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): for the allgather for large model sizes """ + use_multi_rank_bucket_allreduce: bool = True + """ + Combine the reduce buckets of the different ranks and do an All-Reduce instead of multiple Reduce ops. + This feature is useful when the model is small and we want to scale it on too many GPUs which therefore + reduces the message sizes of each packet. + """ + allgather_partitions: bool = True """ Chooses between allgather collective or a series of broadcast collectives @@ -114,7 +135,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): the allgather for large model sizes """ - overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below) + overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below) """ Attempts to overlap the reduction of the gradients with backward computation """ @@ -148,33 +169,46 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): for :any:`DeepSpeedZeroOffloadOptimizerConfig`. """ + zenflow: Optional[ZenFlowConfig] = None + """Enable ZenFlow""" + sub_group_size: int = Field(pp_int(1e9), ge=0) """ Tile size for parameter processing to fit massive models (with trillions of parameters). Used by ZeRO3-Offload and ZeRO-Infinity """ - cpu_offload_param: bool = Field( + cpu_offload_param: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param", - new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_param`` """ - cpu_offload_use_pin_memory: bool = Field( + cpu_offload_use_pin_memory: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param or offload_optimizer", - set_new_param=False, + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param or offload_optimizer", + "set_new_param": False + }, ) """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """ - cpu_offload: bool = Field( + cpu_offload: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_optimizer", - new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": + True, + "new_param": + "offload_optimizer", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_optimizer`` """ @@ -221,9 +255,33 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ + module_granularity_threshold: int = Field(pp_int(0), alias="stage3_module_granularity_threshold") + """ + The granularity of a module is determined by the ratio of "parameter_count / (1 + descendant count)". + ZeRO3 classifies modules with a granularity below the threshold as fine-grained, + which are treated as integral units during parameter fetching. This reduces host overhead + and the separate allgather overhead introduced by hooks for fine-grained layers when fetching parameters. + """ + + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") + """ + Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing + the overhead of concatenation and slicing on the host. + """ + + allgather_sequential: bool = Field(default=False, alias="stage3_allgather_sequential") + """ + Performs allgather on individual parameters sequentially, bypassing the standard parameter bucketing + mechanism in stage3. This significantly reduces data copy overhead (eliminating copy-to-bucket operations) + and lowers peak memory usage by avoiding the allocation of large temporary flattening buffers. + Recommended for scenarios with high memory pressure. + """ + stage3_gather_fp16_weights_on_model_save: bool = Field(False, - deprecated=True, - new_param="gather_16bit_weights_on_model_save") + json_schema_extra={ + "deprecated": True, + "new_param": "gather_16bit_weights_on_model_save" + }) """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """ ignore_unused_parameters: bool = True @@ -231,7 +289,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. - This is set to ``False`` by default, which means unused parameters are + This is set to ``True`` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. """ @@ -248,16 +306,88 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism). """ + zero_hpz_partition_size: int = Field(1, ge=0) + """ + Number of ranks in zero parameters partitioning secondary group + """ + zero_quantized_weights: bool = False + """ + Boolean indicating whether to quantize zero parameters (weights) + for efficient all_gather comm + """ + zero_quantized_nontrainable_weights: bool = False + """ + Boolean indicating whether to quantize non-trainable zero parameters (weights) + for efficient memory usage and communication. Different from zero_quantized_weights + that stores the weights in original precision and only perform quantization during communication, + this flag will store the weights in quantized precision. This is useful for LoRA training. + """ + zero_quantized_gradients: bool = False + """ + Boolean indicating whether to use quantized zero gradients + for efficient all_2_all_reduce comm + """ + zeropp_loco_param: Optional[Dict[str, Any]] = None + """ + This dictionary contains parameters for using LoCo-Zero++, with two key parameters: + - `err_beta`: A coefficient for the moving average of quantization errors before and after gradient computation. + It ranges between 0 and 1, with a default value of 0.8. + - `reset_T`: The number of steps after which the moving-average error buffer is cleared. The default value is 1024. + These parameters can be adjusted based on performance needs. Example configuration in ds config: + "zeropp_loco_param": { "err_beta": 0.8, "reset_T": 1024 }. + See LoCo paper for more details: (https://arxiv.org/abs/2407.04480). + """ + + mics_shard_size: int = Field(-1, json_schema_extra={"new_param": "mics_shard_size"}) + + mics_hierarchical_params_gather: bool = False memory_efficient_linear: bool = True """ Use memory efficient linear implementation, for Stage 3. """ + """ + Whether force load checkpoint in pipeline mode, current only for Stage 3. + """ + pipeline_loading_checkpoint: bool = False + + override_module_apply: bool = True + """ + Override nn.Module apply function, for Stage 3. + """ + + log_trace_cache_warnings: bool = False + """ + Whether to log warnings from trace cache, such as invalidation events. + """ + + enable_sanity_checks: bool = False + """ + Enable internal sanity checks, which could be useful for debugging + """ + + save_muon_momentum_buffer_in_memory: bool = False + """ + When using the Muon optimizer with ZeRO Stage 3, keeps the Muon momentum + buffer in GPU/CPU memory instead of swapping to NVMe with other optimizer + states. Only relevant when using NVMe offloading. + """ + + leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig) + """ + Configuration for modules that should be treated as ZeRO3 leaf modules. + """ # Validators - @validator("overlap_comm") - def overlap_comm_valid(cls, field_value, values): - if field_value is None: - assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" - field_value = values["stage"] == ZeroStageEnum.weights - return field_value + @model_validator(mode="after") + def overlap_comm_valid(self): + if self.overlap_comm is None: + self.overlap_comm = self.stage == ZeroStageEnum.weights + return self + + @model_validator(mode="after") + def offload_ratio_check(self): + offload_config = self.offload_optimizer + if offload_config and offload_config.ratio < 1.0: + assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." + return self diff --git a/deepspeed/runtime/zero/contiguous_memory_allocator.py b/deepspeed/runtime/zero/contiguous_memory_allocator.py index 79f01e5466b9..3e3b11b68cd4 100644 --- a/deepspeed/runtime/zero/contiguous_memory_allocator.py +++ b/deepspeed/runtime/zero/contiguous_memory_allocator.py @@ -78,14 +78,14 @@ def allocate_tensor(self, size): return ret_tensor #assigns the tensor data to the param data and keeps track of the assignment - #any change the the underlying buffer from defragmentation will cause a + #any change the underlying buffer from defragmentation will cause a #reassignment of the param data def assign_to_param(self, tensor, param, numel, shape): tensor_id = id(tensor) assert tensor_id in self.tensor_map.keys(), "No such tensor allocated by the allocator." assert tensor.numel() >= numel, "Assert tensor buffer does is not large enough" - assert not tensor_id in self.id_to_params.keys(), "This tensor has already been assigned to a param" + assert tensor_id not in self.id_to_params.keys(), "This tensor has already been assigned to a param" self.id_to_params[tensor_id] = [param] @@ -193,7 +193,7 @@ def _defragment_memory(self): tensor = self.tensor_map[self.tensor_ids[tensor_addr]] assert tensor_size == tensor.numel(), \ - "Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} " + f"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} " assert empty_addr != tensor_addr, \ f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}" diff --git a/deepspeed/runtime/zero/leaf_module_config.py b/deepspeed/runtime/zero/leaf_module_config.py new file mode 100644 index 000000000000..e225801b5c44 --- /dev/null +++ b/deepspeed/runtime/zero/leaf_module_config.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List +from pydantic import Field, model_validator + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + +DEFAULT_LEAF_MODULE_CLASSES: List[str] = [ + "transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock", + "transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock", + "transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock", +] +DEFAULT_LEAF_MODULE_NAMES: List[str] = [] +DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = [] + + +class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel): + """Configuration for ZeRO leaf modules that should bypass hook installation.""" + + classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES)) + names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES)) + name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES)) + + @model_validator(mode="before") + def _coerce_container_types(cls, values): + if values is None: + return {} + if isinstance(values, dict): + coerced = dict(values) + for key in ("classes", "names", "name_suffixes"): + if key in coerced and isinstance(coerced[key], str): + coerced[key] = [coerced[key]] + return coerced + raise TypeError("leaf_module configuration must be a mapping of fields to values") + + @model_validator(mode="after") + def _validate_entries(self): + normalized_classes = [str(cls) for cls in self.classes] + normalized_names = [str(name) for name in self.names] + normalized_suffixes = [str(suffix) for suffix in self.name_suffixes] + + deduped_classes = list(dict.fromkeys(normalized_classes)) + deduped_names = list(dict.fromkeys(normalized_names)) + deduped_suffixes = list(dict.fromkeys(normalized_suffixes)) + + object.__setattr__(self, "classes", deduped_classes) + object.__setattr__(self, "names", deduped_names) + object.__setattr__(self, "name_suffixes", deduped_suffixes) + return self diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index b97a833beacb..033cb77430dc 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -16,6 +16,7 @@ #when implemented outside of torch.autograd.Function import math +import functools import torch from torch import Tensor @@ -26,38 +27,62 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator -tensor_map = {} - def print_rank_0(message, debug=False, force=False): if dist.get_rank() == 0 and (debug or force): print(message) -try: - autocast_custom_fwd = get_accelerator().amp().custom_fwd - autocast_custom_bwd = get_accelerator().amp().custom_bwd -except (ImportError, AttributeError) as exp: - autocast_custom_fwd = noop_decorator - autocast_custom_bwd = noop_decorator +def _get_legacy_autocast_decorators(device_type): + legacy_amp = getattr(getattr(torch, device_type, None), 'amp', None) + custom_fwd = getattr(legacy_amp, 'custom_fwd', None) + custom_bwd = getattr(legacy_amp, 'custom_bwd', None) + if custom_fwd is not None and custom_bwd is not None: + return custom_fwd, custom_bwd + return noop_decorator, noop_decorator -class LinearFunctionForZeroStage3(torch.autograd.Function): +def _get_autocast_decorators(): + amp = getattr(torch, 'amp', None) + custom_fwd = getattr(amp, 'custom_fwd', None) + custom_bwd = getattr(amp, 'custom_bwd', None) + if custom_fwd is not None and custom_bwd is not None: + device_type = get_accelerator().device_name() + return functools.partial(custom_fwd, device_type=device_type), functools.partial(custom_bwd, + device_type=device_type) + return _get_legacy_autocast_decorators(get_accelerator().device_name()) - # Note that both forward and backward are @staticmethods - @staticmethod - @autocast_custom_fwd - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - weight_id = id(weight) - bias_id = id(bias) +autocast_custom_fwd, autocast_custom_bwd = _get_autocast_decorators() + - #ctx.save_for_backward(input, weight, bias) - ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) +def _is_autocast_enabled(device_type): + try: + return torch.is_autocast_enabled(device_type) + except TypeError: + legacy_getter = getattr(torch, f'is_autocast_{device_type}_enabled', None) + if legacy_getter is not None: + return legacy_getter() + return torch.is_autocast_enabled() - tensor_map[weight_id] = weight - tensor_map[bias_id] = bias + +def _get_autocast_dtype(device_type): + try: + return torch.get_autocast_dtype(device_type) + except TypeError: + legacy_getter = getattr(torch, f'get_autocast_{device_type}_dtype', None) + if legacy_getter is not None: + return legacy_getter() + return None + + +class LinearFunctionForZeroStage3(torch.autograd.Function): + + generate_vmap_rule = True + + @staticmethod + # bias is an optional argument + def forward(input, weight, bias=None): if input.dim() == 2 and bias is not None: # fused op is marginally faster @@ -70,55 +95,45 @@ def forward(ctx, input, weight, bias=None): return ret + @staticmethod + def setup_context(ctx, inputs, output): + device_type = get_accelerator().device_name() + ctx._dtype = _get_autocast_dtype(device_type) + ctx._fwd_used_autocast = _is_autocast_enabled(device_type) + input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None + ctx.save_for_backward(input, weight, bias) + # This function has only a single output, so it gets only one gradient @staticmethod - @autocast_custom_bwd def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - #input, weight, bias = ctx.saved_tensors - - input, weight_id, bias_id = ctx.saved_tensors - weight = tensor_map[weight_id.item()] - bias = tensor_map[bias_id.item()] - - grad_input = grad_weight = grad_bias = None - - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") - grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") - if ctx.needs_input_grad[1]: - #print("Computing grad weight") + # Match @custom_bwd semantics: always run backward under the same + # autocast state as forward — including explicitly disabling autocast + # when forward did not use it, to guard against outer autocast regions. + device_type = get_accelerator().device_name() + with torch.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype): + input, weight, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + dim = grad_output.dim() - if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") - if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") - grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") - return grad_input, grad_weight, grad_bias + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight) + if ctx.needs_input_grad[1]: + if dim > 2: + grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( + input.reshape(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + if bias is not None and ctx.needs_input_grad[2]: + if dim > 2: + grad_bias = grad_output.sum([i for i in range(dim - 1)]) + else: + grad_bias = grad_output.sum(0) + return grad_input, grad_weight, grad_bias def zero3_linear_wrap(input, weight, bias=None): - if bias is None: - return LinearFunctionForZeroStage3.apply(input, weight) - else: - return LinearFunctionForZeroStage3.apply(input, weight, bias) + return LinearFunctionForZeroStage3.apply(input, weight, bias) class LinearModuleForZeroStage3(Module): diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py new file mode 100755 index 000000000000..f1b6b955239c --- /dev/null +++ b/deepspeed/runtime/zero/mics.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import deepspeed +import torch +from deepspeed import comm as dist +from deepspeed.runtime.zero.utils import is_zero_param +from deepspeed.runtime.zero.mics_utils import (MiCS_CommGroups, create_mics_comm_groups, scale_tensors) +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.utils import instrument_w_nvtx, log_dist, logger +from deepspeed.accelerator import get_accelerator +from torch import Tensor +from torch.nn import Parameter + + +def has_hierarchical_all_gather_groups(comm_groups: MiCS_CommGroups): + result = False + if comm_groups.param_intra_node_group is not None and comm_groups.param_inter_node_shard_group is not None: + result = True + return result + + +class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle): + """ This handle assumes that no need to + copy data out from a contiguous tensor + """ + + def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None: + super().__init__(allgather_handle, params, partitions, world_size) + + def wait(self, **kwargs) -> None: + """ + """ + # let the current stream to op + try: + # print("HANDLE", self.allgather_handle) + instrument_w_nvtx(self.allgather_handle.wait)() + except (ValueError, RuntimeError) as e: + log_dist( + "WARNING: Runtime Error while waiting the collective all-gather, possibly due to the _IllegalWork", + ranks=[0]) + log_dist(f"Error message: {e}", ranks=[0]) + + if self.complete: + return + + for _, param in enumerate(self.params): + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + +class MiCS_Init(Init): + + def __init__(self, + module=None, + data_parallel_group=None, + sequence_data_parallel_group=None, + mem_efficient_linear=True, + remote_device=None, + pin_memory=False, + config_dict_or_path=None, + config=None, + enabled=True, + dtype=None, + mpu=None): + """A context manager to partition the model parameters during the model + construction with MiCS partition strategy. Model states are partitioned + to the number of devices specified via ``mics_shard_size`` field in the + deepspeed config json file. The context manager also introduces + hierarchical communication method to reduce the cost of inter-node + communications, which can be enabled with + ``mics_hierarchical_params_gather`` field in deepspeed config. + + Args: + module (``torch.nn.Module``, optional): If provided, partition the model as + if it was constructed in the context. + data_parallel_group (``deepspeed.comm`` process group, optional): + The group of processes to partition among. Defaults to all processes. + Synonymous with sequence data parallel group for param partitioning + across both sequence and data parallel groups. + mem_efficient_linear (bool, optional): Replace + torch.nn.functional.linear with an implementation that allows + DeepSpeed to partition parameters. Defaults to ``True``. + remote_device (string, optional): The initial device to store model + weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU + memory. The model may still be moved to GPU based on the + offload settings for training. Defaults to param offload device if a config is + defined, otherwise GPU. + pin_memory (bool, optional): Potentially increase performance by + using pinned memory for model weights. ``remote_device`` must be + ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration + for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. + enabled (bool, optional): If ``False``, this context has no + effect. Defaults to ``True``. + dtype (``dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}. + + This context follows the same logic as ``deepspeed.zero.Init()``, but + with the modification for partition size of each parameter. + + Examples + -------- + + #. Allocate a model and partition it among all processes: + + .. code-block:: python + # the config_dict_or_path is required to let the context manager know + # how partition the parameters. + # The configuration has to include the field ``mics_shard_size`` + with deepspeed.zero.MiCS_Init(config_dict_or_path=ds_config): + model = MyLargeModel() + + + #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes: + + .. code-block:: python + + with deepspeed.zero.MiCS_Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device="cpu", + pin_memory=True + config_dict_or_path=ds_config): + model = MyLargeModel() + + + #. Partition an already-allocated model in CPU memory: + + .. code-block:: python + + model = deepspeed.zero.MiCS_Init(module=model, + config_dict_or_path=ds_config) + """ + + assert config_dict_or_path is not None, "Must provide configuration for MiCS Initialization" + _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + + if data_parallel_group is None: + ds_process_group = dist.get_world_group() + else: + ds_process_group = data_parallel_group + + if sequence_data_parallel_group is not None: + logger.warning( + "sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.") + if data_parallel_group is not None: + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.ds_process_group = sequence_data_parallel_group + + self.mics_comm_groups = create_mics_comm_groups( + _ds_config.mics_shard_size, + ds_process_group, + hierarchical_allgather=_ds_config.mics_hierarchial_params_gather, + mpu=mpu) + + super().__init__(module, data_parallel_group, mem_efficient_linear, remote_device, pin_memory, + config_dict_or_path, config, enabled, dtype, mpu) + + def _convert_to_deepspeed_param(self, param): + super()._convert_to_deepspeed_param(param) + # attach communication groups to every param + param.comm = self.mics_comm_groups + + # record existing all_gather_coalesced implementation + # so that we can fallback later + old_all_gather_coalesced = param.all_gather_coalesced + + def _param_all_gather_coalesced(params, param_buffers=None, **kwargs): + """""" + mics_comm_groups: MiCS_CommGroups = params[0].comm + hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups) + if dist.has_coalescing_manager() and hierarchical_all_gather: + return self._hierarchical_all_gather_params(params, param_buffers) + elif dist.has_coalescing_manager(): + return self._flat_all_gather_with_coalescing_manager(params, param_buffers) + else: + return old_all_gather_coalesced(params, **kwargs) + + # change the all_gather_coalesced method + param.all_gather_coalesced = _param_all_gather_coalesced + + def _pre_all_gather(self, params, params_buffers=None): + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + return params, params_buffers + + def _flat_all_gather_with_coalescing_manager(self, params, params_buffers=None): + """""" + # must have to change the status of the param + # and ensure they are on the device + params, params_buffers = self._pre_all_gather(params, params_buffers) + + mics_comm_groups: MiCS_CommGroups = params[0].comm + param_shard_size = mics_comm_groups.param_shard_size + + output_tensors = [] + input_tensors = [] + for i, p in enumerate(params): + t_size = p.ds_tensor.ds_numel * param_shard_size + if params_buffers is not None and params_buffers[i] is not None: + assert params_buffers[i].numel( + ) == t_size, f'params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}' + flat_out = params_buffers[i] + else: + flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1) + output_tensors.append(flat_out) + _flat_input = p.ds_tensor.data.view(-1) + input_tensors.append(_flat_input) + + all_gather_handle = dist.all_gather_coalesced(output_tensors, + input_tensors, + group=mics_comm_groups.param_shard_group, + async_op=True) + + for idx, param in enumerate(params): + param.data = output_tensors[idx].narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + return MiCS_AllGatherCoalescedHandle(allgather_handle=all_gather_handle, + params=params, + partitions=[], + world_size=param_shard_size) + + def _hierarchical_all_gather_params(self, params, params_buffers=None): + """""" + params, params_buffers = self._pre_all_gather(params, params_buffers) + + mics_comm_groups: MiCS_CommGroups = params[0].comm + local_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group) + inter_node_comm_group = mics_comm_groups.param_inter_node_shard_group + intra_node_comm_group = mics_comm_groups.param_intra_node_group + param_shard_size = mics_comm_groups.param_shard_size + + inter_node_size = dist.get_world_size(group=inter_node_comm_group) + intra_node_size = dist.get_world_size(group=intra_node_comm_group) + param_tensors = [] + for i, p in enumerate(params): + param_size = p.ds_tensor.ds_numel * param_shard_size + if params_buffers is not None and params_buffers[i] is not None: + assert params_buffers[i].numel( + ) == param_size, f'param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}' + param_tensor = params_buffers[i] + else: + param_tensor = torch.empty(param_size, dtype=p.dtype, device=self.local_device, + requires_grad=False).view(-1) + param_tensors.append(param_tensor) + + # inter node all-gather + inter_outputs = [] + inter_inputs = [] + for i, p in enumerate(params): + inter_size = p.ds_tensor.ds_numel * inter_node_size + _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size) + inter_outputs.append(_out) + inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device)) + # sync enqueue + dist.all_gather_coalesced(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False) + + # intra node all-gather + intra_outputs = [] + intra_inputs = [] + for i, p in enumerate(params): + # partition param into multiple chunks for allgather + # because inter-node all-gather outputs are in a continues memory + # while in param memory, those inter-node data are placed in different + # location. + # each chunk is an intra-node output + param_chunk = param_tensors[i].view( + (inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1) + param_chunk.copy_(inter_outputs[i].detach().clone().view(param_chunk.size())) + output_chunks = torch.chunk(param_tensors[i], inter_node_size) + for j, _out in enumerate(output_chunks): + intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel + local_offset = local_rank * p.ds_tensor.ds_numel + _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel) + intra_outputs.append(_out) + intra_inputs.append(_in) + + all_gather_handle = dist.all_gather_coalesced(intra_outputs, + intra_inputs, + group=intra_node_comm_group, + async_op=True) + for i, param in enumerate(params): + param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + return MiCS_AllGatherCoalescedHandle( + allgather_handle=all_gather_handle, + params=params, + partitions=[], + world_size=param_shard_size, + ) + + def get_partition_dp_group(self, param): + return param.comm.param_shard_group + + def get_partition_rank(self): + return self.mics_comm_groups.param_shard_rank + + @property + def num_partitions(self): + return self.mics_comm_groups.param_shard_size + + +class MiCS_Offload(DeepSpeedZeRoOffload): + """ Wrapper to change the behavior for parameter sharding + """ + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + """ overload the parent class function for convert the parameters + + """ + log_dist('Convert to zero parameters from MiCS Offload manager', ranks=[0]) + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + MiCS_Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=self.offload_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu) + + +class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3): + """ + MiCS Optimizer + """ + + def __init__(self, + module, + init_optimizer, + param_names, + timers, + ds_config, + gradient_accumulation_dtype=torch.float16, + **kwargs): + + log_dist("Init MiCS optimizer", ranks=[0]) + super().__init__(module, + init_optimizer, + param_names, + timers, + ds_config, + gradient_accumulation_dtype=gradient_accumulation_dtype, + **kwargs) + first_param = next(module.parameters()) + # overload the dp_process_group and partition_count + assert hasattr(first_param, "comm"), " ".join([ + "Sharded parameters don't have the MiCS_CommGroups attached.", + "Might due to the use of deepspeed.zero.Init context for initializing the weights.", + "To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter." + ]) + self.dp_process_group = first_param.comm.param_shard_group + self.partition_count = first_param.comm.param_shard_size + + def initialize_ds_offload( + self, + *args, + **kwargs, + ): + return MiCS_Offload(*args, **kwargs) + + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + grad_buffers = super().partition_grads(params_to_release, grad_partitions) + # perform all-reduce among replication groups + # the function will perform accumulation boundary check + self.allreduce_mics_shard_grads(params_to_release, grad_buffers) + + @instrument_w_nvtx + def allreduce_mics_shard_grads(self, params, partitioned_grads_buffers: List[Tensor]): + """ + """ + # TODO: improve the condition check + if not self.is_gradient_accumulation_boundary or \ + len(partitioned_grads_buffers) == 0: + return + + mics_comm_groups: MiCS_CommGroups = params[0].comm + param_repli_group = mics_comm_groups.param_repli_group + param_repli_size = mics_comm_groups.param_repli_size + + if param_repli_size is None or param_repli_size <= 1: + return + if not get_accelerator().on_accelerator(partitioned_grads_buffers[0]): + raise RuntimeError("Local sharding has no support for CPU offloading") + + if dist.has_all_reduce_coalesced(): + scale_tensors(partitioned_grads_buffers, param_repli_size) + dist.all_reduce_coalesced(tensors=partitioned_grads_buffers, group=param_repli_group) + else: + # manually coalescing all-reduce + aggregated_buffer: Tensor = torch.cat(partitioned_grads_buffers) + aggregated_buffer.div_(param_repli_size) + dist.all_reduce(aggregated_buffer, group=param_repli_group) + offset = 0 + for grad_buff in partitioned_grads_buffers: + grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel())) + offset += grad_buff.numel() + + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False, + checkpoint_folder=None, + load_serial=None): + r""" Loading the ZeRO-3/MiCS partitioned checkpoints + Because the self.dp_process_group is replaced with the communicator for + partition group we can call the load_state_dict logic from ZeRO-3. + """ + super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder) diff --git a/deepspeed/runtime/zero/mics_utils.py b/deepspeed/runtime/zero/mics_utils.py new file mode 100644 index 000000000000..18e94c501231 --- /dev/null +++ b/deepspeed/runtime/zero/mics_utils.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from typing import List + +import numpy as np +from torch import Tensor + +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import logger +from deepspeed.utils.torch import jit_script_compat + + +def _log_rank0(msg): + if dist.get_rank() == 0: + logger.info(msg) + + +@jit_script_compat +def scale_tensors(tensors: List[Tensor], scale: int): + for t in tensors: + t.div_(scale) + + +@dataclass +class MiCS_CommGroups: + """""" + param_shard_group = None + param_shard_size = -1 + param_shard_rank = -1 + + param_repli_group = None + param_repli_size = -1 + param_repli_rank = -1 + + param_intra_node_group = None + param_inter_node_shard_group = None + + +def create_mics_comm_groups( + shard_size, + dp_group, + hierarchical_allgather=False, + mpu=None, +): + """ + create shard-group, replicate-group from config_file + TODO: consider broadcast the config from rank0 + + Returns: + MiCS_CommGroups + """ + # env var for debugging purpose + ndevices_per_node = int(os.environ.get("NDEV_PER_NODE", get_accelerator().device_count())) + _log_rank0(f'creating MiCS communication groups with per node device size {ndevices_per_node}') + groups = MiCS_CommGroups() + + if mpu is not None: + assert dp_group == mpu.get_data_parallel_group() + + # full size of the world + world_size = dist.get_world_size() + # global rank + global_rank = dist.get_rank() + + config = _generate_mics_config(world_size, ndevices_per_node, shard_size, 1) + ranks_of_shard_group = config['shard_groups'] + ranks_of_repli_group = config['replicate_groups'] + if len(ranks_of_repli_group) == 0: + assert len(ranks_of_shard_group) == 1, "replicate groups are empty only for single shard group" + for r in ranks_of_shard_group[0]: + ranks_of_repli_group.append([r]) + + # for simplicity + assert _sizes_all_same(ranks_of_repli_group), "replicate groups must have the same size" + assert _sizes_all_same(ranks_of_shard_group), "shard groups must have the same size" + + assert sum([len(g) for g in ranks_of_shard_group]) == dist.get_world_size(), "all sharded ranks " + if len(ranks_of_shard_group) > 1: # if only shard on one group then no need for replicate groups + assert len(ranks_of_shard_group) == len( + ranks_of_repli_group[0]), "number of shard groups must equal to the size of each replicate group" + + global_rank = dist.get_rank() + # create shard groups + for shard_ranks in ranks_of_shard_group: + _group = dist.new_group(shard_ranks) + if global_rank in shard_ranks: + groups.param_shard_group = _group + groups.param_shard_size = len(shard_ranks) + groups.param_shard_rank = dist.get_rank(_group) + logger.info(f'rank {global_rank}, shard group' + f' {groups.param_shard_rank}/{dist.get_world_size(group=_group)}') + + # create replicate groups + for repli_ranks in ranks_of_repli_group: + if len(repli_ranks) > 1: + _group = dist.new_group(repli_ranks) + if global_rank in repli_ranks: + groups.param_repli_group = _group + groups.param_repli_size = len(repli_ranks) + groups.param_repli_rank = dist.get_rank(group=_group) + logger.info(f'rank {global_rank} ' + f'replicate group {groups.param_repli_rank}/{dist.get_world_size(group=_group)}') + else: + groups.param_repli_group = None + groups.param_repli_size = 1 + groups.param_repli_rank = 0 + logger.info(f'rank {global_rank} replicate group 0/1') + + # assign shard group size as world size + assert groups.param_shard_size == len(ranks_of_shard_group[0]) + + if hierarchical_allgather: + # create hierarchy inter-node, intra-node groups + # n_span_nodes = config['shard_span'] + n_span_nodes = config['span_nodes'] + assert n_span_nodes > 1, "sharding spans on single node, no need for hierarchy allgather" + assert len(ranks_of_shard_group[0]) % n_span_nodes == 0 + + n_gpu_per_node = len(ranks_of_shard_group[0]) // n_span_nodes + intra_node_ranks_group = [] + inter_node_ranks_group = [] + for shard_group in ranks_of_shard_group: + _intra_node_ranks = [] + for i in range(0, len(shard_group), n_gpu_per_node): + _intra_node_ranks.append(shard_group[i:i + n_gpu_per_node]) + _inter_node_ranks = [] + for i in range(n_gpu_per_node): + _ranks = [_g[i] for _g in _intra_node_ranks] + _inter_node_ranks.append(_ranks) + + intra_node_ranks_group.append(_intra_node_ranks) + inter_node_ranks_group.append(_inter_node_ranks) + + _log_rank0(f"create for hierarchy all-gather groups: intra nodes {intra_node_ranks_group}") + _log_rank0(f"create for hierarchy all-gather groups: inter nodes {inter_node_ranks_group}") + + # create communicators + for shard_group in intra_node_ranks_group: + for intra_node_ranks in shard_group: + _group = dist.new_group(intra_node_ranks) + if global_rank in intra_node_ranks: + groups.param_intra_node_group = _group + _log_rank0(f'create group for intra node ranks {intra_node_ranks}') + + for shard_group in inter_node_ranks_group: + for inter_node_ranks in shard_group: + _group = dist.new_group(inter_node_ranks) + if global_rank in inter_node_ranks: + groups.param_inter_node_shard_group = _group + _log_rank0(f'create group for inter node ranks {inter_node_ranks}') + return groups + + +def _generate_mics_config(world_size, ndev_per_node, shard_size, pp_size=1): + """Generating the configuration for sharding This shard config generation assume + that the pipeline stages are partitioned in order, i.e., first ranks + hold the stage0, etc. + + Args: + + shard_size (int): zero3 data-parallel shard size, FIXME: + change the name later + + pp_size (int): pipeline parallel size, currently, only work with + pipeline parallelism + zero + + """ + assert world_size % pp_size == 0 + assert (world_size // pp_size) % shard_size == 0, \ + f"dp group size is not dividable by dp_shard_size, "\ + f" (world_size {world_size}, pp_size {pp_size}, dp_shard_size {shard_size})" + + config = {} + shard_groups = np.arange(world_size).reshape(-1, shard_size) + replicate_groups = [] + for i in range(shard_size): + same_shard_ranks = shard_groups[:, i].tolist() + n_ranks = len(same_shard_ranks) + replicate_size = n_ranks // pp_size + replicate_groups.extend([same_shard_ranks[j:j + replicate_size] for j in range(0, n_ranks, replicate_size)]) + + config['replicate_groups'] = replicate_groups + config['shard_groups'] = shard_groups.tolist() + config["span_nodes"] = len(shard_groups[0]) // ndev_per_node + return config + + +def _sizes_all_same(groups): + """all groups have same length""" + all_same = True + for g in groups: + if len(g) != len(groups[0]): + return False + return all_same diff --git a/deepspeed/runtime/zero/muon/__init__.py b/deepspeed/runtime/zero/muon/__init__.py new file mode 100644 index 000000000000..3fa53f31d420 --- /dev/null +++ b/deepspeed/runtime/zero/muon/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025 Peng Du and Zhipeng Wang +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/runtime/zero/muon/muon_optimizer.py b/deepspeed/runtime/zero/muon/muon_optimizer.py new file mode 100644 index 000000000000..b2810c556eeb --- /dev/null +++ b/deepspeed/runtime/zero/muon/muon_optimizer.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 Peng Du and Zhipeng Wang +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +try: + from deepspeed.runtime.zero.muon.original_muon import MuonWithAuxAdam as BaseMuonWithAuxAdam + from deepspeed.runtime.zero.muon.original_muon import adam_update +except ImportError: + pass + + +class MuonWithAuxAdam(BaseMuonWithAuxAdam): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + if group["use_muon"]: + # we move the muon update part to the deepspeed's optimizer since the parameter here is a flat version + # thus not suitable for muon update + for p in group["params"]: + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(p.grad.reshape(p.shape), alpha=-group["lr"]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], + group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/deepspeed/runtime/zero/muon/original_muon.py b/deepspeed/runtime/zero/muon/original_muon.py new file mode 100644 index 000000000000..2bb745da4f5f --- /dev/null +++ b/deepspeed/runtime/zero/muon/original_muon.py @@ -0,0 +1,429 @@ +# Copyright (c) 2024 Keller Jordan +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +MIT License + +Copyright (c) 2024 Keller Jordan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" + +import torch +import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check +from deepspeed.runtime import compiler +from deepspeed.accelerator import get_accelerator + + +@compiler.compile() +def zeropower_via_newtonschulz5(G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + # Use bf16 when hardware supports it; fp32 otherwise + compute_dtype = torch.bfloat16 if get_accelerator().is_bf16_supported() else torch.float32 + X = G.to(compute_dtype) + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +@compiler.compile() +def zeropower_via_gram_newtonschulz(G, steps: int): + """ + Gram Newton-Schulz iteration for orthogonalization. + + Mathematically equivalent to standard Newton-Schulz but iterates on the + small square Gram matrix R = X @ X.T (n x n) instead of the full rectangular + X (n x m). This reduces FLOPs significantly when m >> n (typical for + transformer weight matrices with aspect ratio ~5). + + Uses fp16 instead of bf16 for better numerical precision at the same + compute cost. Includes a restart at iteration 2 to maintain stability + in half-precision. + + Falls back to standard Newton-Schulz for square matrices (n == m) + where there is no FLOP advantage. + + Reference: https://tridao.me/blog/2026/gram-newton-schulz/ + """ + assert G.ndim >= 2 + a, b, c = (3.4445, -4.7750, 2.0315) + # Use fp16 for better precision than bf16 when hardware supports it; fp32 otherwise + compute_dtype = torch.float16 if get_accelerator().is_fp16_supported() else torch.float32 + X = G.to(compute_dtype) + if G.size(-2) > G.size(-1): + X = X.mT + + n, m = X.size(-2), X.size(-1) + + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # For square matrices, no FLOP advantage; use standard iteration + if m <= n: + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(-2) > G.size(-1): + X = X.mT + return X + + # Gram NS: iterate on R = X @ X.T (n x n) instead of X (n x m) + R = X @ X.mT + Q = None + restart_at = 2 + + for i in range(steps): + if i == restart_at and i != 0: + X = Q @ X + R = X @ X.mT + Q = None + + Z = b * R + c * R @ R + + if Q is None: + Q = Z.clone() + Q.diagonal().add_(a) + else: + Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0) + + if i < steps - 1 and (i + 1) != restart_at: + RZ = torch.addmm(R, Z, R, beta=a, alpha=1.0) + R = torch.addmm(RZ, Z, RZ, beta=a, alpha=1.0) + + if G.size(-2) > G.size(-1): + X = X.mT @ Q.mT + else: + X = Q @ X + return X + + +NS_METHODS = {"standard", "gram"} + + +@compiler.compile() +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"): + orig_dtype = grad.dtype + momentum.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum, beta) if nesterov else momentum + if update.ndim == 4: # for the case of conv filters + update = update.view(len(update), -1) + if ns_method == "gram": + update = zeropower_via_gram_newtonschulz(update, steps=ns_steps) + else: + update = zeropower_via_newtonschulz5(update, steps=ns_steps) + update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + if update.dtype != orig_dtype: + update = update.to(orig_dtype) + return update + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the + advantage that it can be stably run in bfloat16 on the GPU. + + Muon should only be used for hidden weight layers. The input embedding, final output layer, + and any internal gains or biases should be optimized using a standard method such as AdamW. + Hidden convolutional weights can be trained using Muon by viewing them as 2D and then + collapsing their last 3 dimensions. + + Arguments: + lr: The learning rate, in units of spectral norm per update. + weight_decay: The AdamW-style weight decay. + momentum: The momentum. A value of 0.95 here is usually fine. + ns_method: Newton-Schulz method. "gram" (default) uses Gram NS for ~2x speedup + on rectangular matrices. "standard" uses the original iteration. + """ + + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method) + assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) + params = sorted(params, key=lambda x: x.size(), reverse=True) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1]) + ] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram")) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()]) + + return loss + + +class SingleDeviceMuon(torch.optim.Optimizer): + """ + Muon variant for usage in non-distributed settings. + """ + + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram")) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + + return loss + + +def adam_update(grad, buf1, buf2, step, betas, eps): + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0]**step) + buf2c = buf2 / (1 - betas[1]**step) + return buf1c / (buf2c.sqrt() + eps) + + +class MuonWithAuxAdam(torch.optim.Optimizer): + """ + Distributed Muon variant that can be used for all parameters in the network, since it runs an + internal AdamW for the parameters that are not compatible with Muon. The user must manually + specify which parameters shall be optimized with Muon and which with Adam by passing in a + list of param_groups with the `use_muon` flag set. + + The point of this class is to allow the user to have a single optimizer in their code, rather + than having both a Muon and an Adam which each need to be stepped. + + You can see an example usage below: + + https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470 + ``` + hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.lm_head.weight] + + from muon import MuonWithAuxAdam + adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] + adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups] + muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True) + param_groups = [*adam_groups, muon_group] + optimizer = MuonWithAuxAdam(param_groups) + ``` + """ + + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + group["ns_method"] = group.get("ns_method", "gram") + assert group[ + "ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}" + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1]) + ] * (dist.get_world_size() - len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram")) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], + group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss + + +class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): + """ + Non-distributed variant of MuonWithAuxAdam. + """ + + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + group["ns_method"] = group.get("ns_method", "gram") + assert group[ + "ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}" + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, + state["momentum_buffer"], + beta=group["momentum"], + ns_method=group.get("ns_method", "gram")) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update.reshape(p.shape), alpha=-group["lr"]) + else: + for p in group["params"]: + if p.grad is None: + # continue + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], + group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index c3a6dc7af530..ac88d32266f8 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -3,9 +3,11 @@ # DeepSpeed Team -from pydantic import Field, validator from enum import Enum from pathlib import Path +from pydantic import Field, model_validator +from typing import Optional + from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int @@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel): `nvme`. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for parameter offloading. """ buffer_count: int = Field(5, ge=0) @@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): `nvme`. Optimizer computation is offload to CPU regardless of device option. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for optimizer state offloading. """ buffer_count: int = Field(4, ge=0) @@ -88,7 +90,26 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): fast_init: bool = False """ Enable fast optimizer initialization when offloading to NVMe. """ - @validator("pipeline_read", "pipeline_write", always=True) - def set_pipeline(cls, field_value, values): - values["pipeline"] = field_value or values.get("pipeline", False) - return field_value + ratio: float = Field(1.0, ge=0.0, le=1.0) + """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" + + super_offload: bool = False + """ Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3.""" + + cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0) + """ Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True.""" + + @model_validator(mode="after") + def set_pipeline(self): + pipeline = self.pipeline_read or self.pipeline_write + self.__dict__["pipeline"] = pipeline + return self + + +class OffloadStateTypeEnum(str, Enum): + """ Enum for internal buffer types """ + optim_states = "optim_states" + hp_params = "hp_params" + lp_params = "lp_params" + lp_grads = "lp_grads" + contiguous_grad_buffer = "contiguous_grad_buffer" diff --git a/deepspeed/runtime/zero/offload_states.py b/deepspeed/runtime/zero/offload_states.py new file mode 100644 index 000000000000..f0a45f12b7c6 --- /dev/null +++ b/deepspeed/runtime/zero/offload_states.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Set +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum + + +def _make_offload_state_key(key): + return f"{key}_offload_buffer" + + +def offload_optimizer_states(optimizer, device, pin_memory=False, non_blocking=False): + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + if pin_memory and v.device.type != 'cpu': + pinned_buffer = torch.empty_like(v, device='cpu').pin_memory() + pinned_buffer.copy_(v, non_blocking=non_blocking) + state[k] = pinned_buffer + else: + state[k] = v.to(device, non_blocking=non_blocking) + + +def reload_optimizer_states(optimizer, device, non_blocking=False): + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(device, non_blocking=non_blocking) + + +def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_key(state, key): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = torch.empty_like(state[key], device=device) + if pin_memory: + state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key]) + state[offload_buf_key].copy_(state[key], non_blocking=non_blocking) + state[key].data = state[offload_buf_key] + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + +def reload_adam_states(optimizer, device, non_blocking: bool = False): + """Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam.""" + + def move_back_key(state, key): + state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking) + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + move_back_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_back_key(state, "exp_avg_sq") + + +def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]: + """Retrieve the devices of the specified state of the model. + + Args: + model (DeepSpeedEngine): The model whose device allocations are to be checked. + state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved. + + Returns: + Set[torch.device]: A set of devices of the specified state. + + """ + if state == OffloadStateTypeEnum.hp_params: + return set(model.optimizer.get_hp_param_device(p) for p in model.parameters()) + elif state == OffloadStateTypeEnum.lp_params: + return set(p.ds_tensor.device for p in model.parameters()) + elif state == OffloadStateTypeEnum.lp_grads: + return {model.optimizer.grad_partitions_flat_buffer.device} + elif state == OffloadStateTypeEnum.optim_states: + return set(model.optimizer.get_hp_param_device(p, "exp_avg") for p in model.parameters()) | \ + set(model.optimizer.get_hp_param_device(p, "exp_avg_sq") for p in model.parameters()) + elif state == OffloadStateTypeEnum.contiguous_grad_buffer: + return set(bucket.buffer.device for bucket in model.optimizer.ipg_buckets.values() + if bucket.buffer is not None) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 55beff336740..aba0cde6266d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,91 +6,19 @@ import sys import torch from collections import OrderedDict +from deepspeed.utils import z3_leaf_module, set_z3_leaf_module from deepspeed.runtime.utils import see_memory_usage +from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params -from deepspeed import comm as dist +from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params from deepspeed.accelerator import get_accelerator +from deepspeed import utils FWD_MODULE_STACK = list() -def is_builtin_type(obj): - # https://stackoverflow.com/a/17795199 - return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" - - -def isinstance_namedtuple(obj: object) -> bool: - """ - Is this an instance of namedtuple/NamedTuple? - From: https://stackoverflow.com/a/62692640 - - Args: - obj (object): An object. - - Returns: - bool: True if namedtuple/NamedTuple else False. - """ - return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') - - -# ensure we only warn once, otherwise every iteration will trigger a warning -warned = False - - -def _apply_to_tensors_only(module, functional, backward_function, outputs): - """ - Apply a torch.autograd.Function that calls a `backward_function` to every Tensor in `outputs`. - - Args: - module (torch.nn.Module): A torch module - functional (Type[torch.autograd.Function]): The function class to apply. - backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to - `functional.apply`. - outputs (Any): The output of `module`. - - Returns: - Any: The output of `module`. - """ - if isinstance(outputs, (tuple, list)): - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, functional, backward_function, output) - touched_outputs.append(touched_output) - - if isinstance_namedtuple(outputs): - # namedtuples require a slightly different syntax. - return outputs.__class__(*touched_outputs) - - return outputs.__class__(touched_outputs) - elif isinstance(outputs, dict): - # apply inplace to avoid recreating dict inherited objects - for key in outputs.keys(): - outputs[key] = _apply_to_tensors_only(module, functional, backward_function, outputs[key]) - return outputs - - elif isinstance(outputs, torch.Tensor): - # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - touched_outputs = functional.apply(module, backward_function, outputs) - - # restore zero param attributes if those get stripped by `backward_function` - if not is_zero_param(touched_outputs) and is_zero_param(outputs): - touched_outputs.ds_param_alias = outputs - return touched_outputs - else: - if not is_builtin_type(outputs): - global warned - if not warned and dist.get_rank() == 0: - logger.warning( - f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " - "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " - "output tensors and therefore may not get triggered properly.") - warned = True - return outputs - - #for each tensor in outputs run the forward_function and register backward_function as hook def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): if type(outputs) is tuple: @@ -122,6 +50,10 @@ def __init__(self, parent_module, *args, **kwargs): self._parent_module = parent_module self._in_forward = False + def __reduce__(self): + r0, _, *r2 = super().__reduce__() + return (r0, (self._parent_module, )) + tuple(r2) + def __getitem__(self, key): param = super().__getitem__(key) @@ -129,7 +61,8 @@ def __getitem__(self, key): if param is None: return param - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + # TODO: only weaken this check during compilation + if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if self._parent_module._parameters._in_forward: register_external_parameter(FWD_MODULE_STACK[-1], param) param.all_gather() @@ -140,6 +73,8 @@ def __getitem__(self, key): def _inject_parameters(module, cls): for module in module.modules(): + module._original_parameters = module._parameters + if cls == ZeROOrderedDict: new_param = cls(parent_module=module) else: @@ -147,80 +82,72 @@ def _inject_parameters(module, cls): for key, param in module._parameters.items(): new_param[key] = param + module._parameters = new_param -class PreBackwardFunction(torch.autograd.Function): +def ensure_zero_ordered_dict(module): + """Wrap ``module._parameters`` in :class:`ZeROOrderedDict` if not already. - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args + PyTorch 2.5+ defaults ``nn.Module._parameters`` to a plain ``dict`` + (pytorch/pytorch#129164), which rejects the ``_in_forward`` attribute + the forward prologue sets. Modules not converted by ``_inject_parameters`` + at engine init (e.g. submodules attached after ``deepspeed.initialize``, + or restored by ``deepspeed/compile/init_z3.py``) hit issue #6961. + Idempotent; no-op if already wrapped, missing, or a non-dict container. + """ + params = getattr(module, "_parameters", None) + if isinstance(params, ZeROOrderedDict) or not isinstance(params, dict): + return + # Preserve the original container only on first wrap so the un-injection + # path in ``deepspeed/compile/init_z3.py`` can restore it. + if not hasattr(module, "_original_parameters"): + module._original_parameters = params + new_param = ZeROOrderedDict(parent_module=module) + for key, param in params.items(): + new_param[key] = param + module._parameters = new_param class DeepSpeedZeRoOffload(object): - def __init__(self, - module, - timers, - ds_config, - overlap_comm=True, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - model_persistence_threshold=sys.maxsize, - offload_param_config=None, - mpu=None): - - see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) + def __init__( + self, + module, + timers, + ds_config, + zenflow=False, + overlap_comm=True, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + offload_param_config=None, + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + zero_module_granularity_threshold=0, + log_trace_cache_warnings=False, + ): + + see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=False) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) self.module = module + self.timers = timers + self.zenflow = zenflow self.dtype = list(module.parameters())[0].dtype + self.dp_process_group = dp_process_group self.offload_device = None self.offload_param_pin_memory = False + self.zero_param_parallel_group = zero_param_parallel_group + self.zero_quantized_weights = zero_quantized_weights + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights + self.log_trace_cache_warnings = log_trace_cache_warnings if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: self.offload_device = offload_param_config.device @@ -238,42 +165,62 @@ def __init__(self, self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, self.model_persistence_threshold) - self.param_coordinators = {} self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) - self.__allgather_stream = get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream() + self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream( + ) if overlap_comm else get_accelerator().default_stream() + + if not hasattr(module, "ds_inflight_param_registry"): + module.ds_inflight_param_registry = InflightParamRegistry() + self.__inflight_param_registry = module.ds_inflight_param_registry + + self.fast_sharding_for_leaf_module = False + + if zero_module_granularity_threshold > 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.granularity_info = set() + self.z3_leaf_layers = [] + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + self.fast_sharding_for_leaf_module = True + + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self._max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry, + prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, + timers=self.timers, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, + fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module, + log_trace_cache_warnings=self.log_trace_cache_warnings, + ) self.forward_hooks = [] self.backward_hooks = [] + self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', force=False) - see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True) + see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=False) @instrument_w_nvtx def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters of modules whose input parameters do not require grad computation do not trigger post call and will therefore will remain unpartitioned""" - self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) + self.get_param_coordinator().release_and_reset_all(self.module) for param in iter_params(self.module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") - def get_param_coordinator(self, training): - if not training in self.param_coordinators: - self.param_coordinators[training] = PartitionedParameterCoordinator( - prefetch_bucket_sz=self._prefetch_bucket_sz, - max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, - max_available_parameters_in_numel=self._max_available_parameters_in_numel, - allgather_stream=self.__allgather_stream, - prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, - ) - - return self.param_coordinators[training] + def get_param_coordinator(self): + return self.param_coordinator def empty_partition_cache(self): self.partition_all_parameters() @@ -286,7 +233,8 @@ def _convert_to_zero_parameters(self, ds_config, module, mpu): zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) else: group = None - if mpu: + # parallel_state_sp doesn't have get_data_parallel_group + if mpu and hasattr(mpu, "get_data_parallel_group"): group = mpu.get_data_parallel_group() Init(module=module, @@ -295,7 +243,10 @@ def _convert_to_zero_parameters(self, ds_config, module, mpu): config_dict_or_path=ds_config, remote_device=self.offload_device, pin_memory=self.offload_param_pin_memory, - mpu=mpu) + mpu=mpu, + zero_param_parallel_group=self.zero_param_parallel_group, + zero_quantized_weights=self.zero_quantized_weights, + zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights) def destroy(self): self._remove_module_hooks() @@ -310,6 +261,8 @@ def _remove_module_hooks(self): for hook in self.backward_hooks: hook.remove() + self.fwd_pre_hook.remove() + print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', force=False) @@ -318,14 +271,14 @@ def setup_zero_stage3_hooks(self): #reset step if in inference mode @instrument_w_nvtx - def _end_of_forward_hook(module, *args): + def _start_of_forward_hook(module, *args): - if not torch._C.is_grad_enabled(): - self.get_param_coordinator(training=False).reset_step() + self.get_param_coordinator().reset_step() + + self.fwd_pre_hook = self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) + self._register_deepspeed_module(self.module) # Add top module to stack trace global FWD_MODULE_STACK @@ -335,7 +288,7 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): persistent_params = [] total_persistent_parameters = 0 params_count = 0 - for _, param in self.module.named_parameters(recurse=True): + for name, param in self.module.named_parameters(recurse=True): if param.ds_numel + total_persistent_parameters > model_threshold: continue @@ -346,27 +299,32 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): total_persistent_parameters += param.ds_numel print_rank_0( - f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=True) + f"Parameter Offload - Persistent parameters statistics: param_count = {params_count}, numel = {total_persistent_parameters}", + force=False) return persistent_params - def _register_hooks_recursively(self, module, count=[0]): + def _register_deepspeed_module(self, module, count=[0]): my_count = count[0] - module.id = my_count + module.ds_id = my_count - #print(f"{module.__class__} : {module.id}") + #print(f"{module.__class__} : {module.ds_id}") - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) + if z3_leaf_module(module): + for param in module.parameters(): + param.ds_z3_leaf_module = module + else: + for child in module.children(): + count[0] = count[0] + 1 + self._register_deepspeed_module(child, count=count) - @instrument_w_nvtx + @torch.compiler.disable def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) @instrument_w_nvtx def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK FWD_MODULE_STACK.pop() if output is None: @@ -407,20 +365,16 @@ def _post_forward_module_hook(module, input, output): self.post_sub_module_forward_function(module) - def _pre_backward_module_hook(module, inputs, output): + def _bwd_hook_unexpected_inputs_msg(value): + return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \ + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \ + "output tensors and therefore may not get triggered properly." - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + def _pre_backward_module_hook(module, inputs, output): - return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) + return apply_to_tensors_only(module.pre_bwd_fn.apply, + output, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) #This is an alternate to doing _post_backward_module_hook #it uses tensor.register_hook instead of using torch.autograd.Function @@ -442,15 +396,13 @@ def _run_before_forward_function(input): return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function, _run_after_backward_hook, inputs) + @torch.compiler.disable def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) + return apply_to_tensors_only(module.post_bwd_fn.apply, + inputs, + warning_msg_fn=_bwd_hook_unexpected_inputs_msg) # Pre forward hook self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) @@ -459,9 +411,81 @@ def _run_after_backward_function(sub_module): self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) # Pre backward hook + if not hasattr(module, "pre_bwd_fn"): + + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(outputs): + return outputs.detach() + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + module.pre_bwd_fn = PreBackwardFunctionForModule + self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook)) # post backward hook + if not hasattr(module, "post_bwd_fn"): + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(output): + return output.detach() + + @staticmethod + def setup_context(ctx, inputs, output): + (output_in, ) = inputs + ctx.module = module + if output_in.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + module.post_bwd_fn = PostBackwardFunctionModule + self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) @torch.no_grad() @@ -471,43 +495,144 @@ def pre_sub_module_forward_function(self, sub_module): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module) + param_coordinator.fetch_sub_module(sub_module, forward=True) + + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) @torch.no_grad() def post_sub_module_forward_function(self, sub_module): - see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release", + force=False) + + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data - param_coordinator = self.get_param_coordinator(training=sub_module.training) - param_coordinator.release_sub_module(sub_module) + param_coordinator = self.get_param_coordinator() + param_coordinator.release_sub_module(sub_module, forward=True) - see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release", + force=False) @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" - param_coordinator = self.get_param_coordinator(training=True) + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator() param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module) + param_coordinator.fetch_sub_module(sub_module, forward=False) + + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data @torch.no_grad() def post_sub_module_backward_function(self, sub_module): - assert sub_module.training, "backward pass is invalid for module in evaluation mode" + # assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release", force=False) - self.get_param_coordinator(training=True).release_sub_module(sub_module) + if self.zenflow: + params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))) + for param in params_to_fetch: + param.data = param.data.t() if len(param.ds_shape) != 1 else param.data + + self.get_param_coordinator().release_sub_module(sub_module, forward=False) see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release", force=False) + + def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): + + self._get_granularity_recursively(module) + print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=False) + for granularity in self.granularity_info: + print_rank_0(granularity, force=False) + + if self.min_granularity_value <= zero_module_granularity_threshold: + self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold) + utils.logger.info( + f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}") + for layer in self.z3_leaf_layers: + print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=False) + else: + utils.logger.warning( + f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ + f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}. "\ + f"Current Value:{zero_module_granularity_threshold}" + ) + + def _get_granularity_recursively(self, module): + """This function is used to recursively obtain the granularity of each module.""" + + # avoid setting as leaf for particularly large models, even if the granularity is very small + # an oversized leaf module increases the number of live parameters, introducing memory overhead + Z3_MAX_LEAF_SIZE = 1e9 + + if not list(module.parameters()): + # skip Modules without parameters, such as GELU, etc. + module.ds_model_granularity = sys.maxsize + return 0, 0 + + num_layers = 0 + num_params = 0 + num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) + if not any(module.children()): + # torch leaf module + module.ds_model_granularity = sys.maxsize + return 1, num_params + + for child in module.children(): + layers_in_child, params_in_child = self._get_granularity_recursively(child) + num_layers += layers_in_child + num_params += params_in_child + + if module.__class__.__name__ in torch.nn.modules.container.__all__: + # Do not set container modules like ModuleList as leaf modules + # as this will prevent hooks from being set on their children + # and they may do not invoke the forward method + module.ds_model_granularity = sys.maxsize + return num_layers, num_params + + num_layers += 1 + ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize + module.ds_model_granularity = ds_model_granularity + # module.ds_model_num_layers = num_layers + # module.ds_model_num_params = num_params + if self.min_granularity_value > ds_model_granularity: + self.min_granularity_value = ds_model_granularity + self.min_granularity_layer = module.__class__.__name__ + self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(20)}") + + return num_layers, num_params + + def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): + '''Set modules as leaf modules based on the threshold, prioritizing parent nodes.''' + + num_params = sum(p.ds_numel for p in module.parameters()) + if num_params == 0: + # skip Modules without parameters, such as GELU, etc. + return + if module.ds_model_granularity <= granularity_treshhold: + set_z3_leaf_module(module, True) + self.z3_leaf_layers.append(module) + return + + for sub_module in module.children(): + self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 84e628ef487c..2b23c0b340ee 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -6,11 +6,13 @@ import math import os import types -from typing import Callable, Iterable +from typing import Callable, Iterable, Union from enum import Enum import functools import itertools from typing import List +from collections import defaultdict +import logging import torch from torch import Tensor from deepspeed import comm as dist @@ -19,10 +21,11 @@ from .linear import zero3_linear_wrap +from deepspeed.utils import groups import deepspeed -from ..utils import get_only_unique_item, see_memory_usage +from ..utils import see_memory_usage, get_only_unique_item from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.config_utils import get_config_default from deepspeed.utils import instrument_w_nvtx, logger @@ -31,10 +34,25 @@ debug_param2name_id, debug_param2name_id_shape_status) from deepspeed.accelerator import get_accelerator from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus +from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict +from deepspeed.runtime.torch_autocast import sort_dtypes, get_comm_dtype, has_comm_dtype -param_count = 0 partitioned_param_data_shape = [0] -zero_init_enabled = False +zero_init_context = 0 +top_level_context = None + + +class DeepSpeedTensorOverride(Enum): + dtype = 1 + device = 2 + + +DEFAULT_TENSOR_OVERRIDES = [DeepSpeedTensorOverride.dtype, DeepSpeedTensorOverride.device] + + +def get_allgather_dtype(param, param_ds_tensor): + autocast = has_comm_dtype(param) + return get_comm_dtype(param) if autocast else param_ds_tensor.dtype class NoGatherHandle: @@ -43,12 +61,17 @@ def __init__(self, param: Parameter) -> None: if param.ds_status != ZeroParamStatus.INFLIGHT: raise RuntimeError(f"expected param {param.ds_summary()} to be available") - param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), - non_blocking=True).view(param.ds_shape) + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to( + device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape) + else: + param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), + non_blocking=True).view(param.ds_shape) self.__param = param - def wait(self) -> None: - get_accelerator().current_stream().synchronize() + def wait(self, **kwargs) -> None: + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -61,15 +84,20 @@ def __init__(self, params: List[Parameter]) -> None: for param in self.__params: if param.ds_status != ZeroParamStatus.INFLIGHT: raise RuntimeError(f"expected param {param.ds_summary()} to not be available") - param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), - non_blocking=True).view(param.ds_shape) + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to( + device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape) + else: + param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(), + non_blocking=True).view(param.ds_shape) @instrument_w_nvtx - def wait(self) -> None: + def wait(self, **kwargs) -> None: if self.__complete: return - get_accelerator().current_stream().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() for param in self.__params: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" param.ds_status = ZeroParamStatus.AVAILABLE @@ -97,12 +125,6 @@ def debug_rank0(msg: str) -> None: logger.debug(msg) -def is_zero_param(parameter): - if not torch.is_tensor(parameter): - return False - return hasattr(parameter, 'ds_id') - - def _init_external_params(module): if not hasattr(module, '_external_params'): module._external_params = {} @@ -214,33 +236,43 @@ class ZeroParamStatus(Enum): INFLIGHT = 3 +_orig_torch_tensor = torch.tensor _orig_torch_empty = torch.empty _orig_torch_zeros = torch.zeros _orig_torch_ones = torch.ones _orig_torch_full = torch.full +_orig_torch_arange = torch.arange +_orig_torch_eye = torch.eye +_orig_torch_randn = torch.randn -def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable: +def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype, + target_device: torch.device) -> Callable: def wrapped_fn(*args, **kwargs) -> Tensor: - if kwargs.get("device", None) is None: - kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + if kwargs.get("device", None) is None and target_device is not None: + kwargs['device'] = target_device tensor: Tensor = fn(*args, **kwargs) - if tensor.is_floating_point(): - tensor = tensor.to(target_fp_dtype) + if target_fp_dtype is not None and tensor.is_floating_point(): + tensor.data = tensor.data.to(target_fp_dtype) return tensor return wrapped_fn -def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: +def get_new_tensor_fn_for_dtype(target_fp_dtype: torch.dtype, target_device: torch.device) -> Callable: - def new_tensor(cls, *args) -> Tensor: - device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) - tensor = _orig_torch_empty(0, device=device).new_empty(*args) - if tensor.is_floating_point(): - tensor = tensor.to(dtype) + def new_tensor(cls, *args, **kwargs) -> Tensor: + if not args: + args = (0, ) + if target_device is None: + tensor = _orig_torch_empty(0).new_empty(*args, **kwargs) + else: + tensor = _orig_torch_empty(0, device=target_device).new_empty(*args, **kwargs) + + if tensor.is_floating_point() and target_fp_dtype is not None: + tensor = tensor.to(target_fp_dtype) return tensor @@ -248,7 +280,7 @@ def new_tensor(cls, *args) -> Tensor: # https://stackoverflow.com/a/63851681/9201239 -def get_all_subclasses(cls): +def get_all_subclasses(cls, include_root=True): subclass_list = [] def recurse(cl): @@ -258,17 +290,27 @@ def recurse(cl): recurse(cls) - return set(subclass_list) + ret = set(subclass_list) + if include_root: + ret.add(cls) + return ret @instrument_w_nvtx def free_param(param: Parameter) -> None: """Free underlying storage of a parameter.""" - assert not param.ds_active_sub_modules, param.ds_summary() + if param.ds_active_sub_modules: + raise RuntimeError("Cannot free a ZeRO-3 parameter while it is still active in submodules. " + "This can happen if: (1) submodules have not released the parameter, or " + "(2) you modified parameters inside a `GatheredParameters` context with " + "`modifier_rank=None`. For case (2), use `modifier_rank=` to broadcast " + "updates consistently across ranks. " + f"param={param.ds_summary()}") if get_accelerator().on_accelerator(param.data): # need to make sure that we don't free the parameter while it is still # being used for computation - param.data.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + param.data.record_stream(get_accelerator().current_stream()) # param.data doesn't store anything meaningful in partitioned state param.data = torch.empty(0, dtype=param.dtype, device=param.device) param.ds_status = ZeroParamStatus.NOT_AVAILABLE @@ -282,6 +324,8 @@ def free_param(param: Parameter) -> None: # Inserts _post_init_method at the end of init method # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): + num_module_parameters = 0 + num_module_elements = 0 def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None): self.mem_efficient_linear = mem_efficient_linear @@ -290,12 +334,68 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp assert self.dtype in [ torch.half, torch.bfloat16, torch.float ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" + self.wrapped_cls = set() + self.skip_init_depth = 0 + + self.quantized_initialization = None + if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization: + self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization def __enter__(self): - global zero_init_enabled if not self.enabled: return - zero_init_enabled = True + + global zero_init_context + if zero_init_context == 0: + self.patch_init_and_builtins() + global top_level_context + top_level_context = self + + zero_init_context += 1 + + def __exit__(self, exc_type, exc_value, traceback): + if not self.enabled: + return + + global zero_init_context + zero_init_context -= 1 + + # Exiting the top level context + if zero_init_context == 0: + self.unpatch_init_and_builtins() + global top_level_context + top_level_context = None + + if dist.get_rank() == 0: + billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9 + num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters + logger.info( + f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B") + + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module): + pass + + def _set_dtype(self, ds_config, dtype): + if ds_config is not None and dtype is None: + if ds_config.bfloat16_config.enabled and ds_config.float16_config.enabled: + raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + + if ds_config.bfloat16_config.enabled: + self.dtype = torch.bfloat16 + elif ds_config.float16_config.enabled: + self.dtype = torch.half + else: + self.dtype = torch.float + else: + self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported( + ) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32 + + def patch_init_and_builtins(self): def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: """many models make use of child modules like Linear or Embedding which @@ -328,14 +428,16 @@ def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: 3. broadcasts root rank's parameters to the other ranks 4. re-partitions the parameters """ - if not all(is_zero_param(p) for p in module_to_apply_fn_to.parameters(recurse=False)): - raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " - f"were zero params, is it possible that the parameters were " - f"overwritten after they were initialized? " - f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") + + # TODO Delay error checking for dangling partitioned parameters to post module init + # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + # f"were zero params, is it possible that the parameters were " + # f"overwritten after they were initialized? " + # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") params_to_apply_fn_to: Iterable[Parameter] = list( - sorted(module_to_apply_fn_to.parameters(recurse=False), key=lambda p: p.ds_id)) + sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)], + key=lambda p: p.ds_id)) for param in params_to_apply_fn_to: param.all_gather() @@ -358,6 +460,53 @@ def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: return wrapped_apply + def hook_for_skip_init(module): + # this function is intended for handling the logic of torch.nn.utils.skip_init + # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta' + # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device). + def partition_after_empty_init(f): + + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + _module = f(module, *args, **kwargs) + # here is the post-hook for module.apply(empty_like...) + # after module.apply(empty_like...), the module has completed its empty init on real device + # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init + self._post_init_method(_module) + return _module + + return wrapper + + def post_wrapper_to_empty(f): + # append some wrapper restoration after to_empty() call + @functools.wraps(f) + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + # restore _apply hook + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class_apply(subclass) + # self restore + module.to_empty = f + return res + + return wrapper + + def _enable_class_apply(cls): + if '_apply' in cls.__dict__: + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) + + def _disable_class_apply(cls): + if hasattr(cls, '_old_apply_of_skip_init_hook'): + cls._apply = cls._old_apply_of_skip_init_hook + + # add hooks for to_empty: apply_(empty_like) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) + + # add a restore hook when exiting skip_init + module.to_empty = post_wrapper_to_empty(module.to_empty) + def partition_after(f): @functools.wraps(f) @@ -379,29 +528,40 @@ def wrapper(module, *args, **kwargs): is_child_module = True setattr(module, "_ds_child_entered", True) - f(module, *args, **kwargs) + init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta' + if init_on_meta: + self.skip_init_depth += 1 + f(module, *args, **kwargs) + if init_on_meta and self.skip_init_depth == 1: + # check and handle the logic of empty_init + hook_for_skip_init(module) if is_child_module: # child's __init__ is done, now we can run a single post_init on the child object delattr(module, "_ds_child_entered") print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False) - self._post_init_method(module) + if self.skip_init_depth == 0: + self._post_init_method(module) print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) + if init_on_meta: + self.skip_init_depth -= 1 return wrapper def _enable_class(cls): - cls._old_init = cls.__init__ - cls.__init__ = partition_after(cls.__init__) + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): - cls.__init__ = partition_after(cls.__init__) + if '__init__' in cls.__dict__: + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively for subclass in get_all_subclasses(torch.nn.modules.module.Module): - # print(f"subclass={subclass.__module__}.{subclass.__qualname__}") _enable_class(subclass) # holding onto some methods so we can put them back the way they were in __exit__ @@ -411,151 +571,306 @@ def _init_subclass(cls, **kwargs): # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) - torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) - torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) - torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) - torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) - torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) + if self.tensor_overrides: + self._add_tensor_creation_wrappers() if self.mem_efficient_linear: print_rank_0( "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", force=False) - self.linear_bk = torch.nn.functional.linear - torch.nn.functional.linear = zero3_linear_wrap + if not hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): + InsertPostInitMethodToModuleSubClasses.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = zero3_linear_wrap - def __exit__(self, exc_type, exc_value, traceback): - if not self.enabled: - return + if self.quantized_initialization: + print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False) + torch.nn.functional.linear = wrap_quantized_functional(torch.nn.functional.linear) + torch.nn.functional.embedding = wrap_quantized_functional(torch.nn.functional.embedding) + for cls in WEIGHT_QUANTIZATION_LAYERS: + cls._load_from_state_dict = wrap_load_from_state_dict(cls._load_from_state_dict) - shutdown_init_context() + logger.info("Enable Zero3 engine with INT4 quantization.") - if dist.get_rank() == 0: - logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) + self.patched = True - # Now that we cleaned up the metaclass injection, raise the exception. - if exc_type is not None: - return False + def unpatch_init_and_builtins(self): + if self.patched: - # To be implemented by inheriting classes - def _post_init_method(self, module): - pass + def _disable_class(cls): + if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'): + cls.__init__ = cls._old_init - def _set_dtype(self, ds_config, dtype): - if ds_config is not None and dtype is None: - if ds_config.bfloat16_enabled and ds_config.fp16_enabled: - raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class(subclass) - if ds_config.bfloat16_enabled: - self.dtype = torch.bfloat16 - elif ds_config.fp16_enabled: - self.dtype = torch.half - else: - self.dtype = torch.float - else: - self.dtype = dtype or torch.half + # putting methods back the way we found them + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + if self.tensor_overrides: + self._remove_tensor_creation_wrappers() -def shutdown_init_context(): - global zero_init_enabled - - if not zero_init_enabled: - return + self.patched = False - def _disable_class(cls): - cls.__init__ = cls._old_init - - # Replace .__init__() for all existing subclasses of torch.nn.Module - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - _disable_class(subclass) + def _add_tensor_creation_wrappers(self): + if DeepSpeedTensorOverride.dtype in self.tensor_overrides: + target_fp_dtype = self.dtype + else: + target_fp_dtype = None + if DeepSpeedTensorOverride.device in self.tensor_overrides: + target_device = self.local_device + else: + target_device = None + + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, + target_fp_dtype=target_fp_dtype, + target_device=target_device) + + def _remove_tensor_creation_wrappers(self): + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.tensor = _orig_torch_tensor + torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full + torch.arange = _orig_torch_arange + torch.eye = _orig_torch_eye + torch.randn = _orig_torch_randn - # putting methods back the way we found them - torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply - torch.Tensor.__new__ = torch.Tensor.__old_new__ - torch.empty = _orig_torch_empty - torch.zeros = _orig_torch_zeros - torch.ones = _orig_torch_ones - torch.full = _orig_torch_full +def shutdown_init_context(): + """ + This function is used to initialize deepspeed engine inside the context of Init. + We need to remove the wrappers but keep the context. + """ + if top_level_context: + top_level_context.unpatch_init_and_builtins() - # un doing it here will undo it during training - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk - zero_init_enabled = False +def restore_init_context(): + """ + This function is used to restore the wrappers after deepspeed engine is initialized. + """ + if top_level_context: + top_level_context.patch_init_and_builtins() class AllGatherHandle: - def __init__(self, handle, param: Parameter) -> None: + def __init__(self, handle, param: Parameter, quantization=None, param_buffer=None, original_dtype=None) -> None: if param.ds_status != ZeroParamStatus.INFLIGHT: raise RuntimeError(f"expected param {param.ds_summary()} to be available") self.__handle = handle self.__param = param + self.__quantization = quantization + self.__param_buffer = param_buffer + self.__original_dtype = original_dtype - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: instrument_w_nvtx(self.__handle.wait)() + + if self.__param_buffer is not None: + self.__param.data = self.__param_buffer.narrow(0, 0, self.__param.ds_numel).view(self.__param.ds_shape).to( + self.__original_dtype).to(self.__param.device) + elif self.__quantization: + instrument_w_nvtx(self.__quantization.quant_handle.wait)() + self.__param.data = self.__quantization.backend.dequantize( + self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device) self.__param.ds_status = ZeroParamStatus.AVAILABLE class AllGatherCoalescedHandle: + data_buffer = [] + def __init__( self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int, + use_secondary_tensor=False, + quantization=None, ) -> None: - self.__allgather_handle = allgather_handle - self.__params = params - self.__partitions = partitions - self.__world_size = world_size - self.__complete = False - - for param in self.__params: + self.allgather_handle = allgather_handle + self.params = params + self.partitions = partitions + self.world_size = world_size + self.use_secondary_tensor = use_secondary_tensor + self.complete = False + self.quantization = quantization + + for param in self.params: if param.ds_status != ZeroParamStatus.INFLIGHT: raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: - if self.__complete: + def wait(self, handle_dependency=True) -> None: + if self.complete: return - instrument_w_nvtx(self.__allgather_handle.wait)() + instrument_w_nvtx(self.allgather_handle.wait)() + + if self.quantization: + instrument_w_nvtx(self.quantization.quant_handle.wait)() + flat_tensor = self.quantization.backend.dequantize( + self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device) + + self.partitions: List[Parameter] = [] + for i in range(self.world_size): + self.partitions.append( + flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz)) # split the single tensor out into individual tensors param_offset = 0 - for param in self.__params: + for param in self.params: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" partitions: List[Tensor] = [] - for rank in range(self.__world_size): - param_start = rank * param.ds_tensor.ds_numel + ds_tensor_numel = param.ds_tensor.ds_numel + if self.use_secondary_tensor: + ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups + for rank in range(self.world_size): + param_start = rank * ds_tensor_numel if param_start < param.ds_numel: - part_to_copy = self.__partitions[rank].narrow( - 0, param_offset, min(param.ds_numel - param_start, param.ds_tensor.ds_numel)) + part_to_copy = self.partitions[rank].narrow(0, param_offset, + min(param.ds_numel - param_start, ds_tensor_numel)) partitions.append(part_to_copy) - param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) + # Note that dtypes of param and partitions can be different (currently for torch.autocast support) + param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape).to(param.ds_tensor.dtype) param.ds_status = ZeroParamStatus.AVAILABLE + if not get_accelerator().is_synchronized_device() and handle_dependency: + for part_to_copy in partitions: + part_to_copy.record_stream(get_accelerator().current_stream()) - for part_to_copy in partitions: - part_to_copy.record_stream(get_accelerator().current_stream()) + param_offset += ds_tensor_numel - param_offset += param.ds_tensor.ds_numel + self.complete = True + if not get_accelerator().is_synchronized_device() and not handle_dependency: + # if the device needs to handle dependencies and opts for explicit processing outside the function. + AllGatherCoalescedHandle.data_buffer.append(partitions) - self.__complete = True + @staticmethod + def free_buffer(): + AllGatherCoalescedHandle.data_buffer = [] + + +class MultipleAllGatherHandles: + + def __init__(self, handles: List[Union[AllGatherHandle, AllGatherCoalescedHandle]]): + self.handles = handles + + def wait(self, handle_dependency=True) -> None: + for handle in self.handles: + handle.wait(handle_dependency) + + +class AllReduceCoalescedHandle: + + def __init__(self, handle, params: List[Parameter]) -> None: + self.handle = handle + self.params = params + self.complete = False + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self, **kwargs) -> None: + if self.complete: + return + + instrument_w_nvtx(self.handle.wait)() + + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + +class QuantizationInfo: + # a placeholder object to store all quant related vars used in handles + def __init__(self) -> None: + self.quantized_param = None + self.backend = None + self.quant_handle = None + self.scale_buffer = None + + +class CUDAQuantizer: + async_flag = True + target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k + group_size_cache = dict() + quantizer_cuda_module = None + + def __init__(self) -> None: + if CUDAQuantizer.quantizer_cuda_module is None: + CUDAQuantizer.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load() + + def quantize(self, param, groups=None): + if groups is None: + try: + groups = self.group_size_cache[param.numel()] + except KeyError: + groups = math.ceil(param.numel() / self.target_group_size) + while groups < param.numel(): + if param.numel() % (8 * groups) == 0: + break + groups += 1 + while True: + if param.numel() % (8 * groups * 2) == 0 and param.numel( + ) / groups > self.target_group_size: #hard limit of 16k group_size + groups *= 2 + else: + break + assert ( + param.numel() % (8 * groups) == 0 + ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}" + assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k" + assert param.numel( + ) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}" + self.group_size_cache[param.numel()] = groups + return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8, + self.quantizer_cuda_module.Symmetric) + + def dequantize(self, quantized_param, scale): + return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8, + self.quantizer_cuda_module.Symmetric) def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle: for param in params: if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(param.ds_summary()) + raise RuntimeError(f"expect param.ds_status == ZeroParamStatus.NOT_AVAILABLE, got{param.ds_summary()}") param.ds_status = ZeroParamStatus.INFLIGHT params = sorted(params, key=lambda p: p.ds_id) @@ -573,6 +888,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): num_persisted_parameters = 0 num_persisted_elements = 0 apply_param_persistence = False + override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply") def __init__(self, module=None, @@ -584,7 +900,13 @@ def __init__(self, config=None, enabled=True, dtype=None, - mpu=None): + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + sequence_data_parallel_group=None, + param_swapper=None, + tensor_overrides=DEFAULT_TENSOR_OVERRIDES): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision. @@ -594,6 +916,8 @@ def __init__(self, if it was constructed in the context. data_parallel_group (``deepspeed.comm`` process group, optional): The group of processes to partition among. Defaults to all processes. + Synonymous with sequence data parallel group for param partitioning + across both sequence and data parallel groups. mem_efficient_linear (bool, optional): Replace torch.nn.functional.linear with an implementation that allows DeepSpeed to partition parameters. Defaults to ``True``. @@ -606,13 +930,19 @@ def __init__(self, using pinned memory for model weights. ``remote_device`` must be ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration - for swapping fp16 params to NVMe. + for swapping fp16 params to NVMe and other things like ``dtype``. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. dtype (``dtype``, optional): Can be used to change the data type of the parameters. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}. + zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params. + zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False`` + zero_quantized_nontrainable_weights (bool, optional): If ``True``, nontrainable weights will be stored in quantized format. Default is ``False`` + param_swapper (``deepspeed.runtime.swap_tensor.partitioned_param_swapper.AsyncPartitionedParameterSwapper``, optional): [Experimental] Use existing parameter swapper. Defaults to ``None``. + This argument will be removed in the near future. + tensor_overrides ([`deepspeed.runtime.zero.DeepSpeedTensorOverride`], optional): Tensor attributes to override. Defaults to overriding dtype and device. This context accelerates model initialization and enables models that are too large to allocate in their entirety in CPU memory. It has the @@ -644,15 +974,6 @@ def __init__(self, Initializes ``deepspeed.comm`` if it has not already been done so. See :meth:`deepspeed.init_distributed` for more information. - .. note:: - Can also be used as a decorator: - - .. code-block:: python - - @deepspeed.zero.Init() - def get_model(): - return MyLargeModel() - .. note:: Only applicable to training with ZeRO-3. @@ -685,31 +1006,77 @@ def get_model(): """ if config is not None: config_dict_or_path = config - logger.warning( - f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.') + logger.warning('zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.') _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) if config_dict_or_path is not None else None if _ds_config is not None: mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + + self.tensor_overrides = tensor_overrides super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) if not dist.is_initialized(): init_distributed() assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + if data_parallel_group is None: self.ds_process_group = dist.get_world_group() else: self.ds_process_group = data_parallel_group + if sequence_data_parallel_group is not None: + logger.warning( + "sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.") + if data_parallel_group is not None: + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.ds_process_group = sequence_data_parallel_group + self.rank = dist.get_rank(group=self.ds_process_group) - self.world_size = dist.get_world_size(group=self.ds_process_group) + self.dp_world_size = dist.get_world_size(group=self.ds_process_group) + + self.zero_param_process_group = zero_param_parallel_group + if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None: + groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size) + self.zero_param_process_group = groups._get_zero_param_intra_parallel_group() + + self.num_ranks_in_param_group = self.dp_world_size + self.rank_in_group = self.rank + self.num_param_groups = 1 + + if self.zero_param_process_group is not None: + self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size() + self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group) + self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup() + print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=False) + + logger.debug( + "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} " + .format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks())) # Local device is the device where the parameters are consumed, must be default device. # It is the device where parameters are fully instantiated using allgather self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) get_accelerator().set_device(self.local_device) + self.quantized_weights = zero_quantized_weights + if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights: + self.quantized_weights = _ds_config.zero_config.zero_quantized_weights + self.quantized_nontrainable_weights = zero_quantized_nontrainable_weights + if _ds_config is not None and _ds_config.zero_config.zero_quantized_nontrainable_weights and not self.quantized_nontrainable_weights: + self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights + + self.enable_sanity_checks = get_config_default(DeepSpeedZeroConfig, "enable_sanity_checks") if _ds_config is not None: - self._update_persist_config(_ds_config) + self.enable_sanity_checks = _ds_config.zero_config.enable_sanity_checks + + self.module = module + if (self.quantized_weights or self.quantized_nontrainable_weights): + self.quantizer_module = CUDAQuantizer() + print_rank_0(f'Using quantizer for weights: {self.quantizer_module.__class__.__name__}', force=False) + + if _ds_config is not None: + Init.override_module_apply = _ds_config.zero_config.override_module_apply if _ds_config.zero_config.offload_param is not None: remote_device = _ds_config.zero_config.offload_param.device @@ -725,7 +1092,7 @@ def get_model(): # Enable fp16 param swapping to NVMe if self.remote_device == OffloadDeviceEnum.nvme: - self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype) + self.param_swapper = param_swapper or AsyncPartitionedParameterSwapper(_ds_config, self.dtype) else: self.param_swapper = None @@ -738,17 +1105,34 @@ def get_model(): if not self.use_all_gather_into_tensor: logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}") + self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig, + "use_all_reduce_for_fetch_params") + self.allgather_sequential = get_config_default(DeepSpeedZeroConfig, "allgather_sequential") + if _ds_config is not None: + self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params + self.allgather_sequential = _ds_config.zero_config.allgather_sequential + def _update_persist_config(self, ds_config): Init.apply_param_persistence = True Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold - Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.world_size + Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions + + def _zero_init_param(self, param): + self._convert_to_deepspeed_param(param) + if dist.get_world_group() == self.get_dp_process_group(): + dist.broadcast(param.data, 0, self.get_dp_process_group()) + else: + dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0), + self.get_dp_process_group()) + param.partition() def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): continue - self._convert_to_deepspeed_param(param) - param.partition() + + param.data = param.data.to(self.local_device) + self._zero_init_param(param) def _validate_remote_device(self, remote_device, ds_config): if ds_config is not None: @@ -766,28 +1150,27 @@ def _validate_remote_device(self, remote_device, ds_config): f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}' def _post_init_method(self, module): - #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) + #see_memory_usage(f"Before converting params in {module.__class__.__name__}", force=False) print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False) - see_memory_usage(f"Before converting and partitioning parmas in {module.__class__.__name__}", force=False) + see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False) - global param_count for name, param in module.named_parameters(recurse=False): - param_count += param.numel() + print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False) + InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1 + InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel() if not is_zero_param(param): - self._convert_to_deepspeed_param(param) + if not get_accelerator().on_accelerator(param): + param.data = param.data.to(self.local_device) + + if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS: + _quantize_param(param, self.quantized_initialization) + + self._zero_init_param(param) print_rank_0( f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}") - if get_accelerator().on_accelerator(param): - dist.broadcast(param, 0, self.ds_process_group) - else: - if dist.get_rank() == 0: - logger.warn(f"param `{name}` in {module.__class__.__name__} " - f"not on GPU so was not broadcasted from rank 0") - - param.partition() see_memory_usage( - f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}", + f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}", force=False) def _convert_to_deepspeed_param(self, param): @@ -823,6 +1206,15 @@ def _convert_to_deepspeed_param(self, param): # The group that the parameter is scattered across. param.ds_process_group = self.ds_process_group + param.ds_enable_sanity_checks = self.enable_sanity_checks + + # Stores the secondary partitioned copy of the tensor + param.ds_secondary_tensor = None + + #Process group for secondary partition all (group) gather + param.ds_zero_param_process_group = self.zero_param_process_group + param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group + param.ds_secondary_tensor_num_of_groups = self.num_param_groups # This is set to the Async Param swapper if remote device is nvme # else this is set to None @@ -838,13 +1230,223 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): param_list = [cls] return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) + def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allgather_dtype): + # make sure all params have the same dtype + dtype = params[0].dtype # we assume len(params) > 0 + assert all(p.dtype == dtype for p in params), "all params must have the same dtype" + + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + use_secondary_tensor = params[0].ds_secondary_tensor is not None + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=allgather_dtype, + device=get_accelerator().current_device_name(), + requires_grad=False) + + partitions: List[Parameter] = [] + for i in range(world_size): + partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + + if use_secondary_tensor: + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype) + for p in params + ], + out=partitions[rank_in_group]) + else: + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype) for p in params], + out=partitions[rank_in_group]) + handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) + #Fix get_partition_dp_group(params[0])) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + ) + + def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_group, quantize): + handles = [] + for param in params: + buffer_size = math.ceil(param.ds_numel / world_size) * world_size + if use_secondary_tensor: + buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized + + param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor + + original_dtype = param_ds_tensor.dtype + if quantize: + allgather_dtype = torch.int8 + else: + allgather_dtype = get_allgather_dtype(param, param_ds_tensor) + + param_buffer = torch.empty( + buffer_size, + dtype=allgather_dtype, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + if not quantize: + handle = _dist_allgather_fn( + param_ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype), + param_buffer, + ds_process_group, + ) + + if original_dtype == allgather_dtype: + param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) + handles.append(AllGatherHandle(handle, param)) + else: + # This case is complicated: + # We use `register_post_accumulate_grad_hook` to set allgather hooks. Normally, the hook is + # called once per parameter, even if that parameter is tied to multiple layers. + # However, when the dtype changes, the hook may be triggered multiple times. + # If we directly do: + # param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) + # as above, the dtype may differ, causing the gradient-reduce hook + # to be invoked multiple times. + # To avoid this, we leave `param.data` in a partitioned state. + # This prevents duplicate gradient-reduce hook calls. + # In theory, this path could be consolidated with the case where + # (original_dtype == allgather_dtype), but because it changes the + # state transition of DeepSpeed parameters, we keep it separate for safety. + handles.append( + AllGatherHandle(handle, param, param_buffer=param_buffer, original_dtype=original_dtype)) + else: + if hasattr(param_ds_tensor, "ds_quant_scale"): + scales = param_ds_tensor.ds_quant_scale + quantized_param = param_ds_tensor.data + else: + quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor) + handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()), + param_buffer, ds_process_group) + + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=scales.dtype, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()), + quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to( + param.device) + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + handles.append(AllGatherHandle(handle, param, quantization=quant_info)) + return MultipleAllGatherHandles(handles) + + def _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tensor, ds_process_group, quantize): + if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: + + # Use all_reduce instead of all_gather to fetch the module params + flat_buffer_size = sum(p.ds_numel_aligned for p in params) + flat_tensor = torch.zeros(flat_buffer_size, + dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), + device=get_accelerator().current_device_name(), + requires_grad=False) + start_param = 0 + for param in params: + param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) + start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) + + start_param += param.ds_numel + + handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) + + return AllReduceCoalescedHandle(handle=handle, params=params) + else: + if not quantize: + dtype_params = defaultdict(list) + for p in params: + allgather_dtype = get_allgather_dtype(p, p.ds_tensor) + dtype_params[allgather_dtype].append(p) + handles = [] + for dtype in sort_dtypes(dtype_params.keys()): + handles.append( + _all_gather_dtype(dtype_params[dtype], world_size, rank_in_group, ds_process_group, dtype)) + + return MultipleAllGatherHandles(handles) + + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params + ]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params + ])) + else: + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + ) + @instrument_w_nvtx - def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -> AllGatherCoalescedHandle: + def all_gather_coalesced(params: Iterable[Parameter], + safe_mode: bool = False, + quantize: bool = False) -> AllGatherCoalescedHandle: # fetches from nvme if the partition is not available and in nvme self._ensure_availability_of_partitioned_params(params) - if self.world_size == 1: + if self.num_partitions == 1: return _no_gather_coalesced(params) for param in params: @@ -852,6 +1454,17 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) - raise RuntimeError(param.ds_summary()) param.ds_status = ZeroParamStatus.INFLIGHT + #use appropriate all gather process group + ds_process_group = self.ds_process_group + rank_in_group = self.rank + world_size = self.dp_world_size + use_secondary_tensor = params[0].ds_secondary_tensor is not None + if self.zero_param_process_group and use_secondary_tensor: + ds_process_group = self.zero_param_process_group #intragroup + rank_in_group = self.rank_in_group + world_size = self.num_ranks_in_param_group + + #pprint(dir(ds_process_group)) # ensure that each rank has params in same order. the allgather # is done by flattening the parameter list into a single tensor that # can be allgathered in a single call - this means that if each rank @@ -860,7 +1473,8 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) - # to debug correctness issues. params = sorted(params, key=lambda p: p.ds_id) - debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") if safe_mode: # ensure that same list (with same ordering) of parameters are @@ -871,46 +1485,19 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) - # otherwise could mix data between tensors. assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) - if len(params) == 1: - # have an opportunity to avoid some intermediate memory allocations - param, = params - param_buffer = torch.empty( - math.ceil(param.ds_numel / self.world_size) * self.world_size, - dtype=param.dtype, - device=get_accelerator().current_device_name(), - requires_grad=False, - ) - handle = _dist_allgather_fn(param.ds_tensor.to(get_accelerator().current_device_name()), param_buffer, - self.ds_process_group) - param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) - return AllGatherHandle(handle, param) + if self.allgather_sequential or len(params) == 1: + return _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_group, quantize) else: - partition_sz = sum(p.ds_tensor.ds_numel for p in params) - flat_tensor = torch.empty(partition_sz * self.world_size, - dtype=get_only_unique_item(p.dtype for p in params), - device=get_accelerator().current_device_name(), - requires_grad=False) - partitions: List[Parameter] = [] - for i in range(self.world_size): - partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) - - instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], - out=partitions[self.rank]) - handle = _dist_allgather_fn(partitions[self.rank], flat_tensor, self.ds_process_group) - - return AllGatherCoalescedHandle( - allgather_handle=handle, - params=params, - partitions=partitions, - world_size=self.world_size, - ) + return _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tensor, ds_process_group, + quantize) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param - print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}") + print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", + force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=True) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -957,6 +1544,7 @@ def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict: "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None, "persist": slf.ds_persist, "active_sub_modules": slf.ds_active_sub_modules, + "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None } def convert_to_zero_parameters(param_list): @@ -993,8 +1581,8 @@ def _aligned_size(self, param): return param.ds_numel + self._padding_size(param) def _padding_size(self, param): - remainder = param.ds_numel % self.world_size - return (self.world_size - remainder) if remainder else 0 + remainder = param.ds_numel % self.num_partitions + return (self.num_partitions - remainder) if remainder else 0 def _partition_numel(self, param): return param.ds_tensor.ds_numel @@ -1030,37 +1618,46 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): handles.append(handle) else: all_gather_list.append(param) - + # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list if not async_op: - if len(param_list) == 1: - ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + if self.allgather_sequential or len(all_gather_list) == 1: + ret_value = self._allgather_params_sequential(all_gather_list, hierarchy=hierarchy) else: - ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) - + all_gather_quantize_list = [] + all_gather_nonquantize_list = [] + for param in all_gather_list: + if hasattr(param.ds_tensor, + "ds_quant_scale") or (hasattr(param, "ds_secondary_tensor") + and hasattr(param.ds_secondary_tensor, "ds_quant_scale")): + all_gather_quantize_list.append(param) + else: + all_gather_nonquantize_list.append(param) + # _allgather_params_coalesced always return None + self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False) + self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE - return ret_value + return None return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: - #print_rank_0(f"Before Partitioning Param {param.ds_id}") - # self._param_status(param) - self._partition_param(param, has_been_updated=has_been_updated) + print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) + if self.zero_param_process_group is not None: + self._partition_param_sec(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=True) + param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: # assert id(param.data) == id(param.ds_tensor.data), \ # "After the parameters are initially partitioned, make sure we are not recreating the partition." - #print_rank_0(f"After Partitioning Param {param.ds_id}") - # self._param_status(param) - + #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" - global reuse_buffers - #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}") + print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) if param.ds_status is ZeroParamStatus.AVAILABLE: print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False) # if reuse_buffers and False: @@ -1074,24 +1671,29 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): # if deepspeed.comm.get_rank(): # print(f"Releasing {param.data.numel()}") - if param.ds_tensor is not None and not has_been_updated: + if param.ds_tensor is not None and not has_been_updated: ##param already partitioned + + #print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=False) #param.data = param.ds_tensor.data see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False) param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0( + f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme", + force=False) return tensor_size = self._aligned_size(param) - partition_size = tensor_size // self.world_size - + partition_size = tensor_size // self.num_partitions if param.ds_tensor is None: final_location = None if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor( @@ -1111,6 +1713,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): device = self.remote_device partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device) + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + partitioned_tensor, partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + partitioned_tensor) if device == OffloadDeviceEnum.cpu and self.pin_memory: partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor) @@ -1120,8 +1726,9 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): param.ds_tensor.ds_numel = partition_size param.ds_tensor.status = PartitionedParamStatus.AVAILABLE param.ds_tensor.final_location = final_location + param.ds_numel_aligned = tensor_size - start = partition_size * self.rank + start = partition_size * self.get_partition_rank() end = start + partition_size one_dim_param = param.contiguous().view(-1) @@ -1129,7 +1736,11 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): if start < param.ds_numel and end <= param.ds_numel: src_tensor = one_dim_param.narrow(0, start, partition_size) - param.ds_tensor.copy_(src_tensor) + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.copy_(src_tensor) + #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device) else: @@ -1138,9 +1749,12 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): # device=self.remote_device ) if start < param.ds_numel: - elements_to_copy = param.ds_numel - start - param.ds_tensor.narrow(0, 0, - elements_to_copy).copy_(one_dim_param.narrow(0, start, elements_to_copy)) + elems_to_copy = param.ds_numel - start + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.narrow(0, 0, + elems_to_copy).copy_(one_dim_param.narrow(0, start, elems_to_copy)) #print(f"Remote device {self.remote_device}") @@ -1161,6 +1775,61 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}") + @instrument_w_nvtx + def _partition_param_sec(self, param, buffer=None, has_been_updated=False): + assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" + global reuse_buffers + ##support for NVME secondary param offload + #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=False) + if param.ds_status is ZeroParamStatus.AVAILABLE: + if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned + return + #check padding + tensor_size = self._aligned_size(param) + partition_size = tensor_size // self.dp_world_size + + secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group) + if param.ds_secondary_tensor is None: + final_location = None + secondary_partitioned_tensor = torch.empty(secondary_partition_size, + dtype=param.dtype, + device=self.remote_device) + + if self.pin_memory: + secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory() + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + secondary_partitioned_tensor) + secondary_partitioned_tensor.requires_grad = False + param.ds_secondary_tensor = secondary_partitioned_tensor + param.ds_secondary_tensor.ds_numel = secondary_partition_size + param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_secondary_tensor.final_location = final_location + + #use rank in group for secondary tensor + secondary_start = secondary_partition_size * self.rank_in_group + + secondary_end = secondary_start + secondary_partition_size + + one_dim_param = param.contiguous().view(-1) + + # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size + sec_numel = max(0, min(param.ds_numel - secondary_start, secondary_partition_size)) + + # copy from full tensor to secondary tensor + with torch.no_grad(): + # make sure param.ds_secondary_tensor requires_grad always be false + param.ds_secondary_tensor.narrow(0, 0, + sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + + # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().synchronize() + + print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}", + force=False) + def _param_status(self, param): if param.ds_tensor is not None: print_rank_0( @@ -1175,7 +1844,7 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): partition_size = param.ds_tensor.ds_numel - tensor_size = partition_size * self.world_size + tensor_size = partition_size * self.num_partitions aligned_param_size = self._aligned_size(param) assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}' @@ -1191,7 +1860,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ', force=False) - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() print_rank_0( f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}" @@ -1203,33 +1873,35 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): # param.data = replicated_tensor.data # return None if self.use_all_gather_into_tensor: - # try the all_gather_into_tensor on PyTorch master branch handle = dist.all_gather_into_tensor(flat_tensor, param.ds_tensor.to(get_accelerator().device_name()), - group=self.ds_process_group, + group=self.get_partition_dp_group(param), async_op=async_op) else: partitions = [] - for i in range(self.world_size): + for i in range(self.num_partitions): partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) - if i == dist.get_rank(group=self.ds_process_group): + if i == dist.get_rank(group=self.get_partition_dp_group(param)): partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - handle = dist.all_gather(partitions, partitions[self.rank], group=self.ds_process_group, async_op=async_op) + handle = dist.all_gather(partitions, + partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=async_op) replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle - def _allgather_params_coalesced(self, param_list, hierarchy=0): + def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): """ blocking call avoid explicit memory copy in _allgather_params """ if len(param_list) == 0: return - if self.world_size == 1: + if self.num_partitions == 1: handle = _no_gather_coalesced(param_list) handle.wait() return None @@ -1237,34 +1909,55 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): # collect local tensors and partition sizes partition_sizes = [] local_tensors = [] + if quantize: + quantize_scale_sizes = [] + quantize_scale_tensors = [] for param in param_list: partition_sizes.append(param.ds_tensor.ds_numel) local_tensors.append(param.ds_tensor.to(get_accelerator().device_name())) - + if quantize: + quantize_scale_sizes.append(param.ds_tensor.ds_quant_scale.numel()) + quantize_scale_tensors.append(param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name())) # allocate memory for allgather params allgather_params = [] + if quantize: + allgather_quantize_scale = [] for psize in partition_sizes: - tensor_size = psize * self.world_size - flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device).view(-1) + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, + device=self.local_device).view(-1) flat_tensor.requires_grad = False allgather_params.append(flat_tensor) + if quantize: + for psize in quantize_scale_sizes: + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, + device=self.local_device).view(-1) + flat_tensor.requires_grad = False + allgather_quantize_scale.append(flat_tensor) # launch launch_handles = [] - # backend = get_backend(self.ds_process_group) - # with _batch_p2p_manager(backend): + launch_quantize_handles = [] for param_idx, param in enumerate(param_list): input_tensor = local_tensors[param_idx].view(-1) if self.use_all_gather_into_tensor: - # try the all_gather_into_tensor from Pytorch master + # try the _all_gather_base from Pytorch master h = dist.all_gather_into_tensor(allgather_params[param_idx], input_tensor, - group=self.ds_process_group, + group=self.get_partition_dp_group(param), async_op=True) + if quantize: + quantize_handle = dist.all_gather_into_tensor(allgather_quantize_scale[param_idx], + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True) + launch_quantize_handles.append(quantize_handle) else: output_list = [] - for i in range(self.world_size): + for i in range(self.num_partitions): psize = partition_sizes[param_idx] partition = allgather_params[param_idx].narrow(0, i * psize, psize) output_list.append(partition) @@ -1272,71 +1965,95 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): logger.warning( f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}') - # back to old all_gather function signature - h = dist.all_gather(output_list, input_tensor, group=self.ds_process_group, async_op=True) + # back to old all_gather function + h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True) + if quantize: + output_scale_list = [] + for i in range(self.num_partitions): + psize = quantize_scale_sizes[param_idx] + partition = allgather_quantize_scale[param_idx].narrow(0, i * psize, psize) + output_scale_list.append(partition) + quant_handle = dist.all_gather(output_scale_list, + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True) + launch_quantize_handles.append(quant_handle) launch_handles.append(h) # Wait ensures the operation is enqueued, but not necessarily complete. launch_handles[-1].wait() + if quantize: + for quant_handle in launch_quantize_handles: + quant_handle.wait() # assign to param.data (not copy) for i, param in enumerate(param_list): gathered_tensor = allgather_params[i] + if quantize: + gathered_tensor = self.quantizer_module.dequantize(gathered_tensor, allgather_quantize_scale[i]) param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data # guarantee the communication to be completed - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() return None - def _allgather_params(self, param_list, hierarchy=0): + def _allgather_params_sequential(self, param_list, hierarchy=0): if len(param_list) == 0: return - partition_size = sum([param.ds_tensor.ds_numel for param in param_list]) - - tensor_size = partition_size * self.world_size - flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device) - flat_tensor.requires_grad = False - partitions = [] - for i in range(self.world_size): - start = partition_size * i - - partitions.append(flat_tensor.narrow(0, start, partition_size)) - - if i == self.rank: - offset = 0 - for param in param_list: - param_numel = param.ds_tensor.ds_numel - - partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data) - - offset += param_numel - - dist.all_gather(partitions, partitions[self.rank], group=self.ds_process_group, async_op=False) - param_offset = 0 - for param in param_list: - param_partition_size = param.ds_tensor.ds_numel - param_size = param.ds_numel - replicated_tensor = torch.empty(param.ds_shape, dtype=param.dtype, device=self.local_device) - - for i in range(self.world_size): - - start = i * partition_size - - param_start = i * param_partition_size - - if param_start < param_size: - numel_to_copy = min(param_size - param_start, param_partition_size) - - part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy) + partition_size = param.ds_tensor.ds_numel + tensor_size = partition_size * self.num_partitions - replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy) - #param_offset += param.data.numel() - param_offset += param.ds_tensor.ds_numel + flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device) + flat_tensor.requires_grad = False + if self.use_all_gather_into_tensor: + dist.all_gather_into_tensor(flat_tensor, + param.ds_tensor.to(get_accelerator().device_name()), + group=self.get_partition_dp_group(param), + async_op=False) + else: + partitions = [] + for i in range(self.num_partitions): + partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + if i == self.get_partition_rank(): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + dist.all_gather(partitions, + partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=False) + + if hasattr(param.ds_tensor, 'ds_quant_scale'): + scale_size = param.ds_tensor.ds_quant_scale.numel() + scale_tensor_size = scale_size * self.num_partitions + flat_scale_tensor = torch.empty(scale_tensor_size, + dtype=param.ds_tensor.ds_quant_scale.dtype, + device=self.local_device) + flat_scale_tensor.requires_grad = False + if self.use_all_gather_into_tensor: + dist.all_gather_into_tensor(flat_scale_tensor, + param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name()), + group=self.get_partition_dp_group(param), + async_op=False) + else: + scale_partitions = [] + for i in range(self.num_partitions): + scale_partitions.append(flat_scale_tensor.narrow(0, scale_size * i, scale_size)) + if i == self.get_partition_rank(): + scale_partitions[i].data.copy_(param.ds_tensor.ds_quant_scale.data, non_blocking=True) + dist.all_gather(scale_partitions, + scale_partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=False) + flat_tensor = self.quantizer_module.dequantize(flat_tensor, flat_scale_tensor) + + param.data = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) - param.data = replicated_tensor.data + # guarantee the communication to be completed + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() return None @@ -1359,10 +2076,10 @@ def _reduce_scatter_gradients(self, param_list): # For these ranks the output of reduce scatter is a separate buffer and needs # to be copied in partition_size = param.ds_tensor.ds_numel - start = self.rank * partition_size + start = self.get_partition_rank() * partition_size end = start + partition_size - #print_rank_0("REduce scatter was executed for praam {param.ds_id}") - if start < param.ds_numel and end > param.ds_numel: + #print_rank_0("REduce scatter was executed for param {param.ds_id}") + if start < param.ds_numel < end: elements = param.ds_numel - start param.grad.view(-1).narrow(0, start, elements).copy_(reduced_partition.narrow(0, 0, elements)) @@ -1371,10 +2088,10 @@ def _reduce_scatter_gradient(self, param): partition_size = param.ds_tensor.ds_numel #output = torch.empty(partition_size, dtype=param.dtype, device=param.device) - total_size = partition_size * self.world_size + total_size = partition_size * self.num_partitions input_list = [] - for i in range(self.world_size): + for i in range(self.num_partitions): start = i * partition_size end = start + partition_size @@ -1391,8 +2108,11 @@ def _reduce_scatter_gradient(self, param): #print("after reduce scatter gradients") input_list.append(input) - rank = dist.get_rank(group=self.ds_process_group) - handle = dist.reduce_scatter(input_list[rank], input_list, group=self.ds_process_group, async_op=True) + rank = dist.get_rank(group=self.get_partition_dp_group(param)) + handle = dist.reduce_scatter(input_list[rank], + input_list, + group=self.get_partition_dp_group(param), + async_op=True) return handle, input_list[rank] @@ -1404,6 +2124,7 @@ def _partition_gradients(self, param_list, partition_buffers=None, accumulate=Fa self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate) def _partition_gradient(self, param, partition_buffer=None, accumulate=False): + #import pdb;pdb.set_trace() # param.grad=None # param.grad.test() @@ -1420,7 +2141,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): assert partition_buffer.numel( ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" - rank = dist.get_rank(group=self.ds_process_group) + rank = dist.get_rank(group=self.get_partition_dp_group(param)) start = partition_size * rank end = start + partition_size @@ -1464,6 +2185,22 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): param.grad.data = dest_tensor_full_buffer.data see_memory_usage("After partitioning gradients", force=False) + def get_partition_dp_group(self, param): + return param.ds_process_group + + def get_partition_rank(self): + """subclass can overload to specify different relative rank in + parameter partition group""" + return self.rank + + @property + def num_partitions(self): + return self.dp_world_size + + def get_dp_process_group(self): + """ Return the communication group with all data-parallel ranks """ + return self.ds_process_group + class GatheredParameters: @@ -1554,6 +2291,7 @@ def load(module: nn.Module, prefix=""): """ self.enabled = enabled + self._param_versions = None if not enabled: return @@ -1564,13 +2302,16 @@ def load(module: nn.Module, prefix=""): else: # single param params = [params] - # enable if at least one is zero-param, otherwise a noop if not any(is_zero_param(p) for p in params): self.enabled = False return self.params = [p for p in params if hasattr(p, "ds_id")] + self.params = sorted( + set(self.params), key=lambda x: x.ds_id + ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks + self.enable_sanity_checks = getattr(self.params[0], "ds_enable_sanity_checks", False) self.src_rank = None if modifier_rank is not None: if self.params[0].ds_process_group == dist.get_world_group(): @@ -1588,15 +2329,48 @@ def __enter__(self): if not self.enabled: return self.params[0].all_gather(param_list=self.params) + if self.src_rank is None and self.enable_sanity_checks: + self._param_versions = [(p, p.data.data_ptr(), p._version) for p in self.params] def __exit__(self, *exc): if not self.enabled: return if self.src_rank is None: + if self._param_versions: + modified_params = [ + p for p, data_ptr, version in self._param_versions + if p.data.data_ptr() != data_ptr or p._version != version + ] + modified_local = bool(modified_params) + modified_global = modified_local + if dist.is_initialized(): + modified_flag = torch.tensor( + int(modified_local), + device=get_accelerator().current_device_name(), + ) + dist.all_reduce(modified_flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) + modified_global = bool(modified_flag.item()) + if modified_global: + self.params[0].partition(param_list=self.params, has_been_updated=False) + raise RuntimeError( + "Detected in-place modification of ZeRO-3 parameters inside GatheredParameters with " + "modifier_rank=None. Use modifier_rank= to broadcast updates across ranks.") self.params[0].partition(param_list=self.params, has_been_updated=False) return - handles = [dist.broadcast(p, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + # Broadcast parameters from modifier_rank to all other ranks. + # NCCL backend requires tensors to be on GPU. If parameters have been moved to a different + # device (e.g., CPU) inside the context, broadcasting will fail. Users should use + # modifier_rank=None if they don't need to broadcast updates across ranks. + expected_device = torch.device(get_accelerator().current_device_name()) + for p in self.params: + if p.data.device != expected_device: + raise RuntimeError( + f"Parameter {p.ds_id} is on {p.data.device} but broadcast requires it to be on {expected_device}. " + f"When using GatheredParameters with modifier_rank set, parameters must remain on " + f"the accelerator device. If you don't need to broadcast updates, use modifier_rank=None.") + + handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] for h in handles: h.wait() self.params[0].partition(param_list=self.params, has_been_updated=True) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 949c54f5e806..4877b44c8934 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -6,15 +6,24 @@ from dataclasses import dataclass import collections from collections import UserDict +import threading from typing import Deque, Set from deepspeed import comm as dist +from deepspeed.utils import z3_leaf_module from deepspeed.utils.logging import logger from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus -from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id +from deepspeed.utils.debug import debug_param2name_id_shape from deepspeed.accelerator import get_accelerator +import deepspeed.runtime.compiler as compiler +from deepspeed.runtime.compiler import is_compiling + +import logging + +ENABLE_PROFILER = False def debug_rank0(message: str) -> None: @@ -27,6 +36,7 @@ def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +@compiler.enable(min_version="2.7.0") def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse)) @@ -40,18 +50,27 @@ class ZeRoTraceMode(Enum): INVALID = 3 -class PartitionedParameterCoordinator: - """Handles partitioning and gathering of parameters.""" +class InflightParamRegistry(UserDict): + """registry for parameters in flight""" - class __InflightParamRegistry(UserDict): - """registry for parameters in flight""" + def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}") + self.data[param] = handle - def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None: - if param in self.data: - raise RuntimeError(f"{param.ds_summary()} already in registry") - if param.ds_status != ZeroParamStatus.INFLIGHT: - raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}") - self.data[param] = handle + +class PartitionedParameterCoordinator: + FORWARD_FETCH_SUBMIT = 'forward_fetch_submit' + FORWARD_FETCH_WAIT = 'forward_fetch_wait' + FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit' + BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit' + BACKWARD_FETCH_WAIT = 'backward_fetch_wait' + BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_submit' + FORWARD_ALL_GATHER = 'forward_all_gather' + BACKWARD_ALL_GATHER = 'backward_all_gather' + """Handles partitioning and gathering of parameters.""" @dataclass class __ParamInTrace: @@ -64,14 +83,20 @@ def __init__( max_reuse_distance_in_numel: int, max_available_parameters_in_numel: int, allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + fast_sharding_for_leaf_module=False, + log_trace_cache_warnings=False, ) -> None: # mapping of param -> handle for each param that is currently in flight - self.__inflight_param_registry = __class__.__InflightParamRegistry() + self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode - self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID # sequence of submodules/parameters in forward pass + backward pass self.__submodule_order: Iterable[Module] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = [] @@ -88,6 +113,8 @@ def __init__( self.__prefetch_bucket_sz: int = prefetch_bucket_sz self.__prefetch_nvme: bool = prefetch_nvme self.hierarchy: int = 0 + self.zero_quantized_weights = zero_quantized_weights + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights # stream that will be used for allgather operations self.__allgather_stream: get_accelerator().Stream = allgather_stream @@ -104,6 +131,21 @@ def __init__( self.__ongoing_fetch_events: Deque[get_accelerator().Event] = collections.deque() # TODO. make this configurable via JSON self.__max_ongoing_fetch_events: int = 2 + self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + + # Whether to log trace cache warnings, e.g. invalidation events + self.__log_trace_cache_warnings = log_trace_cache_warnings + + # whether to enable fast fetch for the z3 leaf module. + # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure. + self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module + + # Thread synchronization for leaf module fetches during backward pass. + # When autograd executes hooks in multiple threads (e.g., for modules returning multiple tensors), + # we need to ensure only one thread fetches parameters for a given leaf module at a time. + # This is only needed during backward pass; forward pass is single-threaded. + self.__ongoing_fetch_leaf_module_events = collections.defaultdict(threading.Event) + self.__leaf_module_lock = threading.Lock() """Tracing and Tracking TODO. consider performing trace before initializing PartitionedParameterCoordinator @@ -129,46 +171,59 @@ def is_invalid_trace(self) -> bool: def is_record_trace(self) -> bool: return self.__trace_mode == ZeRoTraceMode.RECORD + def _clean_inflight_param_registry(self) -> None: + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() + def _invalidate_trace(self) -> None: if self.is_invalid_trace(): raise RuntimeError("attempted to invalidate already invalid trace") self.__trace_mode = ZeRoTraceMode.INVALID self._clear_trace_structures() + self._clean_inflight_param_registry() def trace_prologue(self, sub_module: Module) -> None: if self.is_complete_trace(): # sub_module must match expectation else invalidate trace cache if len(self.__submodule_order) <= self.__step_id: print_rank_0( - f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: " + f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: " f"cache has only {len(self.__submodule_order)} modules", - force=True) + force=self.__log_trace_cache_warnings) self._invalidate_trace() return if sub_module != self.__submodule_order[self.__step_id]: - expected_module_id = self.__submodule_order[self.__step_id].id + expected_module_id = self.__submodule_order[self.__step_id].ds_id print_rank_0( f"Invalidate trace cache @ step {self.__step_id}: " - f"expected module {expected_module_id}, but got module {sub_module.id}", - force=True) + f"expected module {expected_module_id}, but got module {sub_module.ds_id}", + force=self.__log_trace_cache_warnings) self._invalidate_trace() + @compiler.enable(min_version="2.7.0") def record_module(self, sub_module: Module) -> None: """adds sub module to trace""" + if is_compiling(): + return + if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") self.__submodule_order.append(sub_module) - self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) + self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: + if is_compiling(): + return """adds sub module to trace""" if not self.is_record_trace(): raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") - step_id = self.__step_id_module_fetched_for[sub_module.id].popleft() - for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft() + for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id): self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id)) def construct_parameter_trace_from_module_trace(self): @@ -177,15 +232,17 @@ def construct_parameter_trace_from_module_trace(self): for sub_module in self.__submodule_order: self.record_parameters(sub_module) + @compiler.disable def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError(f"still have inflight params " - f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + if is_compiling(): + return + + self._clean_inflight_param_registry() if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks - assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order]) if self.is_record_trace(): # Successfully recorded a trace @@ -198,75 +255,143 @@ def reset_step(self) -> None: self.__param_order = tuple(self.__param_order) # freeze self.__trace_mode = ZeRoTraceMode.COMPLETE print_rank_0( - f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}", + f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}", force=False) else: # Enable trace recording for next forward/backward pass self.__trace_mode = ZeRoTraceMode.RECORD + else: + if self.__profiler is not None: + self.__profiler.log_events() + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) self.__step_id = 0 self.__n_available_params = 0 + self.__profiler.reset_events() + # Clear leaf module fetch events for clean state + self.__ongoing_fetch_leaf_module_events.clear() def _dump_params(self, tag, sub_module, params, step_id=None): if step_id is None: step_id = self.__step_id - param_names = [debug_param2name_id(p) for p in params] - print(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}') + param_names = [debug_param2name_id_shape(p) for p in params] + print_rank_0(f'{tag} step = {step_id} p_names = {param_names}', force=False) def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None): if step_id is None: step_id = self.__step_id - print(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}') + print_rank_0(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}', force=False) """Fetch and Release Fetching, prefetching, and releasing parameters """ + @compiler.disable @instrument_w_nvtx @torch.no_grad() - def fetch_sub_module(self, current_submodule: Module) -> None: + def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: """This method does the following (in order): 1. kick off fetch for parameters in immediately required sub module 2. kick off fetch for next few parameters we will need later (prefetch) 3. block on parameters in immediately required sub module """ - debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " - + str({ - "avail": f"{self.__n_available_params:.1e}", - "queue_sz": f"{len(self.__param_queue or [])}", - "inflight": [p.ds_id for p in self.__inflight_param_registry], - })) - - params_to_fetch = frozenset(iter_params(current_submodule)) - - # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") - self.__all_gather_params(params_to_fetch) + # For leaf modules during backward pass, autograd may trigger hooks from multiple + # threads concurrently (e.g., when a module returns multiple tensors). We need to + # serialize access to prevent race conditions in parameter state management. + # Forward pass is single-threaded, so no synchronization is needed there. + is_leaf = z3_leaf_module(current_submodule) + needs_sync = is_leaf and not forward + if needs_sync: + event_to_wait = None + with self.__leaf_module_lock: + event = self.__ongoing_fetch_leaf_module_events.get(current_submodule.ds_id) + if event is not None: + # Another thread is already fetching this leaf module, wait for it + event_to_wait = event + else: + # Mark that we're starting a fetch for this leaf module + new_event = threading.Event() + self.__ongoing_fetch_leaf_module_events[current_submodule.ds_id] = new_event + + if event_to_wait is not None: + # Wait outside the lock to avoid deadlock + event_to_wait.wait() + return + try: + self._fetch_sub_module_impl(current_submodule, forward, is_leaf) + finally: + if needs_sync: + # Signal that we're done fetching this leaf module and remove the event + with self.__leaf_module_lock: + event = self.__ongoing_fetch_leaf_module_events.pop(current_submodule.ds_id, None) + if event is not None: + event.set() + + def _fetch_sub_module_impl(self, current_submodule: Module, forward: bool, is_leaf: bool) -> None: + """Implementation of fetch_sub_module, separated for thread synchronization.""" + if logger.isEnabledFor(logging.DEBUG): + debug_rank0( + f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=is_leaf)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + })) + + params_to_fetch = set(iter_params(current_submodule, recurse=is_leaf)) + fetch_numel = sum( + [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + + if fetch_numel > 0: + event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT + self._dump_param_ids(event_name, current_submodule.ds_id, + [(p.ds_id, p.ds_shape) + for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + # self._dump_params(event_name, current_submodule, [p for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) + + self.__profiler.start_event(event_name) + # kick off all gather for params in the immediately required submodule + #for param in params_to_fetch: + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch, forward) + self.__profiler.stop_event(event_name, fetch_numel) + + wait_numel = 0 + wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT + self.__profiler.start_event(wait_event_name) + fast_fetch = self.fast_sharding_for_leaf_module and is_leaf # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: - param.ds_active_sub_modules.add(current_submodule.id) - debug_rank0(f"-wait: {param.ds_summary()}") + param.ds_active_sub_modules.add(current_submodule.ds_id) + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: + wait_numel += param.partition_numel() with get_accelerator().stream(self.__allgather_stream): while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query(): self.__ongoing_fetch_events.popleft() if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch) - event = get_accelerator().Event() - event.record() - self.__ongoing_fetch_events.append(event) + if not get_accelerator().handles_memory_backpressure() and not fast_fetch: + event = get_accelerator().Event() + event.record() + self.__ongoing_fetch_events.append(event) assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if not get_accelerator().resolves_data_dependency(): + get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if fast_fetch: + AllGatherCoalescedHandle.free_buffer() + self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules # don't prefetch if we dont have a completed model trace @@ -287,7 +412,7 @@ def fetch_sub_module(self, current_submodule: Module) -> None: if discarded_from_prefetch_queue != params_not_already_fetched: raise RuntimeError( f"tracing error at step {self.__step_id}: \n" - f"module id: {current_submodule.id}, training: {current_submodule.training}\n" + f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n" f"expected the next {len(params_not_already_fetched)} parameters in the " f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n" f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.") @@ -326,9 +451,14 @@ def _is_currently_on_nvme(param): params_to_prefetch.add(param_in_trace.param) numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) + if numel_prefetching > 0: + event_name = __class__.FORWARD_PREFETCH_SUBMIT if forward else __class__.BACKWARD_PREFETCH_SUBMIT + self.__profiler.start_event(event_name) + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch, forward) + self.__profiler.stop_event(event_name, numel_prefetching) if self.__prefetch_nvme: self.__prefetch_nvme_param_partitions() @@ -337,15 +467,26 @@ def _is_currently_on_nvme(param): @instrument_w_nvtx @torch.no_grad() - def release_sub_module(self, submodule: Module) -> None: + def release_sub_module(self, submodule: Module, forward=False) -> None: """release the parameters of a sub module, assuming they meet conditions to be released.""" + #print_rank_0(f"release_sub_module {'fwd' if forward else 'bwd'}: {debug_module2name_id(submodule)}", force=False) params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( - p.ds_id for p in iter_params(submodule))) - for param in iter_params(submodule): - param.ds_active_sub_modules.discard(submodule.id) + p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + + free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module + if not free_data: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, device=torch.device(get_accelerator().current_device_name())) + + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): + param.ds_active_sub_modules.discard(submodule.ds_id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data) + if not free_data: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -353,7 +494,7 @@ def release_and_reset_all(self, module: Module) -> None: """release all module parameters""" for param in iter_params(module, recurse=True): if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") + self.__inflight_param_registry.pop(param).wait() # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue @@ -365,22 +506,51 @@ def release_and_reset_all(self, module: Module) -> None: raise RuntimeError(f"{param.ds_summary()} expected to be released") @instrument_w_nvtx - def __all_gather_params(self, params: Set[Parameter]) -> None: + def __all_gather_params(self, params: Set[Parameter], forward: bool) -> None: + quantized_params = [] + nonquantized_params = [] + for param in params: + if hasattr(param.ds_tensor, 'ds_quant_scale'): + quantized_params.append(param) + else: + nonquantized_params.append(param) + if quantized_params: + self.__all_gather_params_(quantized_params, forward, quantize=True) + if nonquantized_params: + self.__all_gather_params_(nonquantized_params, forward, quantize=self.zero_quantized_weights) + + def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: bool = False) -> None: """for each partitioned parameter, kick off an async allgather and store the work handle for the in flight parameters.""" partitioned_params = [] + all_gather_numel = 0 # numel = num of elements for param in params: if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: partitioned_params.append(param) - self.__n_available_params += param.ds_numel + all_gather_numel += param.ds_numel if partitioned_params: - with get_accelerator().stream(self.__allgather_stream): - handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - - for param in partitioned_params: - assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() - self.__inflight_param_registry[param] = handle + self.__n_available_params += all_gather_numel + # here we need to handle a special case where some of the parameters have a valid hpz secondary tensor (e.g. they are not trainable so their secondary tensor never expire) but others do not. + partitioned_params_with_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is not None + ] + partitioned_params_without_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is None + ] + for param_group in [ + partitioned_params_with_secondary_tensors, partitioned_params_without_secondary_tensors + ]: + if not param_group: + continue + with get_accelerator().stream(self.__allgather_stream): + event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER + self.__profiler.start_event(event_name) + handle = param_group[0].all_gather_coalesced(param_group, quantize=quantize) + self.__profiler.stop_event(event_name, all_gather_numel) + for param in param_group: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle # Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU swap_persisted_params = [ @@ -389,11 +559,14 @@ def __all_gather_params(self, params: Set[Parameter]) -> None: if swap_persisted_params: swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params) + @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: - debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-release: {param.ds_summary()}") + print_rank_0(f"release: {debug_param2name_id_shape(param)}", force=False) + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx @@ -402,7 +575,9 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set if not self.is_complete_trace(): raise RuntimeError("expected trace to be complete") - params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) if not p.ds_persist) + params_to_release = set( + p.ds_id for p in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)) + if not p.ds_persist) # Problem: When prefetcher scans the param trace, it skips AVAILABLE params. # This creates issues if those params are released before the skipped uses: @@ -411,7 +586,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set # diverges from the trace. # Solution: Don't release params whose reuse was skipped by prefetch. This is # possible because we detect such skips during prefetch and mark those params. - for param in iter_params(submodule_to_release): + for param in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)): if self.__most_recent_step_id_param_fetched_for[param] > step_id: params_to_release.discard(param.ds_id) @@ -422,7 +597,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set for module in self.__submodule_order[step_id:]: if params_traversed >= self.__max_reuse_dist_in_numel: break - for param in iter_params(module): + for param in iter_params(module, recurse=z3_leaf_module(submodule_to_release)): params_to_release.discard(param.ds_id) params_traversed += param.ds_numel diff --git a/deepspeed/runtime/zero/partitioned_param_profiler.py b/deepspeed/runtime/zero/partitioned_param_profiler.py new file mode 100644 index 000000000000..b4ea11f3b836 --- /dev/null +++ b/deepspeed/runtime/zero/partitioned_param_profiler.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass +from deepspeed.utils import log_dist + + +class PartitionedParameterProfiler(object): + + @dataclass + class EventCounter: + name: str + count: int + num_elem: int + + def reset(self): + self.count = 0 + self.num_elem = 0 + + def increment(self, numel): + self.count += 1 + self.num_elem += numel + + def __init__(self, timers): + self.timers = timers + self.event_counters = {} + + def reset_events(self): + for event_ctr in self.event_counters.values(): + event_ctr.reset() + + def start_event(self, name): + if self.timers is None: + return + + if name not in self.event_counters: + self.event_counters[name] = __class__.EventCounter(name=name, count=0, num_elem=0) + self.timers(name).start() + + def stop_event(self, name, num_elem): + if self.timers is None: + return + assert name in self.event_counters, f'unknown event {name}' + self.event_counters[name].increment(num_elem) + self.timers(name).stop() + + def _log_timers(self): + if self.timers is None: + return + self.timers.log(names=list(self.event_counters.keys())) + + def _log_event_counters(self): + for event_ctr in self.event_counters.values(): + log_dist( + f'{event_ctr.name}: count = {event_ctr.count}, numel = {event_ctr.num_elem}', + #f'{event_ctr.name}: time = {self._log_timers()},count = {event_ctr.count}, numel = {event_ctr.num_elem}', + ranks=[0]) + + def log_events(self): + self._log_event_counters() + self._log_timers() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e3b6be65ed2b..aeab5d5734a9 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -6,29 +6,48 @@ import sys import gc import collections -from typing import Deque, Dict, Tuple +import itertools +from typing import Deque, Dict, Set, List, Container, Optional +from contextlib import contextmanager +from dataclasses import dataclass, field -from deepspeed.runtime import ZeROOptimizer +from deepspeed import comm as dist +from deepspeed.utils import groups, z3_leaf_parameter + +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook, required_torch_version from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter +from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum -from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3 +from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer, defragment +from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False +OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state' +INIT_OPTIMIZER_TIMER = 'init_optimizer_state' +OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state' +OPTIMIZER_STEP_TIMER = 'optimizer_step' + def print_rank_0(message, debug=False, force=False): rank = dist.get_rank() @@ -50,7 +69,7 @@ def isclose(a, b, rtol=1e-09, atol=0.0): def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 + from math import gcd return x * y // gcd(x, y) @@ -59,6 +78,58 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() +@contextmanager +def unwrap_model_for_generation(model): + """ + For ZeRO-3 models, we gather the weights once to speed up generation. + """ + with GatheredParameters(model.parameters()): + # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model. + + # Remove hooks + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + yield model + + # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model. + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + optimizer_offload._register_deepspeed_module(optimizer_offload.module) + return + + +@dataclass +class IPGBucketZ3: + buffer: Optional[torch.Tensor] = None + buffer_meta: Optional[torch.Tensor] = None + params: List[torch.Tensor] = field(default_factory=list) + elements: int = 0 + + def clear(self): + self.buffer = None + self.buffer_meta = None + self.params.clear() + self.elements = 0 + + def clear_params(self): + """Clear params and elements but keep buffer for reuse.""" + self.params.clear() + self.elements = 0 + + INITIAL_MICRO_STEP_ID = -1 @@ -74,40 +145,59 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): """ - def __init__(self, - module, - init_optimizer, - timers, - ds_config, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True, - contiguous_gradients=True, - reduce_bucket_size=500000000, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - model_persistence_threshold=sys.maxsize, - dp_process_group=None, - reduce_scatter=True, - overlap_comm=False, - offload_optimizer_config=None, - offload_param_config=None, - sub_group_size=1000000000000, - mpu=None, - clip_grad=0.0, - communication_data_type=torch.float16, - postscale_gradients=True, - gradient_predivide_factor=1.0, - gradient_accumulation_steps=1, - elastic_checkpoint=False, - aio_config=None): - - see_memory_usage("Stage 3 initialize beginning", force=True) + def __init__( + self, + module, + init_optimizer, + param_names, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + zenflow_config=None, + sub_group_size=1000000000000, + offload_ratio=0.0, + mpu=None, + clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, + communication_data_type=torch.float16, + fp16_master_weights_and_gradients=False, + bf16_master_weights_and_gradients=False, + bf16_optimizer_states=False, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None, + all2all_process_group=None, + zero_hpz_partition_size=1, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + zero_module_granularity_threshold=0, + zeropp_loco_param=None, + log_trace_cache_warnings=False, + enable_sanity_checks=False, + cpuadam_cores_perc=0.8, + save_muon_momentum_buffer_in_memory=False, + ): + see_memory_usage("Stage 3 initialize beginning", force=False) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) + super().__init__() if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") @@ -125,12 +215,13 @@ def __init__(self, raise SystemError("Cannot use fp16 without accelerator.") self.optimizer = init_optimizer + self.param_names = param_names - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten + # Use torch (un)flatten ops + self.flatten = _flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self.gradient_accumulation_dtype = gradient_accumulation_dtype self._global_grad_norm = 0. self.custom_loss_scaler = False @@ -146,52 +237,118 @@ def __init__(self, self.offload_param_pin_memory = False self.params_in_nvme_and_cpu = False self.max_params_in_cpu = 0 + self.partial_offload = offload_ratio + self.enable_sanity_checks = enable_sanity_checks + + self.create_zenflow_hooks() + self._initialize_zenflow_stage3_prologue(module, zenflow_config) - self.parameter_offload = DeepSpeedZeRoOffload(module=module, - timers=timers, - ds_config=ds_config, - overlap_comm=overlap_comm, - prefetch_bucket_size=prefetch_bucket_size, - max_reuse_distance=max_reuse_distance, - max_live_parameters=max_live_parameters, - param_persistence_threshold=param_persistence_threshold, - model_persistence_threshold=model_persistence_threshold, - offload_param_config=offload_optimizer_config, - mpu=mpu) + #num of ranks in a ZeRO param partitioning group + self.zero_hpz_partition_size = zero_hpz_partition_size + + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + print_rank_0( + f"ZeRO Stage 3 param partitioning group {self.zero_hpz_partition_size} {zero_param_parallel_group}", + force=False) + if self.zero_hpz_partition_size > 1 and zero_param_parallel_group is None: + self._set_zero_group_parallelism() + zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + + self.parameter_offload = self.initialize_ds_offload( + module=module, + timers=timers, + ds_config=ds_config, + zenflow=self.zenflow, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_module_granularity_threshold=zero_module_granularity_threshold, + log_trace_cache_warnings=log_trace_cache_warnings, + ) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) + def _enforce_optimizer_offload(): + assert self.offload_optimizer and type(self.optimizer) == DeepSpeedCPUAdam, \ + "Master weights feature requires ZeRO-3 Offload with DeepSpeedCPUAdam. " \ + f"Current ZeRO-3 Offload:{self.offload_optimizer} optimizer type {type(self.optimizer)}." + + self.master_weights_and_grads_dtype = self._configure_master_weights( + fp16_master_weights_and_gradients=fp16_master_weights_and_gradients, + bf16_master_weights_and_gradients=bf16_master_weights_and_gradients, + bf16_optimizer_states=bf16_optimizer_states, + offload_enabled=self.offload_optimizer, + fp16_offload_validator=_enforce_optimizer_offload, + bf16_offload_validator=_enforce_optimizer_offload) + + # backup fused_adam optimizer init + if self.offload_optimizer and self.partial_offload != 1.0: + backup_gpu_tensor = torch.randn(1, device=get_accelerator().device_name()).to(self.dtype) + backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) + assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' + self.backup_optimizer = torch.optim.AdamW([backup_gpu_param], + lr=self.optimizer.param_groups[0]["lr"], + betas=self.optimizer.param_groups[0]["betas"], + eps=self.optimizer.param_groups[0]["eps"], + weight_decay=self.optimizer.param_groups[0]["weight_decay"], + amsgrad=self.optimizer.param_groups[0]["amsgrad"]) + # Multiple param_groups configs for back-up optimizer + if len(self.optimizer.param_groups) > 1: + for i in range(1, len(self.optimizer.param_groups)): + self.backup_optimizer.add_param_group(self.optimizer.param_groups[i]) + + self._initialize_zenflow_stage3_epilogue(zenflow_config, overlap_comm) + self.module = module self.elastic_checkpoint = elastic_checkpoint - self.__inf_or_nan_tracker: Tensor = torch.zeros(1, - dtype=torch.bool, - device=get_accelerator().current_device_name(), - requires_grad=False) + self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu + + self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False) self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) - self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication - self.__reduce_and_partition_stream = get_accelerator().Stream() if overlap_comm else get_accelerator( - ).default_stream() + self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator( + ).Stream() if overlap_comm else get_accelerator().default_stream() ############################################################################ - self.__n_caching_allocator_flushes = 0 + self.n_caching_allocator_flushes = 0 #-------------Stage 3 Setup-------------------# self.timers = timers + self.all2all_process_group = all2all_process_group + self.reduce_scatter = reduce_scatter + self.use_muon = isinstance(self.optimizer, MuonWithAuxAdam) + self.save_muon_momentum_buffer_in_memory = save_muon_momentum_buffer_in_memory + if self.use_muon and self.reduce_scatter: + raise ValueError("Muon and reduce scatter cannot be used together") + if self.use_muon and self.all2all_process_group is not None: + raise ValueError("Muon and all2all process group cannot be used together") + self.dp_process_group = self.parameter_offload.dp_process_group + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() - self.dp_process_group = dp_process_group + self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights self.partition_count = dist.get_world_size(group=self.dp_process_group) - if mpu is None: + self.zeropp_loco_param = zeropp_loco_param + + if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'): self.model_parallel_group = None self.model_parallel_rank = 0 else: @@ -207,6 +364,9 @@ def __init__(self, self.micro_step_id = 0 self.reduce_bucket_size = int(reduce_bucket_size) + if self.all2all_process_group is not None: + assert self.all2all_process_group is not None and self.reduce_scatter == True, "when enable all_to_all_reduce, reduce_scatter should also be enabled for data type checks." + if self.reduce_scatter: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"ZeRO-3 supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" @@ -224,6 +384,7 @@ def __init__(self, # Holds a fused and flattened copy of the parameters self.fp16_partitioned_groups_flat = [] self.fp16_partitioned_groups_flat_numel = [] + self.fp16_partitioned_groups_flat_id = [] #defragmented pinned memory self.param_groups_fp16_flat_cpu_memory = [] @@ -231,6 +392,8 @@ def __init__(self, #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.fp32_partitioned_groups_flat = [] + if self.use_muon and self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat = {} self.next_swappable_fp32_partitioned_groups = [] # number of elements per partition in each group @@ -252,53 +415,54 @@ def __init__(self, # Trainable parameters self.trainable_param_groups = self._get_trainable_parameter_groups() - see_memory_usage("Before creating fp16 partitions", force=True) + see_memory_usage("Before creating fp16 partitions", force=False) self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups) num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) - see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", force=True) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", force=False) # Optimizer tensor swapping if self.swap_optimizer: self._configure_tensor_swapping(offload_optimizer_config, aio_config) - self.__params_in_ipg_bucket: List[Parameter] = [] self.is_gradient_accumulation_boundary: bool = True - self.__param_reduce_events: Deque[get_accelerator().Event] = collections.deque() + self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque() # TODO. make this configurable via JSON - self.__max_param_reduce_events: int = 2 + self.max_param_reduce_events: int = 2 self.param_dict = {} # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} - self.extra_large_param_to_reduce = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] + self.torch_autocast_gradscaler = None + if is_autocast_initialized(): + comm_dtypes = get_all_comm_dtypes([p for params in self.fp16_groups for p in params]) + if get_autocast_dtype() == torch.float16: + self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) + else: + comm_dtypes = {self.communication_data_type} - self.params_already_reduced = [] - self.is_gradient_accumulation_boundary = True - self._release_ipg_buffers() - self.previous_reduced_grads = None + self.ipg_buckets: Dict[torch.dtype, IPGBucketZ3] = {dtype: IPGBucketZ3() for dtype in comm_dtypes} - # simplified param id - self.param_id = {} + self.params_already_reduced = {} + self.previous_reduced_grads = None - count = 0 - for i, params_group in enumerate(self.fp16_groups): + # model parameter traversal-based param id that's stable across runs + for params_group in self.fp16_groups: for param in params_group: - unique_id = id(param) - self.param_id[unique_id] = count - self.param_dict[count] = param - self.params_already_reduced.append(False) - count = count + 1 + param_id = self.get_param_id(param) + self.param_dict[param_id] = param + self.params_already_reduced[param_id] = False #Largest partitioned param - largest_partitioned_param_numel = max([ - max([max(tensor.numel(), tensor.ds_numel) for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups - ]) + largest_partitioned_param_numel = 0 + for fp16_partitioned_group in self.fp16_partitioned_groups: + if len(fp16_partitioned_group) > 0: + largest_partitioned_param_numel = max( + largest_partitioned_param_numel, + max([max(tensor.numel(), tensor.ds_numel) for tensor in fp16_partitioned_group])) + print_rank_0(f'Largest partitioned param numel = {largest_partitioned_param_numel}', force=False) self._setup_for_real_optimizer() @@ -307,7 +471,6 @@ def __init__(self, if self.offload_optimizer: self.norm_for_param_grads = {} - self.local_overflow = False # stores if a partition has been reduced in this step self.is_partition_reduced = {} @@ -319,6 +482,9 @@ def __init__(self, self.averaged_gradients = {} #creates backward hooks for gradient partitioning + ###Calls all gather param + self._grad_acc_hooks = [] + self._leaf_module_hooks = [] self.create_reduce_and_remove_grad_hooks() #exit(0) @@ -334,57 +500,146 @@ def __init__(self, self._link_all_hp_params() + self.offloaded_states: Set[OffloadDeviceEnum] = set() + if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=True) + see_memory_usage("After initializing ZeRO optimizer", force=False) def destroy(self): self.parameter_offload.destroy() + for hook in self._grad_acc_hooks: + hook.remove() + for hook in self._leaf_module_hooks: + hook.remove() + print_rank_0("Removed grad acc hooks", force=False) + self.ipg_buckets.clear() + + def create_zenflow_hooks(self): + from functools import partial + hook_names = [ + "_initialize_zenflow_stage3_prologue", + "_initialize_zenflow_stage3_epilogue", + "zenflow_cpu_optimizer_step", + "_sync_selective_optimizer_lr", + "selective_optimizer_step", + "is_zenflow_select_boundary", + "update_selected_channels", + "_process_selected_fp32_groups_grad", + "zenflow_backward_prologue", + "zenflow_backward_epilogue", + "log_selective_optimizer_timers", + ] + + for name in hook_names: + fn = getattr(zf_engine_stage3, name) + setattr(self, name, partial(fn, self)) + + def initialize_ds_offload( + self, + module, + timers, + ds_config, + zenflow, + overlap_comm, + prefetch_bucket_size, + max_reuse_distance, + max_live_parameters, + param_persistence_threshold, + model_persistence_threshold, + dp_process_group, + offload_param_config, + mpu, + zero_param_parallel_group, + zero_quantized_weights, + zero_quantized_nontrainable_weights, + zero_module_granularity_threshold, + log_trace_cache_warnings, + ): + return DeepSpeedZeRoOffload(module=module, + timers=timers, + ds_config=ds_config, + zenflow=zenflow, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_module_granularity_threshold=zero_module_granularity_threshold, + log_trace_cache_warnings=log_trace_cache_warnings) def _get_trainable_parameter_groups(self): param_groups = [] + PARAMS_KEY = "params" for param_group in self.optimizer.param_groups: - trainable_params = {"params": [p for p in param_group["params"] if p.requires_grad]} - param_groups.append(trainable_params) + trainable_params = [p for p in param_group[PARAMS_KEY] if p.requires_grad] + if len(trainable_params) == 0: + continue + + trainable_param_group = {} + for key in param_group.keys(): + if key == PARAMS_KEY: + trainable_param_group[PARAMS_KEY] = trainable_params + else: + trainable_param_group[key] = param_group[key] + param_groups.append(trainable_param_group) + return param_groups + def _set_zero_group_parallelism(self): + groups._create_zero_param_parallel_group(self.zero_hpz_partition_size) + + def invalidate_secondary_tensor(self): + for fpg in self.fp16_groups: + for param in fpg: + if param.ds_secondary_tensor is not None: + param.ds_secondary_tensor = None + def _setup_for_real_optimizer(self): - see_memory_usage("Before creating fp32 partitions", force=True) + see_memory_usage("Before creating fp32 partitions", force=False) self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=True) + see_memory_usage("After creating fp32 partitions", force=False) dist.barrier() # To support pipelined optimizer swapping self._create_next_swappable_fp32_groups() - see_memory_usage("Before initializing optimizer states", force=True) + see_memory_usage("Before initializing optimizer states", force=False) self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=True) + see_memory_usage("After initializing optimizer states", force=False) dist.barrier() if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") + logger.info("optimizer state initialized") # IPG if self.contiguous_gradients: - self.__ipg_bucket_flat_buffer: Tensor = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=get_accelerator().current_device_name()) + for dtype, bucket in self.ipg_buckets.items(): + bucket.buffer = torch.empty(self.reduce_bucket_size, + dtype=dtype, + device=get_accelerator().current_device_name()) - grad_partitions_flat_buffer = None + self.grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), - dtype=self.dtype, - device=self.device) + self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), + dtype=self.gradient_accumulation_dtype, + device=self.device) if self.offload_optimizer_pin_memory: - grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer) + self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 for param in all_params: - self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow( + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() @@ -401,46 +656,8 @@ def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]["lr"] - # TODO. factor out to a utility outside of stage3 - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - - gc.collect() - get_accelerator().empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - - def _get_param_coordinator(self, training): - return self.parameter_offload.get_param_coordinator(training) + def _get_param_coordinator(self): + return self.parameter_offload.get_param_coordinator() def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## @@ -464,7 +681,7 @@ def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): nvme_swap_folder = os.path.join(offload_optimizer_config.nvme_path, 'zero_stage_3') os.makedirs(nvme_swap_folder, exist_ok=True) if dist.get_rank() == 0: - logger.info(f'Tensor Swapping: Adding optimizer tensors') + logger.info('Tensor Swapping: Adding optimizer tensors') swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config.pipeline else PartitionedOptimizerSwapper @@ -474,13 +691,9 @@ def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): optimizer=self.optimizer, largest_numel=max(self.fp16_partitioned_groups_flat_numel), device=self.device, - dtype=torch.float32, + dtype=self.master_weights_and_grads_dtype, timers=self.timers) - @property - def elements_in_ipg_bucket(self): - return sum(p.ds_numel for p in self.__params_in_ipg_bucket) - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): '''If flat buffer is None then the parameters in the param_list are not copied to the flat buffer. This is because they exceed the number of max_params_in_cpu @@ -544,6 +757,21 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): param_groups: List[List[Parameter]] = tuple( self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups) + if self.use_muon: + self.sub_groups_using_muon = [] + self.muon_beta = None + self.muon_ns_method = None + for idx, param_group in enumerate(fp16_param_groups): + if getattr(param_group['params'][0], 'use_muon', False): + self.sub_groups_using_muon.extend([True] * len(param_groups[idx])) + group_beta = param_group['momentum'] + if self.muon_beta is not None and self.muon_beta != group_beta: + raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " + f"Found {self.muon_beta} and {group_beta}.") + self.muon_beta = group_beta + self.muon_ns_method = param_group.get('ns_method', 'gram') + else: + self.sub_groups_using_muon.extend([False] * len(param_groups[idx])) # bookkeeping related to param groups for param_group_idx, param_group in enumerate(param_groups): for sub_group in param_group: @@ -553,12 +781,19 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): self.fp16_groups.append(sub_group) self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group]) + if self.zenflow: + for param in sub_group: + param.group_id = param_group_idx + # record sub group -> group mapping self.sub_group_to_group_id[sub_group_idx] = param_group_idx # record total elements of parameter partitions in sub group self.fp16_partitioned_groups_flat_numel.append(sum(p.partition_numel() for p in sub_group)) + # record ds_ids of parameter partitions in sub group + self.fp16_partitioned_groups_flat_id.append([p.ds_id for p in sub_group]) + # record padding required to align group to world size (only applies to last rank) rank_requires_padding = dist.get_rank( self.dp_process_group) == dist.get_world_size(self.dp_process_group) - 1 @@ -567,19 +802,12 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): # move parameters to flattened buffer if not self.offload_param: # partitioned params remain in GPU during training # move parameter partitions into a single contiguous flat buffer - parameter_partitions: List[Tensor] = [] - for sub_group in self.fp16_groups: - for param in sub_group: - parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) + parameter_partitions = self._get_parameter_partitions() + + # We need to keep the reference to this buffer to make sure you can free it in `offload_states` + self.lp_param_buffer = defragment(parameter_partitions) + self._set_fp16_partitioned_groups_flat() - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.partition_numel() for param in sub_group) - self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel)) - offset += sub_group_numel else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group self._create_param_groups_fp16_flat_cpu_memory() @@ -622,9 +850,12 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): largest_partition_numel = [t.ds_numel for t in sub_group] max_partition_numel = total_elements - assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + assert len(largest_partition_numel) > 0, 'Unexpected that largest partition is empty' self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(largest_partition_numel) + def _get_parameter_partitions(self) -> List[Tensor]: + return [param.ds_tensor for sub_group in self.fp16_groups for param in sub_group] + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): offset = 0 elements_in_sub_group = sum([t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) @@ -668,6 +899,20 @@ def _get_sub_group_partitions(self, sub_group_id): return sub_group_partitions + def _create_momentum_buffer(self, num_elements, i, ds_id): + if self.use_muon and self.sub_groups_using_muon[i]: + unpinned_fp32_buffer_momentum = torch.zeros(num_elements, + device=self.device, + dtype=self.communication_data_type) + unpinned_fp32_buffer_momentum.requires_grad = False + if self.fp32_partitioned_groups_flat[i] not in self.optimizer.state: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]] = {} + self.optimizer.state[ + self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] = unpinned_fp32_buffer_momentum + if self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat[i] = unpinned_fp32_buffer_momentum + self.muon_momentum_buffer_partitioned_groups_flat[i].ds_id = ds_id + def _create_fp32_partitions(self): cpu_memory_usage = 0 cpu_memory_sub_groups = 0 @@ -684,16 +929,34 @@ def _create_fp32_partitions(self): nvme_fp16_partitions_info = [] nvme_fp16_num_elems = [] nvme_fp32_dest_tensors = [] - fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + fp32_element_size = torch.tensor([], dtype=self.master_weights_and_grads_dtype).element_size() + + # Assign portion of subgroup to cpu, the other to gpu. + if self.offload_optimizer: + self.subgroup_to_device = {} + sub_group_size = len(self.fp16_partitioned_groups_flat) + # print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n") + for i in range(sub_group_size): + if i >= int((1 - self.partial_offload) * sub_group_size): + self.subgroup_to_device[i] = 'cpu' + else: + self.subgroup_to_device[i] = get_accelerator()._name for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] + ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) + ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) + ds_id = ds_id_begin + '_' + ds_id_end # a partition of the fp32 master weights that will be updated by this process if self._swappable_optimizer_subgroup(i): - self.fp32_partitioned_groups_flat.append(torch.Tensor()) + self.fp32_partitioned_groups_flat.append(torch.empty(0, dtype=self.master_weights_and_grads_dtype)) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 + if not (self.use_muon and self.sub_groups_using_muon[i] + and not self.save_muon_momentum_buffer_in_memory): + self._create_momentum_buffer(num_elements, i, ds_id) if self.params_in_nvme_and_cpu and tensor is None: num_swap_from_nvme_partitions += 1 @@ -704,7 +967,9 @@ def _create_fp32_partitions(self): nvme_fp16_num_elems.append(num_elements) nvme_fp32_dest_tensors.append(self.fp32_partitioned_groups_flat[i]) else: - unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float) + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=self.master_weights_and_grads_dtype) self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.optimizer_swapper.initialize_parameters(parameters=[self.fp32_partitioned_groups_flat[i]], src_tensors=[unpinned_fp32_buffer]) @@ -718,12 +983,30 @@ def _create_fp32_partitions(self): cpu_memory_sub_groups += 1 if self.params_in_nvme_and_cpu and tensor is None: - unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float) + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=self.master_weights_and_grads_dtype) self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + self._create_momentum_buffer(num_elements, i, ds_id) + elif self.offload_optimizer: + converted = self.fp16_partitioned_groups_flat[i].to(self.subgroup_to_device[i], + dtype=self.master_weights_and_grads_dtype) + self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) + elif self.fp16_partitioned_groups_flat[i].dtype == self.master_weights_and_grads_dtype and \ + self.fp16_partitioned_groups_flat[i].device == self.device: + # When torch autocast is enabled, weights in the provided model (and thus groups in the so-called + # "fp16" partitioned groups) are already in and updated using fp32. In such cases we don't need + # another copy of the weights. + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i]) + self._create_momentum_buffer(num_elements, i, ds_id) else: - self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) + converted = self.fp16_partitioned_groups_flat[i].to(self.device, + dtype=self.master_weights_and_grads_dtype) + self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it @@ -784,24 +1067,41 @@ def _create_fp16_sub_groups(self, params_group): return sub_groups - def _release_ipg_buffers(self): - if self.contiguous_gradients: - self.ipg_buffer = None - def _optimizer_step(self, sub_group_id): param_group_id = self.sub_group_to_group_id[sub_group_id] fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] + def step_with_gradscaler(optimizer): + if self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.step(optimizer) + self.torch_autocast_gradscaler.update() + else: + if not self.zenflow: + optimizer.step() + else: + self.zenflow_cpu_optimizer_step() + + if self.offload_optimizer: + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device == 'cpu': + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + step_with_gradscaler(self.optimizer) + self.optimizer.param_groups[param_group_id]['params'] = [] + else: + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + step_with_gradscaler(self.backup_optimizer) + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + else: + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + step_with_gradscaler(self.optimizer) + self.optimizer.param_groups[param_group_id]['params'] = [] def _swappable_optimizer_subgroup(self, sub_group_id): if not self.swap_optimizer: return False - return self.optimizer_swapper.swappable_tensor(None, - numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + return self.optimizer_swapper.is_swappable_tensor(None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) def _partitioned_params_swap_out(self, i): offset = 0 @@ -824,6 +1124,15 @@ def _partitioned_params_swap_out(self, i): swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params, src_fp32_params=swap_fp32_params) + def _set_fp16_partitioned_groups_flat(self): + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.partition_numel() for param in sub_group) + self.fp16_partitioned_groups_flat.append(self.lp_param_buffer.narrow(0, offset, sub_group_numel)) + offset += sub_group_numel + def initialize_optimizer_states(self): num_subgroups = len(self.fp16_groups) @@ -833,12 +1142,15 @@ def initialize_optimizer_states(self): timer_names = set() + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad) + if self.swap_optimizer: self.optimizer_swapper.init_timers() - INIT_OPTIMIZER_TIMER = 'init_optimizer_state' timer_names.add(INIT_OPTIMIZER_TIMER) - self.start_timers([INIT_OPTIMIZER_TIMER]) + self.timers(INIT_OPTIMIZER_TIMER).start() for i, group in enumerate(self.fp16_groups): swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) @@ -852,18 +1164,20 @@ def initialize_optimizer_states(self): if swappable_optimizer_subgroup: self._optimizer_states_and_gradient_swap_in(i, timer_names) + if self.use_muon and self.sub_groups_using_muon[i] and not self.save_muon_momentum_buffer_in_memory: + # Create momentum buffer after swap-in so swap files can be created on swap-out. + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(num_elements, i, self.fp32_partitioned_groups_flat[i].ds_id) if self.offload_optimizer and not swappable_optimizer_subgroup: subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=self.device) if self.offload_optimizer_pin_memory: subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer.to(self.subgroup_to_device[i]) else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) - self._optimizer_step(i) - if swappable_param_subgroup: self._partitioned_params_swap_out(i) @@ -874,8 +1188,12 @@ def initialize_optimizer_states(self): f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', force=False) - self.stop_timers([INIT_OPTIMIZER_TIMER]) - self.log_timers(timer_names) + # Initialize the optimizer states with the flattened fp32 partition. + if is_adagrad: + self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults) + + self.timers(INIT_OPTIMIZER_TIMER).stop() + self.timers.log(timer_names) if self.swap_optimizer: self.optimizer_swapper.log_timers() @@ -924,84 +1242,120 @@ def initialize_gradient_partitioning_data_structures(self): @instrument_w_nvtx def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.__reduce_and_partition_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + self.report_ipg_memory_usage("In ipg_epilogue before reduce_ipg_grads", 0) + for comm_dtype in sort_dtypes(self.ipg_buckets.keys()): + self.__reduce_and_partition_ipg_grads(comm_dtype) + self.report_ipg_memory_usage("In ipg_epilogue after reduce_ipg_grads", 0) - self.__reduce_and_partition_stream.synchronize() + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False + for param_id in self.params_already_reduced.keys(): + self.params_already_reduced[param_id] = False #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad #TODO: use a similar code path for both cpu_offload and non-cpu offload if not self.offload_optimizer: for i, sub_group in enumerate(self.fp16_groups): + #TODO: This is redundant self.averaged_gradients[i] = [ self.__param_id_to_grad_partition[param.ds_id] if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group ] - # self.averaged_gradients[i] = self.get_flat_partition( - # self.fp16_groups[i], - # 0, - # self.fp32_partitioned_groups_flat[i].numel(), - # return_tensor_list=True) - - # this method gets called after every backward. need to increment - # here because if it gets incremented in backward() the micro step - # id will be off by one when we do the reduce and partition at the. - # start of this method. - # TODO. make this less error prone - self.micro_step_id += 1 + # This method gets called after every backward. With reentrant gradient + # checkpointing, it may be called multiple times per backward pass (once per phase). + # We track that the epilogue ran this backward so we can increment micro_step_id + # at the start of the NEXT forward pass. This ensures all phases within a backward + # use the same micro_step_id value (copy semantics for all, not accumulate). + # The increment is deferred to clear_backward_seen_flag() which runs in forward(). + self._epilogue_ran_this_backward = True def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() def create_reduce_and_remove_grad_hooks(self): - print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] + print_rank_0('[Begin] Create gradient reduction hooks') + self.leaf_parameters = defaultdict(list) + non_leaf_params_requiring_grad = [] + for i, param_group in enumerate(self.fp16_groups): for param in param_group: - if param.requires_grad: - #print_rank_0(f" Before all gather {param.device}, {param.shape}") + if z3_leaf_parameter(param): + self.leaf_parameters[param.ds_z3_leaf_module].append(param) + elif param.requires_grad: + non_leaf_params_requiring_grad.append(param) - # The hook must be created in un-partitioned parameter + leaf_module_count = len(self.leaf_parameters) + + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: param.all_gather() - #print(f"After all gather {param.device}, {param.shape}") - def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] + def wrapper(param): @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() + # Re-enter backward for subsequent phases in reentrant checkpointing + self.reenter_backward_if_needed() + + self.reduce_ready_partitions_and_remove_grads(param) - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) + # Update hook state and run epilogue if all expected hooks have fired + if refresh_expected: + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + else: + current_expected = self._max_expected_hooks_seen + self.update_hook_state_and_maybe_run_epilogue(current_expected) - #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param, i) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) + + if not z3_leaf_parameter(param): + wrapper(param) # Partition the parameter after creating the hook param.partition() - print_rank_0(f'[End] Create gradient reduction hooks') + + # We delay reduce for all gradients in the leaf modules until the backward pass of the leaf module is done + for leaf_module, leaf_parameters in self.leaf_parameters.items(): + + def make_hook(params): + + def reduce_leaf_module_grads(module, grad_input, grad_output): + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() + self.reenter_backward_if_needed() + + for param in params: + # this takes care of grads for MoE experts that didn't participate in the current iteration/layer + if param.grad is None: + param.grad = torch.zeros_like(param) + self.reduce_ready_partitions_and_remove_grads(param) + + if refresh_expected: + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + else: + current_expected = self._max_expected_hooks_seen + self.update_hook_state_and_maybe_run_epilogue(current_expected) + + return reduce_leaf_module_grads + + assert required_torch_version(min_version=1.8), "Leaf module requires PyTorch >= 1.8" + self._leaf_module_hooks.append(leaf_module.register_full_backward_hook(make_hook(leaf_parameters))) + + self._remaining_grad_acc_hooks = 0 + + print_rank_0('[End] Create gradient reduction hooks') def get_param_id(self, param): - unique_id = id(param) - return self.param_id[unique_id] - - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", - force=False) + return OptimizerSwapper.parameter_id(param) - ###############Idependent Partition Gradient ######################## - def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + ###############Independent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param): #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) # Because the ipg bucket is initialized with a random place holder tensor, we must @@ -1009,87 +1363,277 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + comm_dtype = self.get_param_comm_dtype(param) + bucket = self.ipg_buckets[comm_dtype] + if bucket.elements + param.ds_numel > self.reduce_bucket_size and bucket.elements > 0: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) + self.__reduce_and_partition_ipg_grads(comm_dtype) - self.__reduce_and_partition_ipg_grads() - - param_id = self.get_param_id(param) - - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.__add_grad_to_ipg_bucket(param) + # deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards. + if getattr(param, "ds_grad_is_ready", True): + self.__add_grad_to_ipg_bucket(param) @instrument_w_nvtx @torch.no_grad() def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.__reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.wait_stream(get_accelerator().current_stream()) - if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() < self.reduce_bucket_size: + bucket = self.ipg_buckets[self.get_param_comm_dtype(param)] + if self.contiguous_gradients and bucket.elements + param.grad.numel() <= self.reduce_bucket_size: # move the gradient to a contiguous buffer - with get_accelerator().stream(self.__reduce_and_partition_stream): + with get_accelerator().stream(self.reduce_and_partition_stream): # move the parameter's gradient to the contiguous flat buffer - new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow(0, self.elements_in_ipg_bucket, - param.grad.numel()).view_as(param.grad) - new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(get_accelerator().current_stream()) + if self.zenflow and len(param.ds_shape) != 1: + transposed_shape = param.grad.t().shape + new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, + param.grad.numel()).view(transposed_shape) + new_grad_tensor.copy_(param.grad.t().contiguous(), non_blocking=True) + else: + new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) param.grad.data = new_grad_tensor - self.__params_in_ipg_bucket.append(param) + bucket.params.append(param) + bucket.elements += param.grad.numel() @instrument_w_nvtx @torch.no_grad() - def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: - if not self.__params_in_ipg_bucket: + def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) -> None: + bucket = self.ipg_buckets[communication_data_type] + params_in_bucket = bucket.params + + if not params_in_bucket: return - for param in self.__params_in_ipg_bucket: + for param in params_in_bucket: if param.grad.numel() != param.ds_numel: raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " f"gradients whose size is not same as the params") - self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + assert len(set(p.ds_id for p in params_in_bucket)) == len(params_in_bucket) + + while self.param_reduce_events and self.param_reduce_events[0].query(): + self.param_reduce_events.popleft() + if len(self.param_reduce_events) > self.max_param_reduce_events: + self.param_reduce_events.popleft().synchronize() + + with get_accelerator().stream(self.reduce_and_partition_stream): + if self.enable_sanity_checks: + assert_ints_same_as_other_ranks([p.ds_id for p in params_in_bucket]) + + if self.contiguous_gradients and bucket.elements <= self.reduce_bucket_size and not self.reduce_scatter: + grad_bucket = bucket.buffer.narrow(0, 0, bucket.elements) + grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket, communication_data_type) + else: + params_in_bucket.sort(key=lambda p: p.ds_id) + grad_partitions = self.__avg_scatter_grads(params_in_bucket, communication_data_type) + + if self.is_zenflow_select_boundary(): + self.update_selected_channels(params_in_bucket, grad_partitions) + + if self.zenflow and self.micro_step >= self.full_warm_up_rounds: + self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions) + + self.partition_grads(params_in_bucket, grad_partitions) + + params_in_bucket.clear() + bucket.elements = 0 + + if not get_accelerator().handles_memory_backpressure(): + event = get_accelerator().Event() + event.record() + self.param_reduce_events.append(event) + + def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, buffer_to_reduce: Tensor): + """ + Update the momentum buffer of the parameters using muon. + Args: + communication_data_type: torch.dtype + buffer_to_reduce: Tensor + Returns: + None + """ + if not self.use_muon: + return + + params_by_group = {} + params_size_offset = 0 + for param in self.ipg_buckets[communication_data_type].params: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + if self.sub_groups_using_muon[i]: + # copy the gradients back to the params in the ipg bucket for the muon update + param.grad.data.copy_(buffer_to_reduce.narrow(0, params_size_offset, + param.grad.numel()).view_as(param.grad), + non_blocking=False) + if i not in params_by_group: + params_by_group[i] = [] + params_by_group[i].append((param, dest_offset, params_size_offset)) + params_size_offset += param.grad.numel() + + # process muon updates per subgroup to avoid holding all parameters and states at once + for i, group_items in params_by_group.items(): + params = [param for param, _, _ in group_items] + if not params: + continue + + momentum_buffer = [] + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + # swap-in once, keep resident through update + writeback + self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(self.fp16_partitioned_groups_flat_numel[i], i, + self.fp32_partitioned_groups_flat[i].ds_id) + state_buffer = self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + elif self.save_muon_momentum_buffer_in_memory: + state_buffer = self.muon_momentum_buffer_partitioned_groups_flat[i] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + else: + # Non-swappable optimizer (GPU/CPU): momentum buffer lives in optimizer state + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(self.fp16_partitioned_groups_flat_numel[i], i, + self.fp32_partitioned_groups_flat[i].ds_id) + state_buffer = self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + + gathered_params_momentums = self._partitioned_buffers_all_gather(params, momentum_buffer, + communication_data_type) + + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * ( + (world_sz - len(params) % world_sz) % world_sz) + gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * ( + (world_sz - len(gathered_params_momentums) % world_sz) % world_sz) + grad_handles = [] + momentum_handles = [] + for base_i in range(len(params))[::world_sz]: + if base_i + rank < len(params): + param = params[base_i + rank] + g = param.grad + m = gathered_momentums_pad[base_i + rank] + update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram')) + g.data.copy_(update, non_blocking=False) + grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], + grads_pad[base_i + rank], + async_op=True) + grad_handles.append(grad_handle) + momentum_handle = dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz], + gathered_momentums_pad[base_i + rank], + async_op=True) + momentum_handles.append(momentum_handle) + + for handle in momentum_handles: + handle.wait() + for idx, (param, dest_offset, _) in enumerate(group_items): + gathered_momentum = gathered_params_momentums[idx] + chunk_sz = math.ceil(param.grad.numel() / world_sz) + start_offset = rank * chunk_sz + end_offset = start_offset + chunk_sz + if end_offset > param.grad.numel(): + buffer_to_update = torch.zeros(chunk_sz, device=param.grad.device, dtype=param.grad.dtype) + buffer_to_update[:param.grad.numel() - + start_offset] = gathered_momentum.view(-1).data[start_offset:param.grad.numel()] + else: + buffer_to_update = gathered_momentum.view(-1).data[start_offset:end_offset] + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + elif self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat[i].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + # update the momentum buffer in the optimizer state + self.optimizer.state[self.fp32_partitioned_groups_flat[i]][ + "momentum_buffer"] = self.muon_momentum_buffer_partitioned_groups_flat[i] + else: + # Non-swappable optimizer (GPU/CPU): write directly to optimizer state + self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + for handle in grad_handles: + handle.wait() + for param, _, params_size_offset in group_items: + buffer_to_reduce.narrow(0, params_size_offset, param.grad.numel()).data.copy_(param.grad.view(-1), + non_blocking=False) + + @instrument_w_nvtx + def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, + communication_data_type: torch.dtype) -> List[Tensor]: + dtype = buffer_to_reduce.dtype + if communication_data_type != dtype: + buffer_to_reduce = buffer_to_reduce.to(communication_data_type) + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor) + + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) - assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len(self.__params_in_ipg_bucket) + dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) - while self.__param_reduce_events and self.__param_reduce_events[0].query(): - self.__param_reduce_events.popleft() - if len(self.__param_reduce_events) > self.__max_param_reduce_events: - self.__param_reduce_events.popleft().synchronize() + if self.postscale_gradients and self.gradient_predivide_factor != world_sz: + buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor) - with get_accelerator().stream(self.__reduce_and_partition_stream): - if safe_mode: - assert_ints_same_as_other_ranks([p.ds_id for p in self.__params_in_ipg_bucket]) + if communication_data_type != self.dtype: + buffer_to_reduce = buffer_to_reduce.to(self.dtype) - grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) - self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) + grad_partitions = [] + grad_offset_in_buffer = 0 + self._apply_distributed_muon_update(communication_data_type, buffer_to_reduce) + for param in self.ipg_buckets[communication_data_type].params: + grad = param.grad + chunk_sz = math.ceil(grad.numel() / world_sz) - self.__params_in_ipg_bucket.clear() + start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel()) + end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel()) - event = get_accelerator().Event() - event.record() - self.__param_reduce_events.append(event) + partition = buffer_to_reduce[start_offset:end_offset] + if param.partition_numel() != partition.numel(): + padded_partition = torch.zeros(param.partition_numel(), device=grad.device, dtype=grad.dtype) + if partition.numel() > 0: + padded_partition[:partition.numel()] = partition + grad_partitions.append(padded_partition) + else: + grad_partitions.append(partition) + grad_offset_in_buffer += grad.numel() + + return grad_partitions @instrument_w_nvtx - def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + def __avg_scatter_grads(self, params_to_reduce: List[Parameter], + communication_data_type: torch.dtype) -> List[Tensor]: """average gradients and scatter partitions across ranks""" full_grads_for_rank = [p.grad for p in params_to_reduce] - if self.communication_data_type != self.dtype: - full_grads_for_rank = [g.to(self.communication_data_type) for g in full_grads_for_rank] + if communication_data_type != self.dtype: + full_grads_for_rank = [g.to(communication_data_type) for g in full_grads_for_rank] if self.postscale_gradients and self.gradient_predivide_factor != 1.0: full_grads_for_rank = [g.div(self.gradient_predivide_factor) for g in full_grads_for_rank] - grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group) + local_world_size = get_accelerator().device_count() + global_world_size = dist.get_world_size() + num_nodes = global_world_size // local_world_size + if self.all2all_process_group is not None and num_nodes > 1: + grad_partitions_for_rank = (all_to_all_loco_quant_reduce(params_to_reduce, self.all2all_process_group, + self.zeropp_loco_param) + if self.zeropp_loco_param is not None else all_to_all_quant_reduce( + full_grads_for_rank, self.all2all_process_group)) + else: + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group) - if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size(self.dp_process_group): + if self.postscale_gradients and self.gradient_predivide_factor != 1.0 and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): grad_partitions_for_rank = [g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank] - if self.communication_data_type != self.dtype: + if communication_data_type != self.dtype: grad_partitions_for_rank = [g.to(self.dtype) for g in grad_partitions_for_rank] return grad_partitions_for_rank @@ -1104,7 +1648,7 @@ def set_grad_positions(self): self.grad_position[param_id] = [int(i), int(current_offset), int(num_elements)] #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") current_offset += num_elements - see_memory_usage(f"After Set Grad positions", force=False) + see_memory_usage("After Set Grad positions", force=False) def _constant_buffered_norm2(self, input, buffer_size=250000000): norm = None @@ -1124,7 +1668,7 @@ def set_norm_for_param_grad_in_gpu(self, param): def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): with get_accelerator().stream(self.copy_grad_stream): param_id = self.get_param_id(param) - src_tensor = param.grad.view(-1).float() + src_tensor = param.grad.view(-1).to(dtype=self.master_weights_and_grads_dtype) #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") fp32_grad_tensor.copy_(src_tensor, non_blocking=True) param.grad = None @@ -1137,7 +1681,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): param_id = self.get_param_id(p) if param_id in self.norm_for_param_grads.keys(): param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + total_norm += param_norm**2 # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) @@ -1146,17 +1690,17 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0]**(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device) - return total_norm + return total_norm.cpu() @instrument_w_nvtx - def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: offload_fp32_gradients = {} offload_fp32_offsets = {} + buffers = [] for param, grad_partition in zip(params_to_release, grad_partitions): contains_real_data = param.partition_numel() * dist.get_rank(self.dp_process_group) < param.ds_numel @@ -1167,32 +1711,24 @@ def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: # move or accumulate gradient partition to target buffer grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) + buffers.append(grad_buffer) if self.micro_step_id == 0: # don't accumulate grad_buffer.copy_(grad_partition, non_blocking=True) # ensure grad buffer is a CUDA buffer to speed up the next few # operations and so it can be used asynchronously grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) elif get_accelerator().on_accelerator(grad_buffer): - grad_buffer.add_(grad_partition) + grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(grad_buffer.shape)) else: # if dst is CPU, copy first to src device, do the addition # there, then move back to dst. adding directly to cpu is very slow cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) - cuda_grad_buffer.add_(grad_partition) + cuda_grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) # ensure grad buffer is a CUDA buffer to speed up the next few # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer - if hasattr(self.__inf_or_nan_tracker, "logical_or_"): - self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 - # offload the gradient partition if applicable if self.offload_optimizer: i, dest_offset, _ = self.grad_position[self.get_param_id(param)] @@ -1201,19 +1737,21 @@ def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_buffer) if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): + if i not in offload_fp32_gradients.keys(): offload_fp32_gradients[i] = [] offload_fp32_offsets[i] = [] - offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_gradients[i].append(grad_buffer.to(dtype=self.master_weights_and_grads_dtype)) offload_fp32_offsets[i].append(dest_offset) else: fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( 0, dest_offset, grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) + fp32_grad_tensor.copy_(grad_buffer.to(dtype=self.master_weights_and_grads_dtype)) # free the gradient - param.grad.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) param.grad = None if self.offload_optimizer and self.swap_optimizer: @@ -1221,10 +1759,61 @@ def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i], gradient_offsets=offload_fp32_offsets[i], gradient_tensors=offload_fp32_gradients[i]) + return buffers + + def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_allgather: List[Tensor], + communication_data_type: torch.dtype): + """ + Allgather the partitioned buffers of the parameters to the global buffer. + Args: + params: List[Parameter] + buffers_to_allgather: List[Tensor] + communication_data_type: torch.dtype + Returns: + List[Tensor] + """ - def reduce_ready_partitions_and_remove_grads(self, param, i): + assert len(params) == len(buffers_to_allgather), "params and buffers_to_allgather must have the same length" + assert all(param.partition_numel() == buffer.numel() + for param, + buffer in zip(params, buffers_to_allgather)), \ + "params and buffers_to_allgather must have the same numel" + coalesced_buffer = instrument_w_nvtx(torch.cat)(buffers_to_allgather) + buffer_numel = coalesced_buffer.numel() + reduce_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + rearrange_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + my_rank = dist.get_rank(group=self.dp_process_group) + partition = reduce_buffer.narrow(0, buffer_numel * my_rank, buffer_numel) + partition.data.copy_(coalesced_buffer.data, non_blocking=False) + dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + param_partition_offsets = [0] + rearranged_offset = 0 + for idx, param in enumerate(params): + param_partition_offsets.append(param_partition_offsets[idx] + param.partition_numel()) + for idx, param in enumerate(params): + num_elements = param.partition_numel() + for partition_idx in range(self.partition_count): + sliced = reduce_buffer.narrow(0, buffer_numel * partition_idx + param_partition_offsets[idx], + num_elements) + rearrange_buffer.narrow(0, rearranged_offset, num_elements).copy_(sliced.data, non_blocking=False) + rearranged_offset += num_elements + param_full_offsets = [0] + for idx, param in enumerate(params): + # the offset is the sum of the numel of all the partitions of the parameter including padding + param_full_offsets.append(param_full_offsets[idx] + + buffers_to_allgather[idx].numel() * self.partition_count) + output = [] + for idx, param in enumerate(params): + output.append(rearrange_buffer.narrow(0, param_full_offsets[idx], param.ds_numel).view(param.ds_shape)) + return output + + def reduce_ready_partitions_and_remove_grads(self, param): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) - self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + self.reduce_independent_p_g_buckets_and_remove_grads(param) def zero_reduced_gradients(self, partition_id, i): @@ -1238,6 +1827,40 @@ def are_all_related_partitions_reduced(params_id): if are_all_related_partitions_reduced(params_id): self.param_dict[params_id].grad = None + def quantize_nontrainable_params(self): + """ In ZeRO-3, when the zero_quantized_nontrainable_weights flag is set, we quantize the non-trainable weights and also store them in quantized format. However, this check for trainable/non-trainable is done when deepspeed initializes the partitioning. So, if the user changes the trainable/non-trainable status of a parameter after the partitioning is done (e.g. LoRA), the user needs to re-quantize the non-trainable weights by calling this function. + """ + if not self.zero_quantized_nontrainable_weights: + print_rank_0( + "Warning: quantize_nontrainable_params() called with zero_quantized_nontrainable_weights disabled, return without doing anything", + force=True) + return + quantizer_module = CUDAQuantizer() + + def quantize_dstensor(tensor): + assert tensor.dtype == torch.float16, f"quantize_dstensor() expects tensor.dtype == torch.float16, got {tensor.dtype}" + partition_size = tensor.ds_numel + ds_status = tensor.status + final_location = tensor.final_location + tensor, tensor.ds_quant_scale = quantizer_module.quantize(tensor) + tensor.ds_numel = partition_size + tensor.status = ds_status + tensor.final_location = final_location + tensor.requires_grad = False + return tensor + + for param in self.module.parameters(): + if hasattr(param, "ds_tensor") and (param.ds_tensor.numel() <= 2048 or param.ds_numel <= 500000): + # skip small parameters + continue + if hasattr(param, + "ds_tensor") and not param.requires_grad and not hasattr(param.ds_tensor, "ds_quant_scale"): + param.ds_tensor = quantize_dstensor(param.ds_tensor) + if hasattr(param, "ds_secondary_tensor") and not param.requires_grad and not hasattr( + param.ds_secondary_tensor, "ds_quant_scale") and param.ds_secondary_tensor is not None: + param.ds_secondary_tensor = quantize_dstensor(param.ds_secondary_tensor) + get_accelerator().synchronize() + def flatten_and_print(self, message, tensors, start=0, n=5): flatten_tensor = self.flatten(tensors) @@ -1285,7 +1908,7 @@ def set_none_gradients_to_zero(self, i, partition_id): for param_id in self.is_grad_computed[i][partition_id]: param = self.param_dict[param_id] if param.grad is None: - param.grad = torch.zero_like(param) + param.grad = torch.zeros_like(param) ######################Reduction Related Methods############################## @@ -1303,18 +1926,13 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) - if rank is None: - # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - else: - global_rank = dist.get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - tensor.copy_(tensor_to_allreduce) + tensor.copy_(tensor_to_allreduce) return tensor @@ -1378,10 +1996,10 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): tensor_size = tensor.numel() - if (current_index >= start_index and current_index < end_index): + if start_index <= current_index < end_index: params_in_partition.append(tensor) - elif start_index > current_index and start_index < (current_index + tensor_size): + elif current_index < start_index < (current_index + tensor_size): params_in_partition.append(tensor) assert (first_offset == 0 @@ -1396,11 +2014,16 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset @instrument_w_nvtx - def zero_grad(self, set_to_none=False): + def zero_grad(self, set_to_none=True): """ Zero FP16 parameter grads. """ self.micro_step_id = 0 + # Reset the epilogue flag so the next forward doesn't increment micro_step_id. + # Without this, calling zero_grad() between backward and forward would cause + # micro_step_id to be incremented at the next forward, leading to incorrect + # gradient accumulation behavior. + self._epilogue_ran_this_backward = False # FP32 grad should never exist. # For speed, set model fp16 grad to None by default @@ -1415,6 +2038,24 @@ def zero_grad(self, set_to_none=False): p.grad.detach_() p.grad.zero_() + def clear_backward_seen_flag(self): + """Clear the backward seen flag and increment micro_step_id if epilogue ran. + + This override defers the micro_step_id increment from the epilogue to here. + With reentrant gradient checkpointing, the epilogue may be called multiple + times per backward pass, but we only want to increment micro_step_id once + after the backward is complete. By incrementing here at the start of the + NEXT forward, all phases within a backward use the same micro_step_id value. + """ + # Increment micro_step_id if the epilogue ran during the previous backward. + # This is deferred from independent_gradient_partition_epilogue() to ensure + # all phases within a backward use the same micro_step_id (copy semantics). + if self._epilogue_ran_this_backward: + self.micro_step_id += 1 + + # Call base class to reset flags (including _epilogue_ran_this_backward) + super().clear_backward_seen_flag() + def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ @@ -1449,7 +2090,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() + total_norm = total_norm_cuda[0] else: # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") @@ -1459,16 +2100,25 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): grad_norms.append(g.to(get_accelerator().device_name(), non_blocking=True).double().norm(2)) # Sum across all model parallel GPUs. - total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) + if len(grad_norms) == 0: + # FIX https://github.com/deepspeedai/DeepSpeed/issues/3564 + total_norm_cuda = torch.tensor(0, + dtype=gradients[0].dtype).to(get_accelerator().device_name()).double() + else: + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda.item()**(1. / norm_type) + total_norm = total_norm_cuda**(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = torch.where(inf_or_nan, err, total_norm) return total_norm @@ -1522,39 +2172,29 @@ def free_grad_in_param_list(self, param_list): def reset_cpu_buffers(self): self.norm_for_param_grads = {} - self.local_overflow = False - - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() def _pre_step(self): self.micro_step_id = 0 + # Also reset the epilogue flag so the next iteration starts fresh. + # Without this, the flag from the last backward before step() would cause + # an increment in the next forward(), which is wrong. + self._epilogue_ran_this_backward = False - print_rank_0(f"Inside Step function") - see_memory_usage(f"In step before checking overflow", force=False) + print_rank_0("Inside Step function") + see_memory_usage("In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") - self._get_param_coordinator(training=True).hierarchy = 0 + self._get_param_coordinator().hierarchy = 0 print_rank_0("Finished Tracing at Beginning of Step") + # Clear any stale params from ipg_buckets. This is needed because with + # reentrant checkpointing (use_reentrant=True), the backward pass can + # leave params in the buckets that weren't properly processed, causing + # errors in the next iteration. + for bucket in self.ipg_buckets.values(): + bucket.clear_params() + @instrument_w_nvtx def _get_norm_groups(self): norm_groups = [] @@ -1581,13 +2221,14 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): # release all the gradient since we have already created a necessary copy in dp_grad_partition self.zero_grad(set_to_none=True) - for grad in filter(lambda g: get_accelerator().on_accelerator(g), self.averaged_gradients[sub_group_id]): - grad.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + for grad in filter(lambda g: get_accelerator().on_accelerator(g), self.averaged_gradients[sub_group_id]): + grad.record_stream(get_accelerator().current_stream()) self.averaged_gradients[sub_group_id] = None @instrument_w_nvtx - def _prepare_sub_group(self, sub_group_id, timer_names=set()): + def _prepare_sub_group(self, sub_group_id, timer_names): see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', force=False) if self._swappable_optimizer_subgroup(sub_group_id): self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) @@ -1595,26 +2236,27 @@ def _prepare_sub_group(self, sub_group_id, timer_names=set()): self._prepare_fp32_grad_for_sub_group(sub_group_id) see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', force=False) - def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=None): param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + fp32_param_id = self.get_param_id(self.fp32_partitioned_groups_flat[sub_group_id]) assert self._swappable_optimizer_subgroup(sub_group_id), \ f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', force=False) - self.start_timers([OPTIMIZER_SWAP_IN_STATE]) + if timer_names is not None: + timer_names.add(OPTIMIZER_SWAP_IN_STATE_TIMER) + self.timers(OPTIMIZER_SWAP_IN_STATE_TIMER).start() self.optimizer_swapper.swap_in_optimizer_state( parameter=self.fp32_partitioned_groups_flat[sub_group_id], async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) - self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) - timer_names.add(OPTIMIZER_SWAP_IN_STATE) + if timer_names is not None: + self.timers(OPTIMIZER_SWAP_IN_STATE_TIMER).stop() see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', force=False) @instrument_w_nvtx - def _release_sub_group(self, sub_group_id, timer_names=set()): + def _release_sub_group(self, sub_group_id, timer_names): see_memory_usage(f'Before release optimizer sub group {sub_group_id}', force=False) # get rid of the fp32 gradients. Not needed anymore if not self.offload_optimizer: @@ -1644,27 +2286,37 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment): return self.flatten(padded_tensor_list) - def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=None): param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + fp32_param_id = self.get_param_id(self.fp32_partitioned_groups_flat[sub_group_id]) assert self._swappable_optimizer_subgroup(sub_group_id), \ f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' see_memory_usage(f'post-step Before swapping out optimizer tensors {sub_group_id}', force=False) - self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) + if timer_names is not None: + timer_names.add(OPTIMIZER_SWAP_OUT_STATE_TIMER) + self.timers(OPTIMIZER_SWAP_OUT_STATE_TIMER).start() self.optimizer_swapper.swap_out_optimizer_state( parameter=self.fp32_partitioned_groups_flat[sub_group_id], async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is not None) - self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) + if timer_names is not None: + self.timers(OPTIMIZER_SWAP_OUT_STATE_TIMER).stop() see_memory_usage(f'post-step After swapping out optimizer tensors {sub_group_id}', force=False) - timer_names.add(OPTIMIZER_SWAP_OUT_STATE) # get rid of the fp32 gradients. Not needed anymore self.fp32_partitioned_groups_flat[sub_group_id].grad = None + def _release_swap_buffers(self, sub_group_id): + self.optimizer_swapper.release_swap_buffers(parameter=self.fp32_partitioned_groups_flat[sub_group_id]) + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _writeback_swap_state(self, sub_group_id, write_opt_state, write_gradients): + self.optimizer_swapper.writeback_optimizer_state_and_gradients(self.fp32_partitioned_groups_flat[sub_group_id], + write_opt_state, write_gradients) + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + def _unflatten_partitioned_parameters(self, sub_group_id): updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], self.fp16_partitioned_groups[sub_group_id]) @@ -1683,11 +2335,31 @@ def _overflow_clean_up(self, prev_scale): see_memory_usage('After overflow after clearing gradients', force=False) + def _loco_err_buf_update(self, overflow: bool, scale=1.0): + """ + Loco Error Buffer update. + """ + if not overflow and scale == 1.0: return + if dist.get_rank() == 0: + logger.info(f"update loco-zero++ error buffer with overflow: {overflow}") + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if hasattr(p, 'intra_ef_buf'): + if overflow: + del p.intra_ef_buf + del p.inter_ef_buf + else: + p.intra_ef_buf[1] *= scale + p.inter_ef_buf[1] *= scale + @instrument_w_nvtx def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow - self.check_overflow() + if self.dtype == torch.float16: + self.check_overflow() #loss scaling related computation prev_scale = self.loss_scale @@ -1696,10 +2368,13 @@ def _overflow_check_and_loss_scale_update(self): if self.overflow: self._overflow_clean_up(prev_scale) + #update loco error buf + self._loco_err_buf_update(self.overflow, self.loss_scale / prev_scale) + return self.overflow @instrument_w_nvtx - def _post_step(self, timer_names=set()): + def _post_step(self, timer_names): if self.offload_optimizer: self.reset_cpu_buffers() @@ -1710,14 +2385,18 @@ def _post_step(self, timer_names=set()): if self.swap_optimizer: self.optimizer_swapper.log_timers() - self.log_timers(timer_names) + self.invalidate_secondary_tensor() + + self.timers.log(timer_names) see_memory_usage('After zero_optimizer step', force=False) - print_rank_0(f"------------------Finishing Step-----------------------") + print_rank_0("------------------Finishing Step-----------------------") @instrument_w_nvtx def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + # When torch autocast is enabled, groups in fp16_partitioned_groups are in fp32 already and those in + # fp32_partitioned_groups are aliases. Calling tensor.data.copy_ will not trigger any copy in that case. self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( self.fp32_partitioned_groups_flat[sub_group_id].data) @@ -1736,7 +2415,7 @@ def override_loss_scale(self, loss_scale): def step(self, closure=None): """ Not supporting closure. - """ + """ self._pre_step() self._partition_all_parameters() @@ -1747,15 +2426,15 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale timer_names = set() - timer_names.add('optimizer_step') - self.start_timers(['optimizer_step']) + timer_names.add(OPTIMIZER_STEP_TIMER) + self.timers(OPTIMIZER_STEP_TIMER).start() #update parameters one sub group at a time for sub_group_id, group in enumerate(self.fp16_groups): @@ -1775,14 +2454,16 @@ def step(self, closure=None): #release memory or swap out optimizer states of fp32 parameters self._release_sub_group(sub_group_id, timer_names) - self.stop_timers(['optimizer_step']) + self.timers(OPTIMIZER_STEP_TIMER).stop() self._post_step(timer_names) # warn user about caching allocator flushes memory_stats = get_accelerator().memory_stats() - alloc_retries = memory_stats["num_alloc_retries"] if memory_stats != None else 0 - if alloc_retries > self.__n_caching_allocator_flushes: + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > self.n_caching_allocator_flushes: if dist.get_rank() == 0: logger.warning( "%d pytorch allocator cache flushes since last step. this happens " @@ -1792,8 +2473,8 @@ def step(self, closure=None): "make the cache flushes go away consider adding " "get_accelerator().empty_cache() calls in your training loop to ensure " "that all ranks flush their caches at the same time", - alloc_retries - self.__n_caching_allocator_flushes) - self.__n_caching_allocator_flushes = alloc_retries + alloc_retries - self.n_caching_allocator_flushes) + self.n_caching_allocator_flushes = alloc_retries def dump_pre_step_gradients(self, debug_fp32_grads): # Dump gradient norms for debugging @@ -1829,8 +2510,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) @@ -1855,13 +2536,22 @@ def has_overflow_partitioned_grads_serial(self): @instrument_w_nvtx def has_overflow(self, partition_gradients=True): if partition_gradients: - with get_accelerator().stream(self.__reduce_and_partition_stream): - self.local_overflow = bool(self.__inf_or_nan_tracker.item()) - self.__inf_or_nan_tracker.zero_() + with get_accelerator().stream(self.reduce_and_partition_stream): + if hasattr(self.inf_or_nan_tracker, "logical_or_"): + self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any()) + self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) + overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to( + torch.uint8) + self.inf_or_nan_tracker.zero_() + + if not get_accelerator().resolves_data_dependency(): + get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) else: @@ -1902,27 +2592,18 @@ def _has_inf_or_nan(x, j=None): return True return False - @instrument_w_nvtx - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ + def backward_prologue(self): if self.swap_optimizer: self.optimizer_swapper.pre_backward() - see_memory_usage(f"Before backward", force=False) + if self.zenflow: + self.zenflow_backward_prologue() - if self.custom_loss_scaler: - scaled_loss = self.external_loss_scale * loss - scaled_loss.backward() - else: - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + see_memory_usage("Before backward", force=False) - self._get_param_coordinator(training=True).reset_step() + def backward_epilogue(self): + if self.zenflow: + self.zenflow_backward_epilogue() if self.swap_optimizer: self.optimizer_swapper.post_backward() @@ -1931,7 +2612,8 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: """get fp32 gradient partition dictionary accessed as grad_dict[parameter_group_index][parameter_index] """ - self.__reduce_and_partition_stream.synchronize() + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() grad_dict = collections.defaultdict(dict) if self.offload_optimizer: for group in self.fp16_groups: @@ -1946,41 +2628,68 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: return grad_dict - def _fp32_state_allgather(self, param, fp32_state): - reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(), - dtype=torch.float32, - device=param.device).flatten() + def _fp32_state_allgather(self, param, fp32_state_partition): + reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(), + dtype=self.master_weights_and_grads_dtype, + device=param.device) my_rank = dist.get_rank(group=self.dp_process_group) - partitions = [ - reduce_buffer.narrow(0, - fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count) - ] - partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False) + partition = reduce_buffer.narrow(0, fp32_state_partition.numel() * my_rank, fp32_state_partition.numel()) + partition.data.copy_(fp32_state_partition.data, non_blocking=False) + dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + return reduce_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape) - dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group) + def _get_fp32_grad_state_partition(self, param, release_swap_buffers): + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() - return reduce_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape) + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + if self.offload_optimizer: + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_in(group_idx) + + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + + if self._swappable_optimizer_subgroup(group_idx) and release_swap_buffers: + self._release_swap_buffers(group_idx) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id] + + return fp32_grad, group_idx def get_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - self.__reduce_and_partition_stream.synchronize() - - if self.offload_optimizer: - group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, - num_elements).to(device=param.device) - else: - fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() - + fp32_grad, _ = self._get_fp32_grad_state_partition(param=param, release_swap_buffers=True) + fp32_grad = fp32_grad.to(get_accelerator().current_device_name()).float() return self._fp32_state_allgather(param, fp32_grad) - def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: + def set_fp32_grad_for_param(self, value, param): if not param.requires_grad: - return None + return + + # if not get_accelerator().resolves_data_dependency(): + # self.reduce_and_partition_stream.synchronize() + + # if self.offload_optimizer: + # group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + # fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + # else: + # fp32_grad = self.__param_id_to_grad_partition[param.ds_id] + + fp32_grad, group_idx = self._get_fp32_grad_state_partition(param=param, release_swap_buffers=False) + # import pdb; pdb.set_trace() + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel()) + fp32_grad.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._writeback_swap_state(group_idx, write_opt_state=False, write_gradients=True) + + def _get_fp32_opt_state_partition(self, param, release_swap_buffers, optim_state_key=None): + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() - self.__reduce_and_partition_stream.synchronize() group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] if self._swappable_optimizer_subgroup(group_idx): @@ -1988,15 +2697,165 @@ def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: fp32_param = self.fp32_partitioned_groups_flat[group_idx] if optim_state_key is None: - fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements) else: - fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow( - 0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements) + + if self._swappable_optimizer_subgroup(group_idx) and release_swap_buffers: + self._release_swap_buffers(group_idx) + + return fp32_opt_state, group_idx + + def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + # import pdb; pdb.set_trace() + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, + release_swap_buffers=True, + optim_state_key=optim_state_key) + fp32_opt_state = fp32_opt_state.to(get_accelerator().current_device_name()) hp_param = self._fp32_state_allgather(param, fp32_opt_state) + + return hp_param + + def set_full_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert value.numel( + ) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, + release_swap_buffers=False, + optim_state_key=optim_state_key) + # print(f'{dist.get_rank()=} {fp32_opt_state_partition.shape=} -------- {value.shape=}') + # import pdb; pdb.set_trace() + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, + fp32_opt_state_partition.numel() * my_rank, + fp32_opt_state_partition.numel()) + fp32_opt_state_partition.data.copy_(value_partition.data) + if self._swappable_optimizer_subgroup(group_idx): self._optimizer_states_and_gradient_swap_out(group_idx) - return hp_param + + ### Local API START ### + def get_local_fp32_grad_for_param(self, param) -> Tensor: + if not param.requires_grad: + return None + + fp32_grad, _ = self._get_fp32_grad_state_partition(param=param, release_swap_buffers=True) + fp32_grad = fp32_grad.to(get_accelerator().current_device_name()).float() + return fp32_grad + + def set_local_grad_for_param(self, value, param): + if not param.requires_grad: + return + + assert value.numel() == param.ds_tensor.numel( + ), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}" + + # if not get_accelerator().resolves_data_dependency(): + # self.reduce_and_partition_stream.synchronize() + + # if self.offload_optimizer: + # group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + # fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + # else: + # fp32_grad = self.__param_id_to_grad_partition[param.ds_id] + + if self.offload_optimizer: + self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(value) + + fp32_grad, group_idx = self._get_fp32_grad_state_partition(param=param, release_swap_buffers=False) + fp32_grad.data.copy_(value.flatten().data) + + if self._swappable_optimizer_subgroup(group_idx): + self._writeback_swap_state(group_idx, write_opt_state=False, write_gradients=True) + + def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, + release_swap_buffers=True, + optim_state_key=optim_state_key) + fp32_opt_state = fp32_opt_state.to(get_accelerator().current_device_name()) + return fp32_opt_state + + def set_local_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert hasattr(param, "ds_tensor"), " The parameter does not contain the partitioned copy of the tensor." + assert value.numel() == param.ds_tensor.numel( + ), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, + release_swap_buffers=False, + optim_state_key=optim_state_key) + value_partition = value.flatten() + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + # logger.info(f"[set_local_hp_param][update the params' value successfully]") + + ### Local API END ### + + ### Vectorized API BEGIN ### + def update_fp32_grad_for_param_vectorized(self, update_func, param_list): + params_with_grad = [p for p in param_list if p.requires_grad] + if not params_with_grad: + return + + if not get_accelerator().resolves_data_dependency(): + self.reduce_and_partition_stream.synchronize() + + subgroups = {} + for p in params_with_grad: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(p)] + param_entry = (p, dest_offset, num_elements) + if group_idx in subgroups.keys(): + subgroups[group_idx].append(param_entry) + else: + subgroups[group_idx] = [param_entry] + + for group_idx, group_params in subgroups.items(): + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_in(group_idx) + + for param, dest_offset, num_elements in group_params: + if self.offload_optimizer: + fp32_grad_part = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, dest_offset, num_elements) + else: + fp32_grad_part = self.__param_id_to_grad_partition[param.ds_id] + + fp32_grad_full = self._fp32_state_allgather(param, fp32_grad_part) + new_fp32_grad_full = update_func(fp32_grad_full, param) + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = new_fp32_grad_full.flatten().narrow(0, + fp32_grad_part.numel() * my_rank, + fp32_grad_part.numel()) + fp32_grad_part.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._writeback_swap_state(sub_group_id=group_idx, write_opt_state=False, write_gradients=True) + + ### Vectorized API END ### + + ### Device API BEGIN ### + def get_hp_param_device(self, param, optim_state_key=None) -> torch.device: + if not param.requires_grad: + return None + + fp32_opt_state, _ = self._get_fp32_opt_state_partition(param, + release_swap_buffers=True, + optim_state_key=optim_state_key) + return fp32_opt_state.device + + ### Device API END ### @instrument_w_nvtx def _partition_all_parameters(self): @@ -2093,7 +2952,7 @@ def _clear_fp32_optimizer_param_groups(self): def _rigid_state_dict(self): state_dict = {} state_dict[ZERO_STAGE] = ZeroStageEnum.weights - state_dict['loss_scaler'] = self.loss_scaler + state_dict[LOSS_SCALER] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[PARTITION_COUNT] = self.partition_count @@ -2119,10 +2978,6 @@ def state_dict(self): if self.elastic_checkpoint: raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.") - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now.") - return self._rigid_state_dict() @@ -2197,7 +3052,7 @@ def _restore_base_optimizer_state(self, all_state_dict): def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] + self.loss_scaler = state_dict[LOSS_SCALER] self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.overflow = state_dict['overflow'] @@ -2206,6 +3061,20 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) self._clear_fp32_optimizer_param_groups() + if self.swap_optimizer: + # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint + self.optimizer_swapper.purge_state() + + if self.swap_optimizer: + # Touch all parameters to synchronize all buffers + timer_names = set() + self._partition_all_parameters() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + self._post_step(timer_names) + # restore fp32 partitions for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict[FP32_FLAT_GROUPS]): curr_param.data.copy_(saved_param.data) @@ -2213,8 +3082,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): # restore fp16 partitions from fp32 for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - fp16_param.data.copy_(fp32_param.data) + if sum(fp32_param.size()) > 0: + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) # update fp16 unflattened params for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): @@ -2229,7 +3099,9 @@ def load_state_dict(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, - checkpoint_folder=None): + checkpoint_folder=None, + load_serial=None, + param_shapes=None): r"""Loading a ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. @@ -2258,27 +3130,287 @@ def load_state_dict(self, if self.elastic_checkpoint: raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.") - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now.") + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) + else: + self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + # when use loading checkpoint serial, after finish loading, we need to + # delete the temp state_dict_list variable to save memory, then trigger + # the next rank's loading + if load_serial is not None: + load_serial += 1 + rank = dist.get_rank(group=self.dp_process_group) + local_rank = dist.get_local_rank() + del state_dict_list[rank] + rank_end = dist.get_world_size() - 1 + if local_rank != rank_end: + dist.send(tensor=load_serial, dst=rank + 1) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather + + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) + + def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): + """ Load optimizer and model states from the checkpoint directory. """ + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + + optim_sd = torch.load(optim_state_path, weights_only=False) + self._load_global_state_stage3(optim_sd) + + # Generally the step of each optimizer file should be the same, we can obtain from any parameter. + state_step = optim_sd[OPTIMIZER_STATE_DICT]['state'][0]['step'] + for key in ["fp32", "exp_avg", "exp_avg_sq"]: + for sub_group_id, fp16_group in enumerate(self.fp16_groups): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + key_tensor = torch.zeros_like(fp32_param) + offset = 0 + for param in fp16_group: + if param not in self.param_names: + raise ValueError(f"failed to find optimizer param in named params") + param_name = self.param_names[param] + key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name), + key) + key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition) + offset += key_layer_state_partition.numel() + if key == "fp32": + self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(key_tensor) + self.optimizer.state[fp32_param]['step'] = state_step + else: + self.optimizer.state[fp32_param][key] = key_tensor + + for param_group in self.optimizer.param_groups: + # Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter. + for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items(): + if key == 'params': + param_group['params'] = [] + else: + param_group[key] = value - self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], - load_optimizer_states=load_optimizer_states) + if self.swap_optimizer: + # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint + self.optimizer_swapper.purge_state() - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].partition(self.persistent_parameters) - self.persistent_parameters[0].all_gather(self.persistent_parameters) + if self.swap_optimizer: + # Touch all parameters to synchronize all buffers + timer_names = set() + self._partition_all_parameters() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) + self._post_step(timer_names) + + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + if sum(fp32_param.size()) > 0: + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _load_global_state_stage3(self, sd): + self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) + self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) + self.overflow = sd.get('overflow', self.overflow) + + def load_hp_checkpoint_state(self, folder, key): + rank = dist.get_rank(group=self.dp_process_group) + + # Load tensors from files and reshape them to flat vectors + loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) + + # Partition the loaded data according to the local rank + world_size = dist.get_world_size(group=self.dp_process_group) + unpartitioned_numel = loaded_checkpoint_state.numel() + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + + if world_size * partitioned_numel != unpartitioned_numel: + padding_size = world_size * partitioned_numel - unpartitioned_numel + padding_tensor = torch.zeros(padding_size, dtype=loaded_checkpoint_state.dtype) + loaded_checkpoint_state = torch.cat([loaded_checkpoint_state, padding_tensor]) + checkpoint_state_partition = loaded_checkpoint_state.narrow(0, rank * partitioned_numel, partitioned_numel) + + return checkpoint_state_partition + + def reset_swap_buffers(self): + timer_names = set() + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) def checkpoint_event_prologue(self): self._partition_all_parameters() def checkpoint_event_epilogue(self): + self.invalidate_secondary_tensor() if len(self.persistent_parameters) > 0: self.persistent_parameters[0].all_gather(self.persistent_parameters) def empty_partition_cache(self): self.parameter_offload.empty_partition_cache() + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False): + device = device.value + + self.empty_partition_cache() + + def needs_offload(target): + # return True + return target not in self.offloaded_states and (include == None or target in include) + + if needs_offload(OffloadStateTypeEnum.optim_states) or needs_offload(OffloadStateTypeEnum.hp_params): + assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, "Offloading is supported only for DeepSpeed FusedAdam." + + # HP param + if needs_offload(OffloadStateTypeEnum.hp_params): + if pin_memory: + if not hasattr(self, "hp_params_pin_buffers"): + self.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device=device)) + for t in self.fp32_partitioned_groups_flat + ] + + for src_tensor, dest_buf in zip(self.fp32_partitioned_groups_flat, self.hp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.hp_params) + + # LP param + if needs_offload(OffloadStateTypeEnum.lp_params): + if pin_memory: + if not hasattr(self, "lp_param_contiguous_pin_buffer"): + self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( + torch.empty_like(self.lp_param_buffer, device=device)) + self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) + cpu_buffer = self.lp_param_contiguous_pin_buffer + else: + cpu_buffer = self.lp_param_buffer.to(device, non_blocking=non_blocking) + + self.lp_param_buffer.data = cpu_buffer + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( + [p.ds_tensor for p in self.module.parameters()]): + tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) + + self.fp16_partitioned_groups_flat.clear() + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + + # LP grad + if needs_offload(OffloadStateTypeEnum.lp_grads): + if pin_memory: + if not hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.lp_grad_partitions_flat_pin_buffers = get_accelerator().pin_memory( + torch.empty_like(self.grad_partitions_flat_buffer, device=device)) + self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer, + non_blocking=non_blocking) + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device) + self.averaged_gradients = {} + + self.__param_id_to_grad_partition = {} + + self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): + for bucket in self.ipg_buckets.values(): + if bucket.buffer is not None: + # Record properties like shape, strides, etc. as a meta tensor + bucket.buffer_meta = bucket.buffer.to("meta") + bucket.buffer = None + self.offloaded_states.add(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if needs_offload(OffloadStateTypeEnum.optim_states): + offload_adam_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.optim_states) + + gc.collect() + get_accelerator().empty_cache() + + def reload_states(self, non_blocking: bool = False): + + device = get_accelerator().current_device_name() + + # HP param + if OffloadStateTypeEnum.hp_params in self.offloaded_states: + if hasattr(self, "hp_params_pin_buffers"): + for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): + dest.data = src.to(device, non_blocking=non_blocking) + else: + for buf in self.fp32_partitioned_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) + + # LP Param + if OffloadStateTypeEnum.lp_params in self.offloaded_states: + cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr( + self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer + self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) + self._set_fp16_partitioned_groups_flat() + + parameter_partitions = self._get_parameter_partitions() + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(parameter_partitions): + tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) + self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) + + # LP grad + if OffloadStateTypeEnum.lp_grads in self.offloaded_states: + if hasattr(self, "lp_grad_partitions_flat_pin_buffers"): + self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers.to( + device, non_blocking=non_blocking) + else: + self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to( + device, non_blocking=non_blocking) + self.averaged_gradients = {} + + offset = 0 + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + for param in all_params: + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( + 0, offset, param.partition_numel()) + offset += param.partition_numel() + + self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads) + + # contiguous bucket + if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states: + for bucket in self.ipg_buckets.values(): + if bucket.buffer_meta is not None: + # We don't restore the data + bucket.buffer = torch.empty_like(bucket.buffer_meta, device=device) + + self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer) + + # Adam + if OffloadStateTypeEnum.optim_states in self.offloaded_states: + reload_adam_states(self.optimizer, device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) + + if non_blocking: + get_accelerator().synchronize() + def _handle_overflow(cpu_sum, x, i): import math diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0e7a6115b091..20f17e59db29 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -4,37 +4,56 @@ # DeepSpeed Team import torch -import os from deepspeed import comm as dist from packaging import version as pkg_version -from collections import OrderedDict - -from deepspeed.runtime import ZeROOptimizer +from collections import OrderedDict, defaultdict +from dataclasses import dataclass, field +from typing import List, Dict, Set + +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.runtime.zenflow import zenflow_utils + +import gc +import math +from typing import Container +from deepspeed.runtime.zero.offload_states import offload_optimizer_states, reload_optimizer_states +from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage, - inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) - +from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, + align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, + count_used_parameters_in_backward) from deepspeed.runtime.zero.config import ZeroStageEnum -from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import UtilsBuilder - -from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, - SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, CLIP_GRAD, - ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params +from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, + SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, + BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.checkpoint import enable_universal_checkpoint +from deepspeed.checkpoint.constants import UNIVERSAL_CHECKPOINT_INFO +from deepspeed.utils import groups +from deepspeed.utils.debug import debug_param2name # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False +OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather' +OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' +OPTIMIZER_STEP_TIMER = 'optimizer_step' +OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER] +INITIAL_MICRO_STEP_ID = -1 + def input(msg): return @@ -59,7 +78,7 @@ def isclose(a, b, rtol=1e-09, atol=0.0): def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 + from math import gcd return x * y // gcd(x, y) @@ -69,11 +88,6 @@ def get_alignment_padding(tensor_list, alignment): return (alignment - remainder) if remainder else remainder -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") @@ -87,6 +101,28 @@ def _get_padded_tensor(src_tensor, size): return padded_tensor +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) + return padded_tensor + + +@dataclass +class IPGBucket: + buffer: List[torch.Tensor] = field(default_factory=list) + params: List[torch.Tensor] = field(default_factory=list) + grads: List[torch.Tensor] = field(default_factory=list) + elements: int = 0 + index: int = 0 + has_moe_params: bool = False + + def clear(self): + self.params.clear() + self.grads.clear() + self.elements = 0 + self.has_moe_params = False + + class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -103,21 +139,25 @@ def __init__(self, init_optimizer, param_names, timers, + optimizer_params, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True, contiguous_gradients=True, reduce_bucket_size=500000000, + use_multi_rank_bucket_allreduce=True, allgather_bucket_size=5000000000, dp_process_group=None, expert_parallel_group=None, expert_data_parallel_group=None, reduce_scatter=True, overlap_comm=False, - cpu_offload=False, + offload_optimizer_config=None, + zenflow_config=None, mpu=None, clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, communication_data_type=torch.float16, postscale_gradients=True, gradient_predivide_factor=1.0, @@ -127,18 +167,34 @@ def __init__(self, round_robin_gradients=False, has_moe_layers=False, fp16_master_weights_and_gradients=False, - elastic_checkpoint=False): + bf16_master_weights_and_gradients=False, + bf16_optimizer_states=False, + elastic_checkpoint=False, + check_grad_overflow=True): + + super().__init__() + + if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: + self.cpu_offload = True + self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory + else: + self.cpu_offload = False + self.cpu_offload_pin_memory = False + + # TODO: Remove zenflow-specific call from vanilla ZeroOptimizer, try to isolate zenflow-specific code into sub-class zenflow_zero_optimizer + self.zenflow = True if zenflow_config is not None else False if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") - logger.info(f"CPU Offload: {cpu_offload}") + logger.info(f"CPU Offload: {self.cpu_offload}") logger.info(f'Round robin gradient partitioning: {round_robin_gradients}') # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later self.elastic_checkpoint = elastic_checkpoint + self.check_grad_overflow = check_grad_overflow self.param_names = param_names self.mpu = mpu # differences from apex.fp16_utils: @@ -147,13 +203,12 @@ def __init__(self, # - flat by groups, not keeping state. TODO: remove state explicitly? # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? if not get_accelerator().is_available(): - raise SystemError("Cannot use fp16 without accelerator.") + raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).") self.optimizer = init_optimizer - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten + # Use torch or zenflow (un)flatten ops + self.flatten = _flatten_dense_tensors if not self.zenflow else zenflow_utils._flatten_dense_tensors + self.unflatten = _unflatten_dense_tensors if not self.zenflow else zenflow_utils._unflatten_dense_tensors # ZeRO stage 1 (False) or 2 (True) self.partition_gradients = partition_grads @@ -165,14 +220,12 @@ def __init__(self, self.overlap_comm = overlap_comm - self.cpu_offload = cpu_offload - - self.deepspeed_adam_offload = cpu_offload + self.deepspeed_adam_offload = self.cpu_offload self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group - + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() #expert parallel group self.ep_process_group = expert_parallel_group @@ -190,14 +243,14 @@ def __init__(self, self.is_gradient_accumulation_boundary = True # CPU-Offload requires contiguous gradients - self.contiguous_gradients = contiguous_gradients or cpu_offload + self.contiguous_gradients = contiguous_gradients or self.cpu_offload self.has_moe_layers = has_moe_layers if self.has_moe_layers: self._configure_moe_settings() self._global_grad_norm = 0. - if mpu is None: + if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'): self.model_parallel_group = None self.model_parallel_world_size = 1 self.model_parallel_rank = 0 @@ -212,24 +265,32 @@ def __init__(self, self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID self.ignore_unused_parameters = ignore_unused_parameters self.round_robin_gradients = round_robin_gradients - self.extra_large_param_to_reduce = None - self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients + self.extra_large_param_to_reduce: Dict[int, torch.Tensor] = {} - if self.fp16_master_weights_and_gradients: + def _enforce_cpu_offload(): assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \ - f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."\ - f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \ - f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam." + f"Master weights feature requires {self.zero_stage_string} Offload with DeepSpeedCPUAdam. " \ + f"Current ZeRO-Offload:{self.cpu_offload} optimizer type {type(self.optimizer)}." + + self.master_weights_and_grads_dtype = self._configure_master_weights( + fp16_master_weights_and_gradients=fp16_master_weights_and_gradients, + bf16_master_weights_and_gradients=bf16_master_weights_and_gradients, + bf16_optimizer_states=bf16_optimizer_states, + offload_enabled=self.cpu_offload, + fp16_offload_validator=_enforce_cpu_offload, + bf16_offload_validator=_enforce_cpu_offload) + + self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - if self.reduce_scatter: + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" - assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" # param flattened by groups self.bit16_groups = [] @@ -244,6 +305,10 @@ def __init__(self, # that this process will update self.single_partition_of_fp32_groups = [] + # a 16-bit CPU param buffer for cpu offload + if self.cpu_offload: + self.param_buffer_of_bit16_for_cpu_offload_groups = [] + # param partition info # These are the parameters in each group that will not be updated by this process directly @@ -252,7 +317,7 @@ def __init__(self, # These are the parameters that will be updated by this process directly self.params_in_partition = [] - # Offset from the first parameter in the the self.params_in_partition + # Offset from the first parameter in the self.params_in_partition # the parameter boundaries may not align with partition boundaries # so we need to keep track of the offset self.first_offset = [] @@ -269,9 +334,20 @@ def __init__(self, self.all_reduce_print = False self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self.gradient_accumulation_dtype = gradient_accumulation_dtype + + if self.dtype != self.gradient_accumulation_dtype: + self.use_separate_grad_accum = True + else: + self.use_separate_grad_accum = False + if self.use_separate_grad_accum and not self.partition_gradients: + self.use_grad_accum_attribute = True + else: + self.use_grad_accum_attribute = False self.round_robin_bit16_groups = [] self.round_robin_bit16_indices = [] + self.round_robin_bit16_meta = [] # Use different parallel to do all_to_all_reduce related things # padding on each partition for alignment purposes @@ -282,17 +358,40 @@ def __init__(self, # push this group to list before modify # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group - trainable_parameters = [param for param in param_group['params'] if param.requires_grad] + trainable_parameters = [] + for param in param_group['params']: + if param.requires_grad: + param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) + trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) # not sure why apex was cloning the weights before flattening # removing cloning here - see_memory_usage(f"Before moving param group {i} to CPU") - # move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.bit16_groups[i]) + # Compute group size for memory check (need 2x model size on accelerator to flatten in place: params + flat copy) + orig_group_numel = sum(param.numel() for param in self.bit16_groups[i]) + alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]) + aligned_numel = int(math.ceil(orig_group_numel / alignment)) * alignment + param_dtype = self.bit16_groups[i][0].dtype + element_size = torch.tensor([], dtype=param_dtype).element_size() + flat_buffer_bytes = aligned_numel * element_size + empty_cache() - see_memory_usage(f"After moving param group {i} to CPU", force=False) + accelerator = get_accelerator() + available_memory = accelerator.available_memory() if accelerator.is_available() else 0 + # Flatten on accelerator device if we have enough memory for the flat buffer + flatten_on_accelerator = (accelerator.is_available() and (available_memory >= flat_buffer_bytes)) + + if not flatten_on_accelerator: + see_memory_usage(f"Before moving param group {i} to CPU") + # move all the parameters to cpu to free up accelerator memory for creating flat buffer + for param in self.bit16_groups[i]: + param.cpu_data = param.data.cpu() + param.data = torch.empty(1).to(param.device) + + empty_cache() + see_memory_usage(f"After moving param group {i} to CPU", force=False) # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. @@ -308,21 +407,39 @@ def __init__(self, self.round_robin_bit16_groups.append(round_robin_tensors) self.round_robin_bit16_indices.append(round_robin_indices) - # create flat buffer in CPU and move to GPU - self.bit16_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.round_robin_bit16_groups[i], - self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to( - get_accelerator().current_device_name())) - see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) - - # Record padding required for alignment - if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: - padding = self.bit16_groups_flat[i].numel() - sum( - [t.numel() for t in self.round_robin_bit16_groups[i]]) + # Create meta tensors list, ordered according to round_robin_tensors + meta_tensors = [] + for param in round_robin_tensors: + if flatten_on_accelerator: + meta_tensors.append(torch.zeros_like(param.data, device="meta")) + else: + meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta")) + self.round_robin_bit16_meta.append(meta_tensors) + + if flatten_on_accelerator: + logger.info(f"Flattening param group {i} on {accelerator.device_name()} (sufficient memory)") + flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i], + alignment, + use_cpu_data=False).detach() + self.bit16_groups_flat.append(flattened_buffer) + see_memory_usage(f"After flattening param group {i} on {accelerator.device_name()}", force=False) else: - padding = 0 - self.groups_padding.append(padding) + logger.info(f"Flattening param group {i} on CPU (insufficient memory)") + + flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i], + alignment, + use_cpu_data=True) + + # free temp CPU params + for param in self.bit16_groups[i]: + del param.cpu_data + + # Move CPU flat tensor to the accelerator memory. + self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name())) + del flattened_buffer + + see_memory_usage(f"After flattening and moving param group {i} to {get_accelerator().device_name()}", + force=False) if dist.get_rank(group=self.real_dp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) @@ -335,6 +452,18 @@ def __init__(self, data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i) self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) + # Record padding required for alignment + left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]]) + curr_partition_size = data_parallel_partitions[partition_id].numel() + + if orig_group_numel <= left_boundary: + padding = curr_partition_size + elif orig_group_numel < left_boundary + curr_partition_size: + padding = left_boundary + curr_partition_size - orig_group_numel + else: + padding = 0 + self.groups_padding.append(padding) + # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0) @@ -342,12 +471,24 @@ def __init__(self, # A partition of the fp32 master weights that will be updated by this process. # Note that the params in single_partition_of_fp32_groups is cloned and detached # from the origin params of the model. - if not fp16_master_weights_and_gradients: - self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to( - self.device).clone().float().detach()) - else: - self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to( - self.device).clone().half().detach()) + weights_partition = self.parallel_partitioned_bit16_groups[i][partition_id].detach().clone().to( + device=self.device, dtype=self.master_weights_and_grads_dtype) + + if self.cpu_offload: + if self.cpu_offload_pin_memory: + weights_partition = get_accelerator().pin_memory(weights_partition) + temp_dtype = self.parallel_partitioned_bit16_groups[i][partition_id].dtype + temp_buffer_bit16 = torch.full(weights_partition.shape, + fill_value=0.0, + dtype=temp_dtype, + device=weights_partition.device) + if self.cpu_offload_pin_memory: + temp_pinned = get_accelerator().pin_memory(temp_buffer_bit16) + self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_pinned) + else: + self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_buffer_bit16) + + self.single_partition_of_fp32_groups.append(weights_partition) # Set local optimizer to have flat params of its own partition. # After this, the local optimizer will only contain its own partition of params. @@ -365,20 +506,12 @@ def __init__(self, self.params_not_in_partition.append(params_not_in_partition) self.first_offset.append(first_offset) - for rank in range(dist.get_world_size()): - if dist.get_rank() == rank: - print( - f"Rank: {rank} partition count {self.partition_count} and sizes{[(p.numel(), self.is_moe_param_group[i] if hasattr(self, 'is_moe_param_group') else False) for i,p in enumerate(self.single_partition_of_fp32_groups)]} " - ) - dist.barrier() - self.reduce_bucket_size = int(reduce_bucket_size) + self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce self.allgather_bucket_size = int(allgather_bucket_size) - self.reduction_event = get_accelerator().Event(enable_timing=False, blocking=False) - self.reduction_stream = get_accelerator().Stream() - self.cpu_computation_stream = get_accelerator().Stream() - self.copy_grad_stream = get_accelerator().Stream() + self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() + #self.copy_grad_stream = get_accelerator().Stream() self.callback_queued = False self.param_dict = {} @@ -386,13 +519,19 @@ def __init__(self, # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 + self.torch_autocast_gradscaler = None + if is_autocast_initialized(): + comm_dtypes = get_all_comm_dtypes([p for params in self.bit16_groups for p in params]) + if get_autocast_dtype() == torch.float16: + self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) + else: + comm_dtypes = {self.communication_data_type} + + self.ipg_buckets: Dict[torch.dtype, IPGBucket] = {dtype: IPGBucket() for dtype in comm_dtypes} + self.params_already_reduced = [] self._release_ipg_buffers() - self.previous_reduced_grads = None - self.ipg_bucket_has_moe_params = False + self.previous_reduced_grads: Dict[int, List[torch.Tensor]] = defaultdict(list) # simplified param id self.param_id = {} @@ -423,8 +562,12 @@ def __init__(self, self.norm_for_param_grads = {} self.local_overflow = False self.grad_position = {} - self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( - torch.zeros(largest_param_numel, device=self.device, dtype=self.dtype)) + self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel, + device=self.device, + dtype=self.dtype) + if self.cpu_offload_pin_memory: + self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( + self.temp_grad_buffer_for_cpu_offload) self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel, device=get_accelerator().current_device_name(), dtype=self.dtype) @@ -454,7 +597,7 @@ def __init__(self, # will store the averaged gradients required by this partition self.averaged_gradients = {} - + self.all_grad_tensors = {} # For cpu_offload, will store the averaged gradients required by this partition self.offload_gradient_dict = {} @@ -467,10 +610,17 @@ def __init__(self, # resets the data structure value for the next backward propagation self.reset_partition_gradient_structures() - # creates backward hooks for gradient partitioning - if self.partition_gradients or self.overlap_comm: - self.create_reduce_and_remove_grad_hooks() + # creates backward hooks for the following special handling of gradients + # 1. upcasting for fp32 gradient accumulation + # 2. gradient partitioning + # 3. overlapping backward and reduction + self._grad_acc_hooks = [] + if (self.partition_gradients or self.overlap_comm or self.use_grad_accum_attribute + or self.contiguous_gradients): + self.create_gradient_handling_hooks() + + self.ready_for_gradients = False self.custom_loss_scaler = False self.external_loss_scale = None @@ -481,24 +631,54 @@ def __init__(self, dynamic_loss_args=dynamic_loss_args) self.dynamic_loss_scale = self.loss_scaler.dynamic - see_memory_usage("Before initializing optimizer states", force=True) + if self.dtype != torch.float16: + # Only fp16 should use dynamic loss scaling + assert self.loss_scaler.cur_scale == 1.0 + assert not self.dynamic_loss_scale + + see_memory_usage("Before initializing optimizer states", force=False) self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=True) + see_memory_usage("After initializing optimizer states", force=False) if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") + logger.info("optimizer state initialized") if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=True) + see_memory_usage("After initializing ZeRO optimizer", force=False) self._link_all_hp_params() + self._hp_optimizer_states_linked = False + self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + if self.cpu_offload: + self._create_optimizer_mapping() + + self.offloaded_states: Set[OffloadStateTypeEnum] = set() + + def destroy(self): + for i, _ in enumerate(self.optimizer.param_groups): + for p in self.bit16_groups[i]: + if getattr(p, '_hp_mapping', None): + p._hp_mapping = None + for hook in self._grad_acc_hooks: + hook.remove() + self.print_rank_0("Removed grad acc hooks") def _enable_universal_checkpoint(self): + self._universal_checkpoint_info = None for lp_param_group in self.bit16_groups: + if self._universal_checkpoint_info is None: + for param in lp_param_group: + autotp_uc_info = getattr(param, UNIVERSAL_CHECKPOINT_INFO, None) + if autotp_uc_info is not None: + self._universal_checkpoint_info = autotp_uc_info + break enable_universal_checkpoint(param_list=lp_param_group) + def _get_universal_checkpoint_info(self): + return getattr(self, '_universal_checkpoint_info', None) + def _create_param_mapping(self): param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): @@ -511,15 +691,21 @@ def _create_param_mapping(self): return param_mapping + def _create_optimizer_mapping(self): + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bit16_groups[i]: + if lp._hp_mapping is not None: + lp._zero_optimizer = self + def _link_all_hp_params(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) if self.cpu_offload: self._get_offload_gradient_dict() for i, _ in enumerate(self.optimizer.param_groups): # Link bit16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - partition_size = self.bit16_groups_flat[i].numel() // dp_world_size + partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size( + group=self.real_dp_process_group[i]) flat_hp_partition = self.single_partition_of_fp32_groups[i] link_hp_params(lp_param_list=self.bit16_groups[i], flat_hp_partition=flat_hp_partition, @@ -529,9 +715,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def is_moe_group(self, group): return 'moe' in group and group['moe'] @@ -541,7 +733,7 @@ def _configure_moe_settings(self): assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion if not self.partition_gradients and not self.contiguous_gradients: - logger.warn( + logger.warning( "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.") assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" @@ -563,8 +755,7 @@ def _configure_moe_settings(self): assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" def _update_model_bit16_weights(self, group_index): - updated_params = self.unflatten(self.bit16_groups_flat[group_index], - self.round_robin_bit16_groups[group_index]) + updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_meta[group_index]) for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): p.data = q.data @@ -582,7 +773,7 @@ def _round_robin_reorder(self, tensor_list, num_partitions): for i, tensor in enumerate(tensor_list): j = i % num_partitions - if not j in partition_tensors: + if j not in partition_tensors: partition_tensors[j] = [] partition_tensors[j].append((i, tensor)) @@ -598,9 +789,12 @@ def _round_robin_reorder(self, tensor_list, num_partitions): def _release_ipg_buffers(self): if self.contiguous_gradients: - self.ipg_buffer = None + for bucket in self.ipg_buckets.values(): + bucket.buffer.clear() + self.grads_in_partition = None self.grads_in_partition_offset = 0 + self.ready_for_gradients = False def initialize_optimizer_states(self): @@ -609,9 +803,13 @@ def initialize_optimizer_states(self): dtype=self.single_partition_of_fp32_groups[i].dtype, device=self.device) self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( - single_grad_partition) if self.cpu_offload else single_grad_partition + single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition - self.optimizer.step() + # Initialize the optimizer states with the flattened fp32 partition. + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. + if isinstance(self.optimizer, torch.optim.Adagrad): + self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: @@ -628,17 +826,18 @@ def reduce_gradients(self, pipeline_parallel=False): # with PP we must create ipg buffer, since backward is handled outside zero if pipeline_parallel and self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) - self.ipg_index = 0 + for dtype, bucket in self.ipg_buckets.items(): + bucket.buffer.append( + torch.empty(int(self.reduce_bucket_size), + dtype=dtype, + device=get_accelerator().current_device_name())) + bucket.index = 0 if not self.overlap_comm: for i, group in enumerate(self.bit16_groups): for param in group: - if param.grad is not None: + grad_reduc = self.get_gradient_for_reduction(param) + if grad_reduc is not None: self.reduce_ready_partitions_and_remove_grads(param, i) # reduce any pending grads in either hook/non-hook case self.overlapping_partition_gradients_reduce_epilogue() @@ -650,8 +849,9 @@ def reduce_gradients(self, pipeline_parallel=False): def get_first_param_index(self, group_id, param_group, partition_id): for index, param in enumerate(param_group): param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index + if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]: + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index return None def initialize_gradient_partitioning_data_structures(self): @@ -679,9 +879,9 @@ def initialize_gradient_partitioning_data_structures(self): i, param_group, partition_id) def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.report_ipg_memory_usage("In ipg_epilogue before reduce_ipg_grads", 0) self.reduce_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + self.report_ipg_memory_usage("In ipg_epilogue after reduce_ipg_grads", 0) # if dist.get_rank() == 0: # logger.info("Params already reduced %s", self.params_already_reduced) @@ -689,39 +889,64 @@ def independent_gradient_partition_epilogue(self): self.params_already_reduced[i] = False if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear previously reduced grads of other partitions self._clear_previous_reduced_grads() if self.cpu_offload is False: for i, _ in enumerate(self.bit16_groups): - - if not i in self.averaged_gradients or self.averaged_gradients[i] is None: + if i not in self.all_grad_tensors or self.all_grad_tensors[i] is None: + self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i], + dtype=self.gradient_accumulation_dtype) + else: + avg_new = self.get_all_grad_tensors(self.params_in_partition[i], + dtype=self.gradient_accumulation_dtype) + for accumulated_grad, new_avg_grad in zip(self.all_grad_tensors[i], avg_new): + accumulated_grad.add_(new_avg_grad) + if self.is_gradient_accumulation_boundary: self.averaged_gradients[i] = self.get_flat_partition( self.params_in_partition[i], self.first_offset[i], self.partition_size[i], - dtype=self.dtype, + dtype=self.gradient_accumulation_dtype, device=get_accelerator().current_device_name(), + param_group_idx=i, return_tensor_list=True) - else: - avg_new = self.get_flat_partition(self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=self.dtype, - device=get_accelerator().current_device_name(), - return_tensor_list=True) - - for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new): - accumulated_grad.add_(new_avg_grad) + # Clear all_grad_tensors after use. With reentrant checkpointing, + # the epilogue may run multiple times per backward pass. Each time, + # we read the cumulative grad_accum (which PyTorch naturally accumulates) + # and the final phase will have all gradients. + self.all_grad_tensors[i] = None self._release_ipg_buffers() - # No need to keep the gradients anymore. - # All gradients required by the step - # are in self.averaged_gradients - self.zero_grad(set_to_none=True) - see_memory_usage(f"End ipg_epilogue") + # Clear param.grad so safe_get_full_grad() goes through the proper _hp_mapping + # path (which does all_reduce for ZeRO-2). Keep grad_accum intact for reentrant + # checkpointing where gradients need to accumulate across multiple phases. + # grad_accum is cleared in clear_backward_seen_flag() at the start of next forward. + self._clear_param_grad_only() + self._epilogue_ran_this_backward = True + + see_memory_usage("End ipg_epilogue") + + def clear_backward_seen_flag(self): + """Clear the backward seen flag and do deferred cleanup. + + With reentrant gradient checkpointing, the epilogue may run multiple times + per backward pass (once per phase). We defer clearing grad_accum until here + (called at the start of the next forward) to ensure all phases have completed. + + Note: param.grad is cleared in the epilogue via _clear_param_grad_only() to + ensure safe_get_full_grad() works correctly. Only grad_accum is deferred. + """ + if self._epilogue_ran_this_backward: + # Clear grad_accum for next step. param.grad is already cleared in epilogue. + for group in self.bit16_groups: + for p in group: + p.grad_accum = None + + super().clear_backward_seen_flag() # resets all partition to no reduced # sets remaining grads to the total number of grads in each partition @@ -763,7 +988,7 @@ def increment_value(dictionary, key): param_size = param.numel() param_id = self.get_param_id(param) - if (current_index >= start_index and current_index < end_index): + if start_index <= current_index < end_index: set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id) increment_value(self.total_grads_in_partition[i], partition_id) @@ -772,7 +997,7 @@ def increment_value(dictionary, key): self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index self.grad_start_offset[i][partition_id][param_id] = 0 - elif start_index > current_index and start_index < (current_index + param_size): + elif current_index < start_index < (current_index + param_size): assert (first_offset == 0 ), "This can happen either zero or only once as this must be the first tensor in the partition" first_offset = start_index - current_index @@ -790,74 +1015,122 @@ def increment_value(dictionary, key): def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() - def create_reduce_and_remove_grad_hooks(self): - self.grad_accs = [] + def _fill_param_grad_accum_attribute(self, param): + if param.grad is not None: + if param.grad_accum is None: + param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) + else: + param.grad_accum.add_(param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) + param.grad = None + + def fill_grad_accum_attribute(self): + for group in self.bit16_groups: + for param in group: + self._fill_param_grad_accum_attribute(param) + + def get_gradient_for_reduction(self, param): + if self.use_grad_accum_attribute: + return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None + else: + return param.grad + + def get_param_gradient_attribute(self, param): + return param.grad_accum if self.use_grad_accum_attribute else param.grad + + # Clear the tensor the reduction gradient attribute is pointing to + def clear_grad_attribute(self, param): + if self.use_grad_accum_attribute: + param.grad_accum = None + else: + param.grad = None + + def create_gradient_handling_hooks(self): + all_params_requiring_grad = [] + + for i, param_group in enumerate(self.bit16_groups): + for param in param_group: + if param.requires_grad: + all_params_requiring_grad.append(param) + for i, param_group in enumerate(self.bit16_groups): for param in param_group: if param.requires_grad: def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) + def grad_handling_hook(*notneeded): + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() + self.reenter_backward_if_needed() + self.process_gradients(param, i) + if refresh_expected: + current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + else: + current_expected = self._max_expected_hooks_seen + self.update_hook_state_and_maybe_run_epilogue(current_expected) - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) wrapper(param, i) + self._remaining_grad_acc_hooks = 0 + def get_param_id(self, param): unique_id = id(param) return self.param_id[unique_id] - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" - ) - # create a flat tensor aligned at the alignment boundary - def flatten_dense_tensors_aligned(self, tensor_list, alignment): + def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False): + tensor_list = [param.cpu_data for param in tensor_list] if use_cpu_data else tensor_list return self.flatten(align_dense_tensors(tensor_list, alignment)) ############### Independent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: + + grad_reduc = self.get_gradient_for_reduction(param) + comm_dtype = self.get_param_comm_dtype(param) + bucket = self.ipg_buckets[comm_dtype] + if bucket.elements + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) - self.reduce_ipg_grads() + self.reduce_ipg_grads(comm_dtype=comm_dtype) if self.contiguous_gradients and self.overlap_comm: - # Swap ipg_index between 0 and 1 - self.ipg_index = 1 - self.ipg_index + # Swap index between 0 and 1 + bucket.index = 1 - bucket.index self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel()) + # deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards. + if not getattr(param, "ds_grad_is_ready", True): + return + param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ + f"The parameter {debug_param2name(param)} has already been reduced. \ Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - if param.numel() > self.reduce_bucket_size: - self.extra_large_param_to_reduce = param + Multiple gradient reductions are currently not supported" - elif self.contiguous_gradients: - # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening - new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel()) - new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + if self.contiguous_gradients: + if param.numel() > self.reduce_bucket_size: + self.extra_large_param_to_reduce[comm_dtype] = param + else: + # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening + new_grad_tensor = bucket.buffer[bucket.index].narrow(0, bucket.elements, param.numel()) + new_grad_tensor.copy_( + grad_reduc.view(-1) if not self.zenflow else grad_reduc.permute( + *reversed(range(grad_reduc.ndim))).contiguous().view(-1)) + grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) if ( + not self.zenflow or grad_reduc.dim() == 1) else new_grad_tensor.data.view_as( + grad_reduc.transpose(0, 1)) - self.elements_in_ipg_bucket += param.numel() + bucket.elements += param.numel() - assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" + assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" - self.grads_in_ipg_bucket.append(param.grad) - self.params_in_ipg_bucket.append((i, param, param_id)) + bucket.grads.append(grad_reduc) + bucket.params.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): - self.ipg_bucket_has_moe_params = True + bucket.has_moe_params = True self.report_ipg_memory_usage("End ipg_remove_grads", 0) @@ -865,14 +1138,16 @@ def print_rank_0(self, message): if dist.get_rank() == 0: logger.info(message) - def gradient_reduction_w_predivide(self, tensor): + def gradient_reduction_w_predivide(self, tensor, communication_data_type: torch.dtype): + if tensor.size().numel() == 0: + return tensor dp_world_size = dist.get_world_size(group=self.dp_process_group) tensor_to_allreduce = tensor - if self.communication_data_type != tensor.dtype: - tensor_to_allreduce = tensor.to(self.communication_data_type) + if communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(communication_data_type) if self.postscale_gradients: if self.gradient_predivide_factor != 1.0: @@ -881,26 +1156,86 @@ def gradient_reduction_w_predivide(self, tensor): dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) + tensor_to_allreduce.mul_(self.gradient_predivide_factor / + (dp_world_size / float(self.sequence_parallel_size))) else: - tensor_to_allreduce.div_(dp_world_size) + tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) return tensor - def average_tensor(self, tensor): + def allreduce_and_copy_with_multiple_ranks(self, + small_bucket, + communication_data_type: torch.dtype, + log=None, + divide=True, + process_group=None, + bucket_ranks=None): + process_group = self.dp_process_group if process_group is None else process_group + allreduced = self.allreduce_bucket(small_bucket, + communication_data_type, + log=log, + divide=divide, + process_group=process_group) + if self.overlap_comm and not get_accelerator().resolves_data_dependency(): + allreduced.record_stream(self.reduction_stream) + for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks): + if dist.get_rank(group=process_group) == bucket_rank: + buf.copy_(synced) + if self.overlap_comm and not get_accelerator().resolves_data_dependency(): + buf.record_stream(self.reduction_stream) + + def allreduce_and_scatter(self, + bucket, + communication_data_type: torch.dtype, + numel_per_bucket=500000000, + log=None, + divide=True, + process_group=None): + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + allreduce_sizes = [] + + for i, bucket_elem in enumerate(bucket): + rank, tensor = bucket_elem + small_bucket.append(tensor) + small_bucket_ranks.append(rank) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + communication_data_type, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + + if len(small_bucket) > 0: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + communication_data_type, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + + def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dtype): if self.overlap_comm: stream = self.reduction_stream - stream.wait_stream(get_accelerator().current_stream()) + if not get_accelerator().resolves_data_dependency(): + stream.wait_stream(get_accelerator().current_stream()) + get_accelerator().current_stream().wait_stream(stream) else: stream = get_accelerator().current_stream() with get_accelerator().stream(stream): if not self.reduce_scatter: - self.gradient_reduction_w_predivide(tensor) + self.gradient_reduction_w_predivide(tensor, communication_data_type) return # Accumulate destination ranks and bucket offsets for each gradient slice. @@ -910,20 +1245,19 @@ def average_tensor(self, tensor): rank_and_offsets = [] real_dp_process_group = [] curr_size = 0 - prev_id = -1 + prev_id, prev_process_group = -1, None process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + bucket = self.ipg_buckets[communication_data_type] + for i, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group - #Averages gradients at parameter level if ipg has a moe param - #Otherwise averaging is done at the entire buffer level at the end of the loop - # MoE param have different groups - if self.ipg_bucket_has_moe_params: + + if bucket.has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.dp_process_group - param.grad.data.div_(dist.get_world_size(group=process_group)) partition_ids = self.param_to_partition_ids[i][param_id] assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids @@ -953,38 +1287,44 @@ def average_tensor(self, tensor): numel = partition_ids_w_offsets[idx + 1][1] - offset # Merge bucket ranges if they belong to the same rank - if partition_id == prev_id: + if partition_id == prev_id and process_group == prev_process_group: prev_pid, prev_size, prev_numel = rank_and_offsets[-1] rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) else: rank_and_offsets.append((partition_id, curr_size, numel)) real_dp_process_group.append(process_group) curr_size += numel - prev_id = partition_id - - if not self.ipg_bucket_has_moe_params: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) + prev_id, prev_process_group = partition_id, process_group - tensor_to_reduce = tensor - if self.communication_data_type != tensor.dtype: - tensor_to_reduce = tensor.to(self.communication_data_type) + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) - async_handles = [] + buckets = {} for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): - grad_slice = tensor_to_reduce.narrow(0, int(bucket_offset), int(numel)) - # if dist.get_rank() == 0: - # print(f"Rank {dist.get_rank()} rank offset id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}") - # dist.barrier() - #dist.barrier() - dst_rank = dist.get_global_rank(real_dp_process_group[i], dst) - async_handle = dist.reduce(grad_slice, dst=dst_rank, group=real_dp_process_group[i], async_op=True) - async_handles.append(async_handle) - - for handle in async_handles: - handle.wait() - - if self.communication_data_type != tensor.dtype: - tensor.copy_(tensor_to_reduce) + grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) + bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( + dst, real_dp_process_group[i]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + if self.use_multi_rank_bucket_allreduce: + buckets[bucket_key].append((dst, grad_slice)) + else: + buckets[bucket_key].append(grad_slice) + + for bucket_key in buckets: + if self.use_multi_rank_bucket_allreduce: + self.allreduce_and_scatter(buckets[bucket_key], + communication_data_type, + numel_per_bucket=self.reduce_bucket_size, + divide=False, + process_group=bucket_key) + else: + dst, process_group = bucket_key + self.allreduce_no_retain(buckets[bucket_key], + communication_data_type, + numel_per_bucket=self.reduce_bucket_size, + rank=dst, + divide=False, + process_group=process_group) ############################################################################## ############################# CPU Offload Methods############################# @@ -1014,10 +1354,14 @@ def get_grad_position(self, group_id, tensor_list, first_offset, partition_size) ] current_offset += num_elements - def update_overflow_tracker_for_param_grad(self, param): - if param.grad is not None and self._has_inf_or_nan(param.grad.data): + def update_offload_overflow_tracker(self, grad): + if grad is not None and self._has_inf_or_nan(grad.data): self.local_overflow = True + def update_offload_overflow_tracker_for_param_grad(self, param): + grad_accum = self.get_param_gradient_attribute(param) + self.update_offload_overflow_tracker(grad_accum) + def _get_offload_gradient_dict(self): for param_group_index, _ in enumerate(self.optimizer.param_groups): self.offload_gradient_dict[param_group_index] = [] @@ -1038,29 +1382,32 @@ def async_accumulate_grad_in_cpu_via_gpu(self, param): #buffer for storing gradients for this parameter in CPU def buffer_to_accumulate_to_in_cpu(): - if not self.fp16_master_weights_and_gradients: - return get_accelerator().pin_memory(torch.zeros(param.numel(), dtype=param.dtype, device=self.device)) + if not self.low_precision_master_weights_and_grads: + buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device) + return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer else: return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) - #accumulate gradients into param.grad or parts of it that belongs to this partition + #accumulate gradients into param.grad_accum or parts of it that belongs to this partition def accumulate_gradients(): - if not self.fp16_master_weights_and_gradients: + grad_accum = self.get_param_gradient_attribute(param) + if not self.low_precision_master_weights_and_grads: dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) - param.grad.data.view(-1).add_(dest_buffer) + grad_accum.data.view(-1).add_(dest_buffer) else: dest_buffer.narrow(0, source_offset, num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) - param.grad.data.view(-1).narrow(0, source_offset, + grad_accum.data.view(-1).narrow(0, source_offset, num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements)) #move accumulated gradients back to CPU def copy_gradients_to_cpu(): - if not self.fp16_master_weights_and_gradients: - self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), non_blocking=True) + grad_accum = self.get_param_gradient_attribute(param) + if not self.low_precision_master_weights_and_grads: + self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True) else: - self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1).narrow( + self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow( 0, source_offset, num_elements), non_blocking=True) @@ -1069,15 +1416,13 @@ def copy_gradients_to_cpu(): if self.micro_step_id > 0: accumulate_gradients() - - # at the boundary we will send 32bit directly - if not self.is_gradient_accumulation_boundary: - copy_gradients_to_cpu() + copy_gradients_to_cpu() def set_norm_for_param_grad(self, param): param_id = self.get_param_id(param) + grad_accum = self.get_param_gradient_attribute(param) accumulated_grad = self.accumulated_grads_in_cpu[ - param_id] if self.gradient_accumulation_steps > 1 else param.grad + param_id] if self.gradient_accumulation_steps > 1 else grad_accum [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] @@ -1088,7 +1433,11 @@ def set_norm_for_param_grad(self, param): def set_norm_for_param_grad_in_gpu(self, param): param_id = self.get_param_id(param) - accumulated_grad = param.grad + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + accumulated_grad = param.grad + else: + accumulated_grad = grad_accum [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] @@ -1104,12 +1453,15 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) - src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements) - if not self.fp16_master_weights_and_gradients: - src_tensor = src_tensor.float() + grad_accum = self.get_param_gradient_attribute(param) + assert grad_accum is not None + + src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) + if src_tensor.dtype != self.master_weights_and_grads_dtype: + src_tensor = src_tensor.to(self.master_weights_and_grads_dtype) dest_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None #offload only + self.clear_grad_attribute(param) #offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 @@ -1122,7 +1474,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, - # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # their backward hooks in self.create_gradient_handling_hooks() will not run, # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] @@ -1141,29 +1493,32 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): """ # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + total_dev_norm = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) - self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_dev_norm[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + total_norm = -1.0 - return total_norm + return torch.tensor(total_norm, device=self.device, dtype=torch.float) ############################################################################################ def copy_grads_in_partition(self, param): if self.cpu_offload: - - if self.gradient_accumulation_steps > 1: + # Accumulate when there were prior backwards in this step (restore from + # CPU buffer) or more will follow (save to CPU buffer). Skipping only + # the lone backward of a step preserves the existing fast path for + # ga_steps=1 + single backward. + if self.micro_step_id > 0 or not self.is_gradient_accumulation_boundary: self.async_accumulate_grad_in_cpu_via_gpu(param) if self.is_gradient_accumulation_boundary: self.set_norm_for_param_grad_in_gpu(param) - self.update_overflow_tracker_for_param_grad(param) + self.update_offload_overflow_tracker_for_param_grad(param) self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) @@ -1182,28 +1537,41 @@ def copy_grads_in_partition(self, param): device=get_accelerator().current_device_name()) see_memory_usage(f"after copying {total_size} gradients into partition") + grad_reduc = self.get_gradient_for_reduction(param) # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel()) - new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + new_grad_tensor.copy_(grad_reduc.view(-1)) + grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}") self.grads_in_partition_offset += param.numel() - def reduce_ipg_grads(self): - if self.contiguous_gradients: - if self.extra_large_param_to_reduce is not None: - assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen" - _, _, param_id = self.params_in_ipg_bucket[0] - assert self.get_param_id(self.extra_large_param_to_reduce - ) == param_id, "param in ipg bucket does not match extra-large param" - self.average_tensor(self.extra_large_param_to_reduce.grad.view(-1)) - self.extra_large_param_to_reduce = None + def reduce_ipg_grads(self, comm_dtype=None): + dtypes = sort_dtypes(self.ipg_buckets.keys()) + if comm_dtype is not None: + dtypes = [comm_dtype] + for comm_dtype in dtypes: + bucket = self.ipg_buckets[comm_dtype] + + if self.contiguous_gradients: + if comm_dtype in self.extra_large_param_to_reduce: + assert len(bucket.params) == 1, "more than 1 param in ipg bucket, this shouldn't happen" + _, _, param_id = bucket.params[0] + assert self.get_param_id(self.extra_large_param_to_reduce[comm_dtype] + ) == param_id, "param in ipg bucket does not match extra-large param" + extra_large_grad_reduc = self.get_gradient_for_reduction( + self.extra_large_param_to_reduce[comm_dtype]) + + extra_large_grad_reduc_for_average = extra_large_grad_reduc.view(-1) if not self.zenflow \ + else extra_large_grad_reduc.permute(*reversed(range(extra_large_grad_reduc.ndim))).contiguous().view(-1) + extra_large_grad_reduc.data = extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc) if (not self.zenflow or self.extra_large_param_to_reduce[comm_dtype].dim() == 1) \ + else extra_large_grad_reduc_for_average.data.view_as(extra_large_grad_reduc.transpose(0, 1)) + + self.average_tensor(extra_large_grad_reduc_for_average, comm_dtype) + del self.extra_large_param_to_reduce[comm_dtype] + else: + self.average_tensor(bucket.buffer[bucket.index].narrow(0, 0, bucket.elements), comm_dtype) else: - self.average_tensor(self.ipg_buffer[self.ipg_index]) - else: - self.buffered_reduce_fallback(None, - self.grads_in_ipg_bucket, - elements_per_buffer=self.elements_in_ipg_bucket) + self.buffered_reduce_fallback(None, bucket.grads, comm_dtype, elements_per_buffer=bucket.elements) if self.overlap_comm: stream = self.reduction_stream @@ -1216,39 +1584,43 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: - - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.params_already_reduced[param_id] = True - - if self.partition_gradients: - if not self.is_param_in_current_partition[param_id]: - if self.overlap_comm and self.contiguous_gradients is False: - # Clear grads of other partitions during the next reduction - # to avoid clearing them before the reduction is complete. - if self.previous_reduced_grads is None: - self.previous_reduced_grads = [] - self.previous_reduced_grads.append(param) - else: - param.grad = None #only if self.partition_gradients - elif self.contiguous_gradients: - self.copy_grads_in_partition(param) - else: # zero stage 1 - partition only optimizer state - if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: - self.copy_grads_in_partition(param) - - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.ipg_bucket_has_moe_params = False - self.elements_in_ipg_bucket = 0 + for comm_dtype in dtypes: + bucket = self.ipg_buckets[comm_dtype] + + for group_idx, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[group_idx][param_idx_in_group] + + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {debug_param2name(param)} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.params_already_reduced[param_id] = True + if self.partition_gradients: + if not self.is_param_in_current_partition[param_id]: + if self.overlap_comm and self.contiguous_gradients is False: + # Clear grads of other partitions during the next reduction + # to avoid clearing them before the reduction is complete. + self.previous_reduced_grads[comm_dtype].append(param) + else: + self.clear_grad_attribute(param) + elif self.contiguous_gradients: + self.copy_grads_in_partition(param) + else: # zero stage 1 - partition only optimizer state + if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: + self.copy_grads_in_partition(param) + bucket.clear() ##################################################################### + def process_gradients(self, param, i): + self.setup_buckets() + if self.use_grad_accum_attribute: + self._fill_param_grad_accum_attribute(param) + if self.partition_gradients or self.overlap_comm: + self.reduce_ready_partitions_and_remove_grads(param, i) + def reduce_ready_partitions_and_remove_grads(self, param, i): - if self.partition_gradients or self.is_gradient_accumulation_boundary: + if self.partition_gradients or self.is_gradient_accumulation_boundary or self.zenflow: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) def zero_reduced_gradients(self, partition_id, i): @@ -1310,48 +1682,63 @@ def set_none_gradients_to_zero(self, i, partition_id): for param_id in self.is_grad_computed[i][partition_id]: param = self.param_dict[param_id] if param.grad is None: - param.grad = torch.zero_like(param) + param.grad = torch.zeros_like(param) ######################Reduction Related Methods############################## - def allreduce_bucket(self, bucket, rank=None, log=None): - rank = None + def allreduce_bucket(self, + bucket, + communication_data_type: torch.dtype, + rank=None, + log=None, + divide=True, + process_group=None): + tensor = self.flatten(bucket) + process_group = self.dp_process_group if process_group is None else process_group + tensor_to_allreduce = tensor - if pg_correctness_test: + if pg_correctness_test or self.sequence_parallel_size > 1: communication_data_type = torch.float32 - else: - communication_data_type = self.communication_data_type if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + if divide: + tensor_to_allreduce.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + dist.all_reduce(tensor_to_allreduce, group=process_group) else: - global_rank = dist.get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + global_rank = dist.get_global_rank(process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=process_group) if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): + if rank is None or rank == dist.get_rank(group=process_group): tensor.copy_(tensor_to_allreduce) return tensor def _clear_previous_reduced_grads(self): - if self.previous_reduced_grads is not None: - for param in self.previous_reduced_grads: - param.grad = None # overlap enabled - self.previous_reduced_grads = None + for dtype in self.previous_reduced_grads: + for param in self.previous_reduced_grads[dtype]: + self.clear_grad_attribute(param) + self.previous_reduced_grads[dtype].clear() # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): + def allreduce_and_copy(self, + small_bucket, + communication_data_type: torch.dtype, + rank=None, + log=None, + divide=True, + process_group=None): + process_group = self.dp_process_group if process_group is None else process_group if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions self._clear_previous_reduced_grads() stream = self.reduction_stream @@ -1359,31 +1746,71 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + allreduced = self.allreduce_bucket( + small_bucket, + communication_data_type, + rank=rank, + log=log, + divide=divide, + process_group=process_group, + ) + if self.overlap_comm and not get_accelerator().resolves_data_dependency(): + allreduced.record_stream(stream) if rank is None or rank == dist.get_rank(group=self.dp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - - def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None): + if self.overlap_comm and not get_accelerator().resolves_data_dependency(): + buf.record_stream(stream) + + def allreduce_no_retain( + self, + bucket, + communication_data_type: torch.dtype, + numel_per_bucket=500000000, + rank=None, + log=None, + divide=True, + process_group=None, + ): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) + self.allreduce_and_copy(small_bucket, + communication_data_type, + rank=rank, + log=None, + divide=divide, + process_group=process_group) small_bucket = [] + numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) + self.allreduce_and_copy(small_bucket, + communication_data_type, + rank=rank, + log=log, + divide=divide, + process_group=process_group) # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None): + def buffered_reduce_fallback(self, + rank, + grads, + communication_data_type: torch.dtype, + elements_per_buffer=500000000, + log=None): split_buckets = split_half_float_double(grads) for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log) + self.allreduce_no_retain(bucket, + communication_data_type, + numel_per_bucket=elements_per_buffer, + rank=rank, + log=log) ############################################################################# ############################################################################# @@ -1425,10 +1852,10 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): tensor_size = tensor.numel() - if (current_index >= start_index and current_index < end_index): + if start_index <= current_index < end_index: params_in_partition.append(tensor) - elif start_index > current_index and start_index < (current_index + tensor_size): + elif current_index < start_index < (current_index + tensor_size): params_in_partition.append(tensor) assert (first_offset == 0 @@ -1442,21 +1869,35 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset - def zero_grad(self, set_to_none=False): + def zero_grad(self, set_to_none=True): """ Zero FP16 parameter grads. """ # FP32 grad should never exist. # For speed, set model fp16 grad to None by default + # zero all pointers to grad tensors for group in self.bit16_groups: for p in group: if set_to_none: p.grad = None # epilogue and in step + p.grad_accum = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() + def _clear_param_grad_only(self): + """Clear only param.grad but keep grad_accum intact. + + This is used at the end of the epilogue to ensure safe_get_full_grad() goes + through the proper _hp_mapping path (which does all_reduce for ZeRO-2), while + preserving grad_accum for reentrant checkpointing where gradients need to + accumulate across multiple backward phases. + """ + for group in self.bit16_groups: + for p in group: + p.grad = None + def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ @@ -1483,16 +1924,16 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): Total norm of the parameters (viewed as a single vector). """ norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) + for g in gradients: + all_norms.append(g.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.dp_process_group) # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() + self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.MAX) else: - total_norm = 0.0 # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") for g, p in zip(gradients, params): @@ -1500,33 +1941,80 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 - # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + all_norms.append( + torch.linalg.vector_norm(g.data.double().detach(), + ord=norm_type).to(get_accelerator().current_device_name())) + if len(all_norms) > 0: + total_norm = torch.stack(all_norms).square().sum().float() + else: + total_norm = torch.tensor(0.0, dtype=torch.float32).to(self.device) + # Sum across all model parallel Device. + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) - self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm.pow(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) return total_norm + def get_all_grad_tensors(self, tensor_list, dtype): + all_grad_tensors = [] + for i, tensor in enumerate(tensor_list): + grad_accum = self.get_param_gradient_attribute(tensor) + if grad_accum is None: + grad_accum = torch.zeros_like(tensor, dtype=dtype) + all_grad_tensors.append(grad_accum) + return all_grad_tensors + # creates a flat fused tensor from the tensor list starting at the first_offset # in the first tensor of the list. If there are not enough elements in the tensor # list then the flat tensor will be padded with zeros - def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False): + def get_flat_partition(self, + tensor_list, + first_offset, + partition_size, + dtype, + device, + param_group_idx, + return_tensor_list=False): + if len(tensor_list) == 0: + # This condition can fire when we have small parameteters and many ranks. + zero_buffer = torch.zeros(int(partition_size), dtype=dtype, device=device) + if return_tensor_list: + return [zero_buffer] + return zero_buffer + flat_tensor_list = [] current_size = 0 + # find the flatten copy in the optimizer's state + flatten_copy = self.optimizer.param_groups[param_group_idx]['params'][0] + if (not self.optimizer.state[flatten_copy]) and getattr( + tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): + self.optimizer.state[flatten_copy] = {} + if "momentum_buffer" not in self.optimizer.state[flatten_copy] and getattr( + tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): + # need to check the total # of elements in the parameters in this group and this partition + total_size = sum([t.numel() for t in tensor_list]) + flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)] + self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list) + + buffer_idx = 0 for i, tensor in enumerate(tensor_list): - if tensor.grad is None: - tensor.grad = torch.zeros_like(tensor) - - tensor = tensor.grad + grad_accum = self.all_grad_tensors[param_group_idx][i] + if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): + assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" + buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, + tensor.numel()).view(tensor.size()) + ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram') + grad_accum = muon_update(grad_accum, + buffer, + self.optimizer.param_groups[param_group_idx]['momentum'], + ns_method=ns_method) + tensor = grad_accum num_elements = tensor.numel() + buffer_idx += num_elements tensor_offset = 0 # we need to offset to get to the right element @@ -1559,31 +2047,12 @@ def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, d def free_grad_in_param_list(self, param_list): for p in param_list: p.grad = None # in step + p.grad_accum = None def reset_cpu_buffers(self): self.norm_for_param_grads = {} self.local_overflow = False - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() - def set_lr(self, lr): """Set the learning rate.""" for param_group in self.optimizer.param_groups: @@ -1603,18 +2072,17 @@ def scaled_global_norm(self, norm_type=2): assert norm_type == 2, "only L2 norm supported" norm_groups = [] for i, group in enumerate(self.bit16_groups): - partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: - norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])) - single_grad_partition = self.single_partition_of_fp32_groups[i].grad + norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) + norm_groups.append(norm) else: norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) if self.has_moe_layers: self._average_expert_grad_norms(norm_groups) - # note that the get_global_norm function only supports l2 norm - return get_global_norm(norm_list=norm_groups) + # calculating L2 norm + return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type) def get_bit16_param_group(self, group_no): bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] @@ -1624,29 +2092,36 @@ def get_bit16_param_group(self, group_no): def _optimizer_step(self, group_no): original_param_groups = self.optimizer.param_groups self.optimizer.param_groups = [original_param_groups[group_no]] - # Disabling this as the C++ side copy & synchornize is not working correctly + # Disabling this as the C++ side copy & synchronize is not working correctly #from deepspeed.ops.adam import DeepSpeedCPUAdam #if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) #else: # self.optimizer.step() - self.optimizer.step() + if self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.step(self.optimizer) + self.torch_autocast_gradscaler.update() + # TODO: Remove zenflow-specific call from vanilla ZeroOptimizer + elif self.zenflow: + self.zenflow_cpu_optimizer_step(group_no) + else: + self.optimizer.step() self.optimizer.param_groups = original_param_groups + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + def step(self, closure=None): """ Not supporting closure. """ - self.micro_step_id = -1 + self.micro_step_id = INITIAL_MICRO_STEP_ID - see_memory_usage(f"In step before checking overflow") + see_memory_usage("In step before checking overflow") # First compute norm for all group so we know if there is overflow - self.check_overflow() - OPTIMIZER_ALLGATHER = 'optimizer_allgather' - OPTIMIZER_GRADIENTS = 'optimizer_gradients' - OPTIMIZER_STEP = 'optimizer_step' - timer_names = [OPTIMIZER_ALLGATHER, OPTIMIZER_GRADIENTS, OPTIMIZER_STEP] + if self.check_grad_overflow: + self.check_overflow(partition_gradients=self.partition_gradients) prev_scale = self.loss_scale self._update_scale(self.overflow) @@ -1656,29 +2131,33 @@ def step(self, closure=None): if self.cpu_offload: self.reset_cpu_buffers() else: - self.averaged_gradients = {} + for k in self.averaged_gradients.keys(): + self.averaged_gradients[k] = None + self.all_grad_tensors[k] = None see_memory_usage('After overflow after clearing gradients') - self.start_timers(timer_names) - self.stop_timers(timer_names) + for timer in OPTIMIZER_TIMERS: + self.timers(timer).start() + self.timers(timer).stop() return - # Step 1:- Calculate gradient norm using fp-16 grads + # Step 1:- Calculate gradient norm using bit-16 grads see_memory_usage('Before norm calculation') scaled_global_grad_norm = self.scaled_global_norm() self._global_grad_norm = scaled_global_grad_norm / prev_scale - see_memory_usage('After norm before optimizer') + # Step 2:- run optimizer and upscaling simultaneously for i, group in enumerate(self.bit16_groups): - self.start_timers([OPTIMIZER_GRADIENTS]) + self.timers(OPTIMIZER_GRADIENTS_TIMER).start() partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: single_grad_partition = self.single_partition_of_fp32_groups[i].grad self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) - self.stop_timers([OPTIMIZER_GRADIENTS]) - self.start_timers([OPTIMIZER_STEP]) + + self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() + self.timers(OPTIMIZER_STEP_TIMER).start() self._optimizer_step(i) # Disabled, this is not currently working @@ -1689,9 +2168,11 @@ def step(self, closure=None): # bit16_partitions[partition_id].data.copy_(fp32_partition.data) bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.single_partition_of_fp32_groups[i] - bit16_partitions[partition_id].data.copy_(fp32_partition.data) + bit16_partition_buffer = self.param_buffer_of_bit16_for_cpu_offload_groups[i] + bit16_partition_buffer.data.copy_(fp32_partition.data) + bit16_partitions[partition_id].data.copy_(bit16_partition_buffer.data, non_blocking=True) - self.stop_timers([OPTIMIZER_STEP]) + self.timers(OPTIMIZER_STEP_TIMER).stop() else: # free gradients for all the parameters that are not updated by this process(ZeRO stage2) self.free_grad_in_param_list(self.params_not_in_partition[i]) @@ -1714,12 +2195,13 @@ def step(self, closure=None): self.free_grad_in_param_list(self.params_in_partition[i]) self.averaged_gradients[i] = None - + self.all_grad_tensors[i] = None self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) - self.stop_timers([OPTIMIZER_GRADIENTS]) + + self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() # Step 3:- run the optimizer if no offloading - self.start_timers([OPTIMIZER_STEP]) + self.timers(OPTIMIZER_STEP_TIMER).start() self._optimizer_step(i) # Step 4:- get rid of the fp32 gradients. Not needed anymore self.single_partition_of_fp32_groups[i].grad = None @@ -1727,27 +2209,27 @@ def step(self, closure=None): bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.single_partition_of_fp32_groups[i] bit16_partitions[partition_id].data.copy_(fp32_partition.data) - self.stop_timers([OPTIMIZER_STEP]) + self.timers(OPTIMIZER_STEP_TIMER).stop() see_memory_usage('After optimizer before all-gather') if self.cpu_offload: self.reset_cpu_buffers() - self.start_timers([OPTIMIZER_ALLGATHER]) + self.timers(OPTIMIZER_ALLGATHER_TIMER).start() # Gather the updated weights from everyone. # Then all partitions of the model parameters are updated and ready for next round forward. - all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) - - self.stop_timers([OPTIMIZER_ALLGATHER]) + self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() # TODO: we probably don't need this? just to be safe for i in range(len(self.bit16_groups)): self._update_model_bit16_weights(i) - self.log_timers(timer_names) + self.timers.log(OPTIMIZER_TIMERS) see_memory_usage('After zero_optimizer step') return @@ -1758,11 +2240,9 @@ def update_lp_params(self): zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) - all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) @@ -1770,12 +2250,11 @@ def update_lp_params(self): def _average_expert_grad_norms(self, norm_groups): for i, norm in enumerate(norm_groups): if self.is_moe_param_group[i]: - scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.real_dp_process_group[i])) - scaled_norm_tensor = torch.tensor(scaled_norm, - device=get_accelerator().device_name(), - dtype=torch.float) + scaled_norm_tensor = norm * 1.0 / dist.get_world_size(group=self.real_dp_process_group[i]) + if self.device == 'cpu': + scaled_norm_tensor = scaled_norm_tensor.to(get_accelerator().current_device_name()) dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) - norm_groups[i] = scaled_norm_tensor.item() + norm_groups[i] = scaled_norm_tensor.to(self.device) def unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group @@ -1783,8 +2262,8 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale for grad in grad_groups_flat: if isinstance(grad, list): @@ -1798,36 +2277,27 @@ def _check_overflow(self, partition_gradients=True): self.overflow = self.has_overflow(partition_gradients) # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params, is_grad_list=False): + def has_overflow_serial(self, params): + invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name()) for p in params: - if p.grad is not None and self._has_inf_or_nan(p.grad.data): - return True - - return False + if p.grad is not None: + invalid_grad_count += self._has_inf_or_nan(p.grad) + return invalid_grad_count.bool() def has_overflow_partitioned_grads_serial(self): + invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name()) for i in range(len(self.bit16_groups)): for j, grad in enumerate(self.averaged_gradients[i]): - if grad is not None and self._has_inf_or_nan(grad.data, j): - return True - return False + if grad is not None: + invalid_grad_count += self._has_inf_or_nan(grad) + return invalid_grad_count.bool() def has_overflow(self, partition_gradients=True): - if partition_gradients: - overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) - '''This will capture overflow across all data parallel and expert parallel process - Since expert parallel process are a subset of data parallel process''' - dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) - - else: - params = [] - for group in self.bit16_groups: - for param in group: - params.append(param) + overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() + overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to( + get_accelerator().current_device_name()) - overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = get_accelerator().ByteTensor([overflow]) + dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs @@ -1839,55 +2309,41 @@ def has_overflow(self, partition_gradients=True): # `x` is a torch.Tensor @staticmethod def _has_inf_or_nan(x, j=None): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - self.micro_step_id += 1 - - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 - - if self.custom_loss_scaler: - scaled_loss = self.external_loss_scale * loss - scaled_loss.backward() - else: - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + float_x = x.float() + nan = float_x.isnan() + inf = float_x.isinf() + inf_or_nan = nan.logical_or(inf) + return inf_or_nan.float().max() + + def setup_buckets(self): + if not self.ready_for_gradients: + self.micro_step_id += 1 + + if self.contiguous_gradients: + for _, bucket in self.ipg_buckets.items(): + bucket.buffer.clear() + + # Buffer's dtype is the same as the dtype of optimizer, not dtype for autocast + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + bucket.buffer.append(buf_0) + bucket.index = 0 + + # Use double buffers to avoid data access conflict when overlap_comm is enabled. + if self.overlap_comm: + for _, bucket in self.ipg_buckets.items(): + buf_1 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + bucket.buffer.append(buf_1) + + self.ready_for_gradients = True + + def backward_epilogue(self, *args, **kwargs): + # Only for Stage 1, Mode 2 + if self.use_grad_accum_attribute: + self.fill_grad_accum_attribute() def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) @@ -1941,7 +2397,7 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - if torch.is_tensor(value): + if torch.is_tensor(value) and value.dim() > 0: lean_length = value.numel() - padding lean_state[key] = value[:lean_length] else: @@ -1972,13 +2428,19 @@ def state_dict(self): torch.save(checkpoint, "saved.pth") """ state_dict = {} - state_dict['loss_scaler'] = self.loss_scaler + state_dict[LOSS_SCALER] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[CLIP_GRAD] = self.clip_grad if self.elastic_checkpoint: state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() + + if "step" in self.optimizer.param_groups[0]: + # Assuming "step" is the only item that changes through training iterations + assert all(group["step"] == self.optimizer.param_groups[0]["step"] + for group in self.optimizer.param_groups), "All param groups must have the same step value" + state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0]["step"] else: state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() @@ -1994,6 +2456,10 @@ def state_dict(self): state_dict[DS_VERSION] = version state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings + autotp_uc_info = self._get_universal_checkpoint_info() + if autotp_uc_info is not None: + state_dict[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info + return state_dict # Restore base optimizer fp32 weights from elastic checkpoint by: @@ -2032,7 +2498,7 @@ def refresh_fp32_params(self): # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) - alignment = dist.get_world_size(group=self.real_dp_process_group[group_id]) + alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[group_id]) if torch.is_tensor(all_partition_states[0]): flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id) @@ -2041,19 +2507,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group # Assume non-tensor states are not partitioned and equal across ranks, so return first one return all_partition_states[0] - def _restore_base_optimizer_state(self, base_optimizer_group_states): + def _restore_step_from_elastic_checkpoint(self, all_state_dict): + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + for sd in all_state_dict), "State dicts of all partitions must have the same step value" + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): if type(base_optimizer_group_states) == dict: base_optimizer_group_states = base_optimizer_group_states['state'] + + saved_keys = base_optimizer_group_states[0].keys() + for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - dst_tensor = self.optimizer.state[p][key] - src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) - self.optimizer.state[p][key].data.copy_(src_tensor.data) + padding = 0 if group_paddings is None else group_paddings[i] + for key in saved_keys: + saved = base_optimizer_group_states[i][key] + + if torch.is_tensor(saved): + if key in self.optimizer.state[p]: + dst_tensor = self.optimizer.state[p][key] + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) + self.optimizer.state[p][key].data.copy_(src_tensor.data) + else: + self.optimizer.state[p][key] = _pad_tensor_by_size( + saved, padding, torch.float32, + torch.device('cpu') if self.cpu_offload else self.device) else: self.optimizer.state[p][key] = saved + for param_group in self.optimizer.param_groups: + param_group['step'] = base_optimizer_state_step + def get_ep_ranks(self, rank=0, group_name=None): from deepspeed.utils import groups expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) @@ -2081,37 +2567,41 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) base_optimizer_group_states.append(partition_states) - self._restore_base_optimizer_state(base_optimizer_group_states) + self._restore_base_optimizer_state(base_optimizer_group_states, + self._restore_step_from_elastic_checkpoint(all_state_dict), None) def load_state_dict(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, - checkpoint_folder=None): + checkpoint_folder=None, + load_serial=None, + param_shapes=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self._load_hp_checkpoint_state(checkpoint_folder) + self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) - @property - def param_groups(self): - """Forward the wrapped optimizer's parameters.""" - return self.optimizer.param_groups + def _load_global_state(self, sd): + self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) + self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) + self.overflow = sd.get('overflow', self.overflow) + self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad) - def _load_hp_checkpoint_state(self, checkpoint_dir): - checkpoint_dir = os.path.join(checkpoint_dir, "zero") - tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_slice_parallel_world_size() + ckpt_version = sd.get(DS_VERSION, False) + assert ckpt_version, "Empty ds_version in checkpoint, not clear how to proceed" + ckpt_version = pkg_version.parse(ckpt_version) - for i, _ in enumerate(self.optimizer.param_groups): - for lp in self.bit16_groups[i]: - if lp._hp_mapping is not None: - #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") - lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, - tp_world_size) + # zero stage 1 mode + if not self.partition_gradients: + required_version = pkg_version.parse("0.3.17") + error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ + "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ + "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." + assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): r"""Loading ZeRO checkpoint @@ -2143,22 +2633,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l # I think it should actually be ok to reload the optimizer before the model. dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] - self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler) - self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale) - self.overflow = current_rank_sd.get('overflow', self.overflow) - self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) - - ckpt_version = current_rank_sd.get(DS_VERSION, False) - assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" - ckpt_version = pkg_version.parse(ckpt_version) - - # zero stage 1 mode - if not self.partition_gradients: - required_version = pkg_version.parse("0.3.17") - error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ - "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ - "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." - assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" + self._load_global_state(current_rank_sd) ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) @@ -2182,7 +2657,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._restore_elastic_base_optimizer_state(state_dict_list) else: # loading an elastic checkpoint into rigid exec - self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], + current_rank_sd[GROUP_PADDINGS]) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. @@ -2216,6 +2693,203 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l if load_optimizer_states: self._link_all_hp_params() + def _clear_hp_buffer_references(self): + """ + Clear all references that might prevent GPU memory release when offloading HP params. + This includes gradient views, HP mapping fragments, and optimizer state fragments. + """ + # Clear gradient references in offload_gradient_dict + if hasattr(self, 'offload_gradient_dict'): + for param_group_index in self.offload_gradient_dict: + if self.offload_gradient_dict[param_group_index] is not None: + self.offload_gradient_dict[param_group_index].clear() + + # Clear gradient buffers attached to HP params + for i, buf in enumerate(self.single_partition_of_fp32_groups): + if hasattr(buf, 'grad') and buf.grad is not None: + buf.grad = None + + # Clear HP mapping references in model parameters + for i, param_group in enumerate(self.bit16_groups): + for param in param_group: + if hasattr(param, '_hp_mapping') and param._hp_mapping is not None: + # Clear the fragment references that point to GPU buffers + if hasattr(param._hp_mapping, 'hp_fragment'): + param._hp_mapping.hp_fragment = None + if hasattr(param._hp_mapping, 'optim_fragment') and param._hp_mapping.optim_fragment is not None: + param._hp_mapping.optim_fragment.clear() + + # Force garbage collection to release references + gc.collect() + + def _clear_lp_params_references(self): + """ + Clear all references that might prevent GPU memory release when offloading LP params. + This includes HP mapping lp_fragment references and completely nullifying _hp_mapping. + """ + # Completely clear HP mapping to break all references to GPU tensors + for i, param_group in enumerate(self.bit16_groups): + for param in param_group: + if hasattr(param, '_hp_mapping') and param._hp_mapping is not None: + # Completely nullify _hp_mapping to break all references + param._hp_mapping = None + + # Force garbage collection to release references + gc.collect() + + def offload_states(self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False): + """ + Offload optimizer states from GPU to the specified device (typically CPU). + + Args: + include (Container[OffloadStateTypeEnum], optional): + Collection of state types to offload. If None, offloads all supported states. + Defaults to None. + device (OffloadDeviceEnum, optional): + Target device for offloading. Defaults to OffloadDeviceEnum.cpu. + pin_memory (bool, optional): + If True, pins data in memory before moving to CPU. + This can accelerate subsequent CPU-to-GPU transfers. Defaults to True. + non_blocking (bool, optional): + If True, attempts to perform offload operations asynchronously. Defaults to False. + """ + device = device.value + + def needs_offload(target): + return target not in self.offloaded_states and (include is None or target in include) + + # Offload FP32 Master Parameters (HP Params) + if needs_offload(OffloadStateTypeEnum.hp_params): + self._clear_hp_buffer_references() + if pin_memory: + if not hasattr(self, "hp_params_pin_buffers"): + self.hp_params_pin_buffers = [ + torch.empty_like(t, device=device).pin_memory() for t in self.single_partition_of_fp32_groups + ] + for src_tensor, dest_buf in zip(self.single_partition_of_fp32_groups, self.hp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.single_partition_of_fp32_groups: + buf.data = buf.data.to(device, non_blocking=non_blocking) + + self.offloaded_states.add(OffloadStateTypeEnum.hp_params) + + # Offload FP16/BF16 Model Parameters (LP Params) + if needs_offload(OffloadStateTypeEnum.lp_params): + self._clear_lp_params_references() + for group in self.bit16_groups: + for param in group: + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + for group_partitions in self.parallel_partitioned_bit16_groups: + group_partitions.clear() + + if pin_memory: + if not hasattr(self, "lp_params_pin_buffers"): + self.lp_params_pin_buffers = [ + torch.empty_like(t, device=device).pin_memory() for t in self.bit16_groups_flat + ] + for src_tensor, dest_buf in zip(self.bit16_groups_flat, self.lp_params_pin_buffers): + dest_buf.copy_(src_tensor, non_blocking=non_blocking) + src_tensor.data = dest_buf + else: + for buf in self.bit16_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + for i in range(len(self.bit16_groups)): + self._update_model_bit16_weights(i) + + self.offloaded_states.add(OffloadStateTypeEnum.lp_params) + + # Offload Partitioned Gradients (LP Grads) + if needs_offload(OffloadStateTypeEnum.lp_grads): + for group_idx in self.averaged_gradients: + grad_list = self.averaged_gradients.get(group_idx) + if grad_list is not None: + for grad_tensor in grad_list: + if grad_tensor is not None and grad_tensor.device.type != device: + # Key insight: We only move the underlying data storage (`.data`) to the target device. + # The Python tensor object and the dictionary structure (`self.averaged_gradients`) + # remain intact, preserving the references needed for reloading. + grad_tensor.data = grad_tensor.data.to(device, non_blocking=non_blocking) + + self.offloaded_states.add(OffloadStateTypeEnum.lp_grads) + + # Offload Optimizer States + if needs_offload(OffloadStateTypeEnum.optim_states): + offload_optimizer_states(self.optimizer, device, pin_memory=pin_memory, non_blocking=non_blocking) + self.offloaded_states.add(OffloadStateTypeEnum.optim_states) + + if not non_blocking: + if get_accelerator().is_available(): + get_accelerator().synchronize() + + gc.collect() + if get_accelerator().is_available(): + get_accelerator().empty_cache() + + def reload_states(self, non_blocking: bool = False): + """ + Reload previously offloaded optimizer states from CPU back to GPU. + + Args: + non_blocking (bool, optional): + If True, attempts to perform reload operations asynchronously. Defaults to False. + """ + device = get_accelerator().current_device_name() + + # Reload FP32 Master Parameters (HP Params) + if OffloadStateTypeEnum.hp_params in self.offloaded_states: + for buf in self.single_partition_of_fp32_groups: + buf.data = buf.data.to(device, non_blocking=non_blocking) + if hasattr(self, "hp_params_pin_buffers"): + del self.hp_params_pin_buffers + self._link_all_hp_params() + self.offloaded_states.remove(OffloadStateTypeEnum.hp_params) + + # Reload FP16/BF16 Model Parameters (LP Params) + if OffloadStateTypeEnum.lp_params in self.offloaded_states: + for buf in self.bit16_groups_flat: + buf.data = buf.data.to(device, non_blocking=non_blocking) + + # Reconstruct the parallel partitions now that the flat buffer is back on GPU. + self.parallel_partitioned_bit16_groups.clear() + for i, flat_group in enumerate(self.bit16_groups_flat): + data_parallel_partitions = self.get_data_parallel_partitions(flat_group, i) + self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) + + for i in range(len(self.bit16_groups)): + self._update_model_bit16_weights(i) + + if hasattr(self, "lp_params_pin_buffers"): + del self.lp_params_pin_buffers + self._link_all_hp_params() + self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) + + # Reload Partitioned Gradients (LP Grads) + if OffloadStateTypeEnum.lp_grads in self.offloaded_states: + # Since we preserved the `self.averaged_gradients` structure during offload, + # we can now iterate through it again. The tensors within currently point to CPU data. + for group_idx in self.averaged_gradients: + grad_list = self.averaged_gradients.get(group_idx) + if grad_list is not None: + for grad_tensor in grad_list: + if grad_tensor is not None and grad_tensor.device.type != device: + grad_tensor.data = grad_tensor.data.to(device, non_blocking=non_blocking) + + self.offloaded_states.remove(OffloadStateTypeEnum.lp_grads) + + # Reload Optimizer States + if OffloadStateTypeEnum.optim_states in self.offloaded_states: + reload_optimizer_states(self.optimizer, device, non_blocking=non_blocking) + self.offloaded_states.remove(OffloadStateTypeEnum.optim_states) + + if non_blocking: + get_accelerator().synchronize() + def _handle_overflow(cpu_sum, x, i): import math @@ -2241,7 +2915,9 @@ def estimate_zero2_model_states_mem_needs(total_params, gpu_mem = 2 * total_params cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor else: - gpu_mem = 4 * total_params + int(16 * total_params / total_gpus) + # GPU's total_params multipliers: 2 = params_16bit, + # 18 = 2_grads_16bit + 4_grads_32bit + 4_params_32bit + 8_optimizer_states_32bit(momentum and variance) + gpu_mem = 2 * total_params + int(18 * total_params / total_gpus) cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor return int(cpu_mem), int(gpu_mem) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 81a301c8d782..139419563352 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,15 +4,22 @@ # DeepSpeed Team import os -from typing import List +import gc +from typing import List, Tuple import torch from deepspeed import comm as dist from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import get_only_unique_item + +# ensure we only warn once, otherwise every iteration will trigger a warning +warned = False def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -35,7 +42,17 @@ class ZeRORuntimeException(Exception): pass -ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam] +ZERO_SUPPORTED_OPTIMIZERS = [ + torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad, + DeepSpeedCPULion, FusedLion +] + +# Add MuonWithAuxAdam to supported list if muon is installed +try: + from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam + ZERO_SUPPORTED_OPTIMIZERS.append(MuonWithAuxAdam) +except ImportError: + pass # Add apex FusedAdam to supported list if apex is installed try: @@ -52,6 +69,23 @@ def is_zero_supported_optimizer(optimizer): return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS +@instrument_w_nvtx +def assert_lst_len_same_as_other_ranks(lst: List[int]) -> None: + rank0_len_tensor = torch.tensor( + len(lst) if dist.get_rank() == 0 else -1, + dtype=int, + device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])), + requires_grad=False, + ) + local_list_length = len(lst) + dist.broadcast(rank0_len_tensor, src=0, async_op=False) + rank0_list_length = rank0_len_tensor.cpu().item() + if rank0_list_length != local_list_length: + raise RuntimeError(f"Detected a disagreement on list length between rank0 and rank{dist.get_rank()}: " + f"\n rank0: {rank0_list_length} " + f"\n rank{dist.get_rank()}: {local_list_length}") + + def get_lst_from_rank0(lst: List[int]) -> None: """ NOTE: creates both communication and synchronization overhead so should be used @@ -60,13 +94,12 @@ def get_lst_from_rank0(lst: List[int]) -> None: lst_tensor = torch.tensor( lst if dist.get_rank() == 0 else [-1] * len(lst), dtype=int, - # device=get_accelerator().current_device_name(), device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])), requires_grad=False, ) dist.broadcast(lst_tensor, src=0, async_op=False) - return list(lst_tensor.cpu().numpy()) + return [t.item() for t in lst_tensor.cpu()] @instrument_w_nvtx @@ -78,7 +111,125 @@ def assert_ints_same_as_other_ranks(ints: List[int]) -> None: takes a list of ints from each rank and ensures that they are the same across ranks, throwing an exception if they are not. """ + assert_lst_len_same_as_other_ranks(ints) rank0_ints = get_lst_from_rank0(ints) if ints != rank0_ints: - raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " - f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") + raise RuntimeError(f"Detected a disagreement on list contents between rank0 and rank{dist.get_rank()}: " + f"\n list length: {len(ints)}" + f"\n rank0: {rank0_ints} " + f"\n rank{dist.get_rank()}: {ints}") + + +def is_builtin_type(obj): + # https://stackoverflow.com/a/17795199 + return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" + + +def isinstance_namedtuple(obj: object) -> bool: + """ + Is this an instance of namedtuple/NamedTuple? + From: https://stackoverflow.com/a/62692640 + + Args: + obj (object): An object. + + Returns: + bool: True if namedtuple/NamedTuple else False. + """ + return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') + + +def is_zero_param(parameter): + if not torch.is_tensor(parameter): + return False + return hasattr(parameter, 'ds_id') + + +def apply_to_tensors_only(function, value, warning_msg_fn=None): + """ + Apply `function` to every Tensor in `value`. + + Args: + functional: The function class to apply. + value (Any): Target object to apply `function` to. + + Returns: + Any: Output of `function`. + """ + if isinstance(value, (tuple, list)): + touched_outputs = [] + for elem in value: + touched_output = apply_to_tensors_only(function, elem) + touched_outputs.append(touched_output) + + if isinstance_namedtuple(value): + # namedtuples require a slightly different syntax. + return value.__class__(*touched_outputs) + + return value.__class__(touched_outputs) + elif isinstance(value, dict): + # apply inplace to avoid recreating dict inherited objects + for key in value.keys(): + value[key] = apply_to_tensors_only(function, value[key]) + return value + + elif isinstance(value, torch.Tensor): + # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter + touched_output = function(value) + + # restore zero param attributes if those get stripped by `backward_function` + if not is_zero_param(touched_output) and is_zero_param(value): + touched_output.ds_param_alias = value + + return touched_output + else: + if not is_builtin_type(value): + global warned + if warning_msg_fn and not warned and dist.get_rank() == 0: + logger.warning(warning_msg_fn(value)) + warned = True + return value + + +def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]: + tensor_infos: List[Tuple[torch.Tensor, int, int]] = [] + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + offset += tensor_numel + + return tensor_infos + + +def defragment(tensors: List[torch.Tensor]) -> torch.Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[torch.Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor, offset, tensor_numel in tensor_infos: + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + gc.collect() + get_accelerator().empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py new file mode 100644 index 000000000000..b76f944eff79 --- /dev/null +++ b/deepspeed/sequence/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import (ModalityFusionSPAdapter, LlavaFusionAdapter, InternVLFusionAdapter, + Qwen2VLFusionAdapter) +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py new file mode 100644 index 000000000000..2bb17413512b --- /dev/null +++ b/deepspeed/sequence/auto_sp.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +AutoSP: one-call sequence parallelism for multimodal models. + +Usage:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.utils import groups + + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, ...) + sp_group = groups._get_sequence_parallel_group() + model = auto_wrap_model_for_sp(model, process_group=sp_group) + +``auto_wrap_model_for_sp`` scans the model and injects: + +* :class:`~deepspeed.sequence.autosp_vit.UlyssesSPViTAttention` + for ViT encoder attention layers. +* a warning for LLM decoder attention layers: HuggingFace-style + ``hidden_states`` attention is incompatible with + :class:`~deepspeed.sequence.layer.DistributedAttention`'s Q/K/V interface; + configure LLM sequence parallelism manually. + +The vision-language projection layer (Phase 2) is detected and a warning is +emitted; wrap it manually with +:class:`~deepspeed.sequence.autosp_fusion.ModalityFusionSPAdapter` until +Phase 2 automation is implemented. +""" + +import logging + +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import detect_model_sp_info, _VIT_HAS_CLS_TOKEN +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention + +logger = logging.getLogger(__name__) + + +def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: + """Inject sequence-parallel wrappers into *model* in-place. + + Scans the model's named modules and replaces recognised attention layers + with their SP-aware equivalents: + + * ViT attention → :class:`UlyssesSPViTAttention` + * LLM attention → warning only (HuggingFace ``hidden_states`` interface + is incompatible with :class:`DistributedAttention`'s Q/K/V interface) + + The function modifies *model* in-place **and** returns it for convenience. + + Parameters + ---------- + model: + The multimodal model to wrap. Must be on the correct device before + calling this function. + process_group: + The sequence-parallel process group (from + ``groups._get_sequence_parallel_group()``). + + Returns + ------- + The same *model* object with attention layers replaced. + + Raises + ------ + ValueError + If no recognisable attention modules are found. Register the model's + attention class names in ``autosp_detector._VIT_ATTN_CLASSNAMES`` or + ``_LLM_ATTN_CLASSNAMES`` to fix this. + """ + info = detect_model_sp_info(model) + + if not info.vit_attn_modules and not info.llm_attn_modules: + raise ValueError("auto_wrap_model_for_sp: no recognisable attention modules found. " + "Add the model's attention class name(s) to " + "_VIT_ATTN_CLASSNAMES or _LLM_ATTN_CLASSNAMES in " + "deepspeed/sequence/autosp_detector.py and retry.") + + # ------------------------------------------------------------------ + # Wrap ViT encoder attention layers + # ------------------------------------------------------------------ + for name, module in info.vit_attn_modules: + cls_name = type(module).__name__ + # Look up whether this ViT architecture uses a CLS token; default True + # (safe fallback) for unknown classes not yet in the registry. + has_cls = _VIT_HAS_CLS_TOKEN.get(cls_name, True) + wrapped = UlyssesSPViTAttention(module, process_group, has_cls_token=has_cls) + _set_module_by_name(model, name, wrapped) + logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention (has_cls_token=%s)", name, has_cls) + + logger.info("AutoSP: wrapped %d ViT attention layer(s).", len(info.vit_attn_modules)) + + # ------------------------------------------------------------------ + # LLM decoder attention layers — warn, do not auto-wrap + # ------------------------------------------------------------------ + # DistributedAttention expects a Megatron-style (query, key, value) + # interface, but every class in _LLM_ATTN_CLASSNAMES uses the + # HuggingFace hidden_states interface. Wrapping them silently would + # produce incorrect behaviour at the first forward pass. Emit a + # per-layer warning so the user can configure SP manually. + for name, module in info.llm_attn_modules: + logger.warning( + "AutoSP: LLM attention '%s' (class %s) uses a HuggingFace hidden_states " + "interface that is incompatible with DistributedAttention's Q/K/V interface. " + "Skipping auto-wrap. Configure sequence parallelism for this layer manually.", name, + type(module).__name__) + + if info.llm_attn_modules: + logger.info("AutoSP: found %d LLM attention layer(s); skipped wrapping (see warnings above).", + len(info.llm_attn_modules)) + + # ------------------------------------------------------------------ + # Warn about the vision projection layer (Phase 2) + # ------------------------------------------------------------------ + if info.vision_projection_module is not None: + proj_name, _ = info.vision_projection_module + logger.warning( + "AutoSP detected vision projection layer '%s'. " + "ModalityFusionSPAdapter (Phase 2) is not yet automated. " + "Wrap this layer manually with ModalityFusionSPAdapter if you " + "need correct cross-modal sequence gather/scatter.", proj_name) + + return model + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _set_module_by_name(model: nn.Module, dotted_name: str, new_module: nn.Module) -> None: + """Replace the submodule at *dotted_name* with *new_module* in-place.""" + parts = dotted_name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) diff --git a/deepspeed/sequence/autosp_detector.py b/deepspeed/sequence/autosp_detector.py new file mode 100644 index 000000000000..be9aab5b320d --- /dev/null +++ b/deepspeed/sequence/autosp_detector.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Automatically detect ViT encoder and LLM decoder attention modules in +multimodal models to guide AutoSP injection. + +Extend _VIT_ATTN_CLASSNAMES / _LLM_ATTN_CLASSNAMES when adding support for +new model architectures. +""" + +import torch.nn as nn +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +# --------------------------------------------------------------------------- +# Architecture registry +# --------------------------------------------------------------------------- + +# Known ViT attention class names (HuggingFace transformers naming) +_VIT_ATTN_CLASSNAMES = { + "ViTAttention", + "CLIPAttention", + "SiglipAttention", + "InternVisionAttention", + "Qwen2VLVisionAttention", + "Idefics2VisionAttention", + "PaliGemmaVisionAttention", +} + +# Whether each known ViT class uses a prepended CLS token. +# CLS is replicated on every rank and is NOT sharded across the sequence. +# Defaults to True for unknown classes (safe fallback). +_VIT_HAS_CLS_TOKEN = { + "ViTAttention": True, + "CLIPAttention": True, + "SiglipAttention": False, + "InternVisionAttention": False, + "Qwen2VLVisionAttention": False, + "Idefics2VisionAttention": False, + "PaliGemmaVisionAttention": False, +} + +# Known LLM decoder attention class names +_LLM_ATTN_CLASSNAMES = { + "LlamaAttention", + "MistralAttention", + "Qwen2Attention", + "InternLM2Attention", + "GemmaAttention", + "Phi3Attention", + "GPTNeoXAttention", + "FalconAttention", + "MptAttention", +} + +# Common attribute names that hold the vision-language projection layer +_VISION_PROJ_KEYWORDS = ( + "visual_projection", + "mm_projector", + "vision_proj", + "multi_modal_projector", + "img_projection", +) + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class SPModelInfo: + """Holds the detection results for a multimodal model.""" + + # (dotted_name, module) pairs for ViT attention layers + vit_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) pairs for LLM decoder attention layers + llm_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) for the outermost vision-language projection layer + vision_projection_module: Optional[Tuple[str, nn.Module]] = None + + +# --------------------------------------------------------------------------- +# Detection logic +# --------------------------------------------------------------------------- + + +def detect_model_sp_info(model: nn.Module) -> SPModelInfo: + """Recursively scan *model* and return an :class:`SPModelInfo`. + + The function identifies: + * ViT encoder attention layers → wrapped with :class:`UlyssesSPViTAttention` + * LLM decoder attention layers → wrapped with :class:`DistributedAttention` + * The vision-language projection layer → wrapped with + :class:`ModalityFusionSPAdapter` (Phase 2) + + To add support for a new architecture, simply register its attention class + names in ``_VIT_ATTN_CLASSNAMES`` or ``_LLM_ATTN_CLASSNAMES``. + """ + info = SPModelInfo() + for name, module in model.named_modules(): + cls_name = type(module).__name__ + if cls_name in _VIT_ATTN_CLASSNAMES: + info.vit_attn_modules.append((name, module)) + elif cls_name in _LLM_ATTN_CLASSNAMES: + info.llm_attn_modules.append((name, module)) + + # Record only the first (outermost) match to avoid double-wrapping + # nested projection modules. + if info.vision_projection_module is None: + if any(kw in name for kw in _VISION_PROJ_KEYWORDS): + info.vision_projection_module = (name, module) + + return info diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py new file mode 100644 index 000000000000..3608cfe9a64f --- /dev/null +++ b/deepspeed/sequence/autosp_fusion.py @@ -0,0 +1,366 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +ModalityFusionSPAdapter — Phase 2 + +Handles the sequence scatter/gather at the vision-language boundary so that +the LLM decoder's :class:`~deepspeed.sequence.layer.DistributedAttention` +receives a uniformly sharded fused (visual + text) sequence. + +Workflow +-------- +:: + + [visual tokens, sharded] ──all-gather──► [visual tokens, full] + │ + splice into text + │ + [fused embeds, full] ──scatter──► [fused embeds, sharded per rank] + │ + LLM decoder (SP-aware) + +Usage +----- +After calling :func:`~deepspeed.sequence.auto_sp.auto_wrap_model_for_sp` to +wrap the ViT attention layers, attach the appropriate fusion adapter to the +vision-language projection layer **before** the first forward pass. Choose +the adapter that matches your model architecture:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.sequence.autosp_fusion import ( + LlavaFusionAdapter, + InternVLFusionAdapter, + Qwen2VLFusionAdapter, + ) + from deepspeed.utils import groups + + # 1. Wrap ViT and LLM attention layers automatically. + sp_group = groups._get_sequence_parallel_group() + auto_wrap_model_for_sp(model, process_group=sp_group) + + # 2. Attach the fusion adapter for the vision-language projection layer. + # LLaVA — replaces image-placeholder tokens with visual tokens: + model.mm_projector = LlavaFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMAGE_TOKEN_ID + ) + + # InternVL — replaces IMG_CONTEXT tokens 1-to-1 with visual tokens: + model.mm_projector = InternVLFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMG_CONTEXT_TOKEN_ID + ) + + # Qwen2-VL — replaces tokens between vision_start/end pairs 1-to-1: + model.visual.merger = Qwen2VLFusionAdapter( + model.visual.merger, sp_group, + vision_start_token_id=VISION_START_ID, + vision_end_token_id=VISION_END_ID, + ) + + # 3. Use the model as normal; the adapter handles all SP gather/scatter. + outputs = model(input_ids=input_ids, pixel_values=pixel_values, ...) + +Status: Phase 2. ``_splice_visual_into_text`` is intentionally left as a +``NotImplementedError``; override it per model architecture (see docstring). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist + +# Default image placeholder token ID used by LLaVA-style models. +_DEFAULT_IMAGE_TOKEN_ID = -200 + + +class ModalityFusionSPAdapter(nn.Module): + """Wraps the vision projection layer and handles cross-modal sequence fusion. + + After projecting visual features, this adapter: + + 1. Gathers the sharded visual token slices from all SP ranks into a single + full visual token tensor. + 2. Splices the visual tokens into the text embedding sequence at the + positions marked by ``image_token_id`` placeholders. + 3. Pads and re-shards the fused sequence so that the subsequent LLM + decoder layers receive uniformly distributed sequence slices. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``mm_projector``). + process_group: + The sequence-parallel process group. + image_token_id: + The token ID used as an image placeholder in the input IDs tensor. + Defaults to ``-200`` (LLaVA convention). + + Notes + ----- + Subclass this and override :meth:`_splice_visual_into_text` to adapt to a + specific multimodal architecture (LLaVA, InternVL, Qwen-VL, …). + """ + + def __init__(self, projection: nn.Module, process_group, image_token_id: int = _DEFAULT_IMAGE_TOKEN_ID) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.image_token_id = image_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate image placeholder positions. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Splice visual tokens into text embedding sequence + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) # [bs, fused_len, lm_hidden] + + # 4. Pad fused length to be divisible by world_size, then scatter + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace image placeholder positions in *text_embeds* with *visual_embeds*. + + This is intentionally architecture-specific. The default raises + ``NotImplementedError``; override this method for each supported model. + + Reference implementations: + * LLaVA: ``LlavaMetaForCausalLM.prepare_inputs_embeds`` + * InternVL: ``InternVLChatModel.extract_feature`` + * Qwen-VL: ``Qwen2VLForConditionalGeneration.get_rope_index`` + """ + raise NotImplementedError(f"{type(self).__name__}._splice_visual_into_text is not implemented. " + "Subclass ModalityFusionSPAdapter and override this method to match " + "your model's prepare_inputs_embeds logic.") + + +class LlavaFusionAdapter(ModalityFusionSPAdapter): + """LLaVA-style splice: replace each image placeholder token with visual tokens. + + Follows the logic of ``LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal``: + for each sample, locate ``image_token_id`` placeholders in ``input_ids``, + remove them, and insert the corresponding visual token chunk in their place. + + Visual tokens for a sample are split evenly across the number of image + placeholders found. This matches the common single-image case (one + placeholder per sample) and simple multi-image cases where every image + contributes the same number of tokens. + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + bs, text_len, hidden = text_embeds.shape + device = text_embeds.device + + fused_samples = [] + for i in range(bs): + img_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + num_images = img_pos.numel() + + if num_images == 0: + # No image in this sample — keep text embeddings unchanged. + fused_samples.append(text_embeds[i]) + continue + + # Split all visual tokens evenly across the image placeholders. + visual_chunks = torch.chunk(visual_embeds[i], num_images, dim=0) + + segments = [] + prev = 0 + for j, pos in enumerate(img_pos.tolist()): + # Text segment before this placeholder. + if pos > prev: + segments.append(text_embeds[i, prev:pos]) + # Visual tokens replacing this placeholder. + segments.append(visual_chunks[j]) + # Skip the placeholder token itself. + prev = pos + 1 + + # Remaining text after the last placeholder. + if prev < text_len: + segments.append(text_embeds[i, prev:]) + + fused_samples.append(torch.cat(segments, dim=0)) + + # Pad all samples to the same length so they stack into a tensor. + max_len = max(s.shape[0] for s in fused_samples) + out = torch.zeros(bs, max_len, hidden, dtype=text_embeds.dtype, device=device) + for i, s in enumerate(fused_samples): + out[i, :s.shape[0]] = s + return out + + +class InternVLFusionAdapter(ModalityFusionSPAdapter): + """InternVL-style splice: replace IMG_CONTEXT token runs with visual tokens. + + InternVL encodes each image as `` ×N `` + inside the token sequence. Each ``IMG_CONTEXT`` token (``image_token_id``) + is a 1-to-1 placeholder for one ViT visual token. This adapter locates + every contiguous run of ``image_token_id`` tokens and replaces them with + the corresponding slice of *visual_embeds*, while preserving the + ``IMG_START`` / ``IMG_END`` boundary embeddings unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length (no length change). + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + Set ``image_token_id`` to the ``IMG_CONTEXT`` token id used by the model + (e.g. the id of ````). + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + # Start from a clone of text embeddings; we only overwrite context positions. + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + ctx_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + if ctx_pos.numel() == 0: + continue + # ctx_pos lists every IMG_CONTEXT index in order. visual_embeds[i] + # has exactly ctx_pos.numel() tokens (one per context position). + out[i, ctx_pos] = visual_embeds[i, :ctx_pos.numel()] + + return out + + +class Qwen2VLFusionAdapter(nn.Module): + """Qwen2-VL-style splice: visual tokens enclosed by vision_start/end tokens. + + Qwen2-VL wraps each image's visual tokens with a pair of special boundary + tokens in ``input_ids``: ``vision_start_token_id`` and + ``vision_end_token_id``. The placeholder tokens between each + (start, end) pair are replaced 1-to-1 by the projected visual token + embeddings. The boundary token embeddings are kept unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``visual.merger``). + process_group: + The sequence-parallel process group. + vision_start_token_id: + Token id of ``<|vision_start|>``. + vision_end_token_id: + Token id of ``<|vision_end|>``. + """ + + def __init__(self, projection: nn.Module, process_group, vision_start_token_id: int, + vision_end_token_id: int) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate vision_start/end boundaries. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension. + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks. + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Replace placeholder positions in text with visual tokens (length-preserving). + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) + + # 4. Pad fused length to be divisible by world_size, then scatter. + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace inner placeholder tokens between vision_start/end pairs with visual embeddings.""" + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + start_pos = (input_ids[i] == self.vision_start_token_id).nonzero(as_tuple=True)[0] + end_pos = (input_ids[i] == self.vision_end_token_id).nonzero(as_tuple=True)[0] + + if start_pos.numel() == 0: + continue + + # Accumulate inner placeholder positions across all start/end pairs. + # Inner positions are (start+1) .. (end-1) inclusive, i.e. excluding + # the boundary tokens themselves. + inner_positions = [] + for s, e in zip(start_pos.tolist(), end_pos.tolist()): + inner_positions.extend(range(s + 1, e)) + + if not inner_positions: + continue + + inner_pos = torch.tensor(inner_positions, dtype=torch.long, device=text_embeds.device) + out[i, inner_pos] = visual_embeds[i, :len(inner_positions)] + + return out diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py new file mode 100644 index 000000000000..09bebbbac8f3 --- /dev/null +++ b/deepspeed/sequence/autosp_vit.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Ulysses-style sequence-parallel wrapper for ViT encoder attention layers. + +Design notes +------------ +ViT self-attention is non-causal: every patch token attends to every other +patch token. This means a straightforward per-rank local attention (as used +for causal LLMs) would be *incorrect* — each rank must have access to the +full key/value context. + +We therefore use a **gather-compute-scatter** pattern: + +1. Input arrives already sharded along the sequence dimension (each rank owns + ``local_patches = num_patches // world_size`` consecutive patches). +2. Before attention we **all-gather** patch tokens so that every rank runs the + full ViT attention over the complete patch sequence. This keeps the + computation equivalent to single-device execution. +3. The output is **scattered** back so that each rank returns only its local + slice, matching the sharded input contract expected by downstream layers. + +Memory benefit: activations *outside* the attention block (e.g. feed-forward +layers, layer norms) are stored only locally, reducing per-rank memory +proportional to ``world_size``. + +The ``cls`` token (if present) is replicated on every rank and is not split +across the sequence dimension. Each rank appends its local patches to the +same ``cls`` token before calling the wrapped attention. + +Padding: when ``num_patches % world_size != 0``, shorter shards are +zero-padded to a uniform size for ``all_gather``. The padding is stripped +*before* the attention call by trimming each rank's contribution to its true +length, so the wrapped attention always sees exactly ``num_patches`` real +tokens — identical to single-device execution and free of softmax pollution +from dummy tokens. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist + + +class UlyssesSPViTAttention(nn.Module): + """Sequence-parallel wrapper for an opaque ViT attention module. + + Parameters + ---------- + attn: + The original ViT attention layer (any ``nn.Module`` that maps + ``hidden_states`` → ``hidden_states`` or a tuple whose first element + is the attention output tensor). + process_group: + The sequence-parallel process group. + has_cls_token: + Set to ``True`` (default) when the first token in the sequence is a + ``[CLS]`` token that should be replicated on every rank rather than + sharded. + """ + + def __init__(self, attn: nn.Module, process_group, has_cls_token: bool = True) -> None: + super().__init__() + self.attn = attn + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.has_cls_token = has_cls_token + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward(self, hidden_states: torch.Tensor, **kwargs): + """ + Parameters + ---------- + hidden_states: + Shape ``[bs, local_seq_len, hidden_dim]`` where + ``local_seq_len = (1 + local_patches)`` if ``has_cls_token`` else + ``local_patches``. Each rank holds a contiguous slice of patches. + **kwargs: + Passed through to the wrapped attention (e.g. ``attention_mask``, + ``head_mask``, ``output_attentions``). + + Returns + ------- + Same shape as input (or a tuple whose first element matches the input + shape, preserving whatever the wrapped module returns). + """ + bs, local_seq_len, hidden_dim = hidden_states.shape + + if self.has_cls_token: + # CLS token is replicated on every rank — not part of the sharded seq + cls_token = hidden_states[:, :1, :] + local_patches = hidden_states[:, 1:, :] + else: + local_patches = hidden_states + + local_patch_len = local_patches.shape[1] + + # ------------------------------------------------------------------- + # 1. All-gather patches from all ranks to reconstruct the full sequence + # ------------------------------------------------------------------- + # When num_patches % world_size != 0, ranks hold different shard sizes. + # We all-gather every rank's local_patch_len so we can: + # (a) zero-pad shorter slices to uniform size for all_gather, and + # (b) strip the padding per rank *before* calling attention, so that + # the wrapped module never sees dummy tokens (which would corrupt + # the softmax normalisation). + len_bufs = [torch.zeros(1, dtype=torch.long, device=local_patches.device) for _ in range(self.world_size)] + dist.all_gather(len_bufs, + torch.tensor([local_patch_len], dtype=torch.long, device=local_patches.device), + group=self.process_group) + all_lens = [int(t.item()) for t in len_bufs] + max_local_len = max(all_lens) + + pad_len = max_local_len - local_patch_len + if pad_len > 0: + # Append zero rows so this rank's buffer matches the largest shard. + local_patches_padded = F.pad(local_patches, (0, 0, 0, pad_len)) + else: + local_patches_padded = local_patches + + gathered = [ + torch.zeros(bs, max_local_len, hidden_dim, dtype=local_patches.dtype, device=local_patches.device) + for _ in range(self.world_size) + ] + dist.all_gather(gathered, local_patches_padded.contiguous(), group=self.process_group) + + # Strip per-rank padding before concatenation so attention only sees + # the true num_patches tokens, identical to single-device execution. + real_parts = [gathered[r][:, :all_lens[r], :] for r in range(self.world_size)] + full_patches = torch.cat(real_parts, dim=1) # [bs, total_real_patches, hidden_dim] + + # ------------------------------------------------------------------- + # 2. Build the full input (prepend CLS if needed) and call attention + # ------------------------------------------------------------------- + if self.has_cls_token: + full_input = torch.cat([cls_token, full_patches], dim=1) + else: + full_input = full_patches + + attn_out = self.attn(full_input, **kwargs) + + # Unwrap tuple: some ViT implementations return (attn_output, attn_weights) + if isinstance(attn_out, (tuple, list)): + full_out, *extra = attn_out + else: + full_out = attn_out + extra = [] + + # ------------------------------------------------------------------- + # 3. Scatter output: each rank keeps its local slice of the real patches. + # Because padding was stripped before attention, scatter offsets are + # the cumulative sums of all_lens, not rank * max_local_len. + # ------------------------------------------------------------------- + if self.has_cls_token: + cls_out = full_out[:, :1, :] + patch_out = full_out[:, 1:, :] + else: + patch_out = full_out + + rank = dist.get_rank(self.process_group) + start = sum(all_lens[:rank]) + local_out = patch_out[:, start:start + local_patch_len, :].contiguous() + + if self.has_cls_token: + local_out = torch.cat([cls_out, local_out], dim=1) + + if extra: + return (local_out, *extra) + return local_out diff --git a/deepspeed/sequence/cross_entropy.py b/deepspeed/sequence/cross_entropy.py new file mode 100644 index 000000000000..baa7bc1ea7a8 --- /dev/null +++ b/deepspeed/sequence/cross_entropy.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +import deepspeed.comm as dist + + +class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_seq_parallel_logits, target, sp_group): + # vocab_seq_parallel_logits: [S/P, B, V] + # target: [S/P, B] + # return: [S, B] + + # Need softmax for backward + softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) + ctx.vocab_size = vocab_seq_parallel_logits.size(2) + loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none') + + sp_world_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + ctx.sp_world_size = sp_world_size + ctx.sp_rank = sp_rank + ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size + batch_size = vocab_seq_parallel_logits.size(1) + + loss_all = torch.empty(ctx.seqlen, + batch_size, + dtype=vocab_seq_parallel_logits.dtype, + device=vocab_seq_parallel_logits.device) + dist.all_gather_into_tensor(loss_all, loss, group=sp_group) + + ctx.save_for_backward(softmax, target) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + softmax, target = ctx.saved_tensors + + step_seqlen = ctx.seqlen // ctx.sp_world_size + sp_rank = ctx.sp_rank + grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :] + + grad_input = softmax + grad_2d = grad_input.view(-1, ctx.vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + grad_2d[arange_1d, target.view(-1)] -= 1 + grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + + return grad_input, None, None, None + + +def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group): + return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py new file mode 100644 index 000000000000..7122c3e356aa --- /dev/null +++ b/deepspeed/sequence/fpdt_layer.py @@ -0,0 +1,1226 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Optional, Any, Tuple +from torch import Tensor +from packaging import version +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import jit_script_compat + +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + flash_attn_version = version.parse(flash_attn.__version__) +except ImportError: + _flash_attn_forward = None + _flash_attn_backward = None + +from einops import rearrange +from .layer import single_all_to_all, apply_rotary_pos_emb + + +def _rotate_half_backward(x): + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((x2, -x1), dim=-1) + + +def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): + rot_dim = freqs_cos.shape[-1] + grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] + grad_t = (grad * freqs_cos) + (_rotate_half_backward(grad * freqs_sin)) + grad = grad_t if grad_pass.shape[-1] == 0 else torch.cat((grad_t, grad_pass), dim=-1) + return grad + + +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +class FPDT_InputConstruct(torch.nn.Module): + + def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None: + + super(FPDT_InputConstruct, self).__init__() + self.tokens = tokens + self.labels = labels + self.loss_mask = loss_mask + self.attention_mask = attention_mask + self.position_ids = position_ids + global_seq_len = tokens.shape[1] + batch_size = tokens.shape[0] + assert global_seq_len % sp_size == 0 + assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 + num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size + local_seq_len = global_seq_len // sp_size + assert local_seq_len % num_chunk_per_gpu == 0 + + self.num_chunk_per_gpu = num_chunk_per_gpu + self.chunk_size = local_seq_len // num_chunk_per_gpu + self.sp_size = sp_size + self.sp_rank = sp_rank + self.global_seq_len = global_seq_len + self.local_seq_len = local_seq_len + self.batch_size = batch_size + self.device = tokens.device + + def generate(self): + device = self.device + totalChunks = self.global_seq_len // self.chunk_size + token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size + chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) + chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() + + gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() + mask = gather_chunk == token_chunk_idx + + indices = mask.nonzero(as_tuple=False) + gather_indices = indices[:, 0] + token_chunk_indices = indices[:, 1] + indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) + load_balanced_loss_mask = self.loss_mask[:, indices] if self.loss_mask is not None else self.loss_mask + + indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * + (self.sp_rank + 1)].flatten().contiguous() + load_balanced_tokens = self.tokens[:, indices] + load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels + + load_balanced_attention_mask = self.attention_mask + load_balanced_position_ids = self.position_ids[:, + indices] if self.position_ids is not None else self.position_ids + + return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids + + +class _FPDTGPUAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, + cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + else: + ctx.pos_emb_cos = None + ctx.pos_emb_sin = None + + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [None for _ in range(num_chunks)] + global_lse = [None for _ in range(num_chunks)] + + for i in range(num_chunks): + + st = chunk_size * i + ed = st + chunk_size + + qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) + global_q_chunk_len = q_chunk.shape[1] + if rotary_pos_emb is not None: + q_chunk = apply_rotary_pos_emb(q_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_q.append(q_chunk) + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) + if rotary_pos_emb is not None: + k_chunk = apply_rotary_pos_emb(k_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_k.append(k_chunk) + + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) + global_v.append(v_chunk) + + for k_i in range(len(global_k)): + causal_chunk = i == k_i + if flash_attn_version >= version.parse("2.6.0"): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + + global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) + + global_o[i] = global_o[i].to(q_chunk.dtype) + + output = [None for i in range(num_chunks)] + + for i in range(num_chunks): + global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() + output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + output = torch.cat(output, dim=1) + + head_dim = output.shape[-1] + + if do_save: + ctx.save_for_backward(layernorm_output) + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return output + + @staticmethod + def backward(ctx, grad_output): + + num_chunks = ctx.num_chunks + device = ctx.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.saved_tensors[0] + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + input_chunk_size = layernorm_output.shape[0] // num_chunks + grad_layernorm_output = [ + torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), + device=device, + dtype=dtype) for _ in range(num_chunks) + ] + + grad_global_attn_output = [] + chunk_size = grad_output.shape[1] // num_chunks + + for i in range(num_chunks): + st = chunk_size * i + ed = st + chunk_size + grad_global_attn_output.append( + single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) + + del grad_output + + dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) + + for i in range(num_chunks): + k_chunk = global_k[i] + v_chunk = global_v[i] + + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + q_chunk = global_q[q_i] + attn_output_chunk = attn_output[q_i] + lse_chunk = lse[q_i] + d_out = grad_global_attn_output[q_i] + + dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) + + if flash_attn_version >= version.parse("2.6.0"): + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + else: + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + + dq[q_i].add_(dq_this.to(torch.float)) + dk[i].add_(dk_this.to(torch.float)) + dv[i].add_(dv_this.to(torch.float)) + + dk_seq_len = dk[i].shape[1] + + if ctx.pos_emb_cos is not None: + dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + else: + dk[i] = dk[i].to(dtype) + dv[i] = dv[i].to(dtype) + dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) + dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) + + input_st = i * input_chunk_size + input_ed = input_st + input_chunk_size + + input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) + + dk[i] = dk[i].flatten(2).permute(1, 0, 2) + dv[i] = dv[i].flatten(2).permute(1, 0, 2) + l, b = dk[i].shape[0], dk[i].shape[1] + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk[i].sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_( + torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dv[i], + qkv_linear_weight[projection_size + kv_projection_size:])) + + dk[i] = None + dv[i] = None + + for i in range(num_chunks): + dq_seq_len = dq[i].shape[1] + if ctx.pos_emb_cos is not None: + dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), + ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + else: + dq[i] = dq[i].to(dtype) + dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + + input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) + layernorm_output = layernorm_output[input_chunk_size:] + + dq[i] = dq[i].flatten(2).permute(1, 0, 2) + l, b = dq[i].shape[0], dq[i].shape[1] + grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) + + dq[i] = None + + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + +class SequenceChunk: + + def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): + + self.chunk_shape = chunk.shape + self.chunk_dtype = chunk.dtype + self.device = chunk.device if device is None else device + + cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) + + if get_accelerator().on_accelerator(chunk): + cpu_chunk.copy_(chunk, non_blocking=True) + else: + cpu_chunk = chunk + + self.cpu_chunk = cpu_chunk + + self.gpu_chunk = chunk if is_in_use else None + + def load_to_gpu(self): + assert self.gpu_chunk is None + if self.gpu_chunk is not None: + pass + else: + gpu_chunk = torch.empty(self.chunk_shape, device=self.device, dtype=self.chunk_dtype) + gpu_chunk.copy_(self.cpu_chunk, non_blocking=True) + self.gpu_chunk = gpu_chunk + + def get_gpu_chunk(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + return self.gpu_chunk + + def check_gpu_chunk(self, ): + assert (self.gpu_chunk is not None) and ( + self.gpu_chunk.device == self.device + ), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" + return True + + def offload(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + del self.gpu_chunk + self.gpu_chunk = None + + def overwrite_to_cpu(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + self.cpu_chunk.copy_(self.gpu_chunk, non_blocking=True) + + +class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, + cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + else: + ctx.pos_emb_cos = None + ctx.pos_emb_sin = None + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + ctx.chunk_size = chunk_size + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [] + global_lse = [] + + layernorm_output_cpu = [] + final_output = [] + + offload_stream = get_accelerator().Stream() + general_offload_stream = get_accelerator().Stream() + compute_stream = get_accelerator().default_stream() + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + for i in range(num_chunks): + + qkv_chunk = torch.matmul(layernorm_output[:chunk_size], + qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) + + with get_accelerator().stream(general_offload_stream): + layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) + + layernorm_output = layernorm_output[chunk_size:] + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) + global_q_chunk_len = q_chunk.shape[1] + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) + + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) + + dist.barrier() + + if ctx.pos_emb_cos is not None: + pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + + q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + with get_accelerator().stream(offload_stream): + global_q.append(SequenceChunk(q_chunk, is_in_use=True)) + global_k.append(SequenceChunk(k_chunk, is_in_use=True)) + global_v.append(SequenceChunk(v_chunk, is_in_use=True)) + + del qkv_chunk + + cur_attn_output = None + cur_attn_lse = None + for k_i in range(len(global_k)): + causal_chunk = i == k_i + with get_accelerator().stream(compute_stream): + if flash_attn_version >= version.parse("2.6.0"): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, + block_lse) + + can_offload_kv = True + if k_i != (len(global_k) - 1) or i != (num_chunks - 1): + if k_i != (len(global_k) - 1): + next_kv_compute_chunk_idx = k_i + 1 + else: + next_kv_compute_chunk_idx = 0 + + if next_kv_compute_chunk_idx == kv_compute_chunk_idx: + can_offload_kv = False + else: + if next_kv_compute_chunk_idx != (len(global_k) - 1): + with get_accelerator().stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + if i == num_chunks - 1 and k_i == num_chunks - 1: + with get_accelerator().stream(offload_stream): + global_q[0].load_to_gpu() + global_k[0].load_to_gpu() + global_v[0].load_to_gpu() + global_o[0].load_to_gpu() + global_lse[0].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + if can_offload_kv: + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + global_q[q_compute_chunk_idx].offload() + q_compute_chunk_idx += 1 + + all2all_output = single_all_to_all( + cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + final_output.append(all2all_output) + with get_accelerator().stream(general_offload_stream): + global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) + global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) + + compute_stream.wait_stream(general_offload_stream) + compute_stream.synchronize() + + final_output = torch.cat(final_output, dim=1) + + head_dim = final_output.shape[-1] + + if do_save: + ctx.layernorm_output = layernorm_output_cpu + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return final_output + + @staticmethod + def backward(ctx, grad_output): + num_chunks = ctx.num_chunks + device = grad_output.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.layernorm_output + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + offload_stream = get_accelerator().Stream() + general_offload_stream = get_accelerator().Stream() + compute_stream = get_accelerator().default_stream() + + chunk_size = grad_output.shape[1] // num_chunks + assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] + + grad_layernorm_output = [ + torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks) + ] + + grad_global_attn_output = [None for _ in range(num_chunks)] + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + last_q_accum_idx = 0 + + with get_accelerator().stream(general_offload_stream): + layernorm_output[0].load_to_gpu() + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, + gather_idx, 0, spg) + get_accelerator().synchronize() + grad_output = grad_output[:, chunk_size:] + + with get_accelerator().stream(offload_stream): + grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) + dq = [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True) + ] + [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), + device) for _ in range(num_chunks - 1) + ] + dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) + + for i in range(num_chunks): + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) + + with get_accelerator().stream(compute_stream): + if flash_attn_version >= version.parse("2.6.0"): + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + else: + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + + if i != (len(global_k) - 1): + if q_i != (len(global_q) - 1): + next_q_compute_chunk_idx = q_i + 1 + else: + next_q_compute_chunk_idx = i + 1 + + can_offload_q = True + + if next_q_compute_chunk_idx == q_compute_chunk_idx: + can_offload_q = False + else: + with get_accelerator().stream(offload_stream): + if i > 0 or q_i > 0: + if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it + dq[last_q_accum_idx].offload() + dq[next_q_compute_chunk_idx].load_to_gpu() + global_q[next_q_compute_chunk_idx].load_to_gpu() + attn_output[next_q_compute_chunk_idx].load_to_gpu() + lse[next_q_compute_chunk_idx].load_to_gpu() + if grad_global_attn_output[next_q_compute_chunk_idx] is not None: + grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() + + if grad_global_attn_output[next_q_compute_chunk_idx] is None: + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), + scatter_idx, gather_idx, 0, spg) + dist.barrier() + grad_output = grad_output[:, chunk_size:] + grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( + grad_global_attn_output_chunk, is_in_use=True) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + with get_accelerator().stream(compute_stream): + dq[q_compute_chunk_idx].check_gpu_chunk() + dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) + dk_accum.add_(dk_this) + dv_accum.add_(dv_this) + + offload_stream.wait_stream(compute_stream) + with get_accelerator().stream(offload_stream): + dq[q_compute_chunk_idx].overwrite_to_cpu() + + if can_offload_q: + global_q[q_compute_chunk_idx].offload() + attn_output[q_compute_chunk_idx].offload() + lse[q_compute_chunk_idx].offload() + grad_global_attn_output[q_compute_chunk_idx].offload() + + last_q_accum_idx = q_compute_chunk_idx + q_compute_chunk_idx = next_q_compute_chunk_idx + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + dk_seq_len = dk_accum.shape[1] + + if ctx.pos_emb_cos is not None: + dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + else: + dq_accum = dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype) + dk_accum = dk_accum.to(dtype) + dv_accum = dv_accum.to(dtype) + + dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + + general_offload_stream.synchronize() + compute_stream.wait_stream(general_offload_stream) + dist.barrier() + + with get_accelerator().stream(compute_stream): + input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) + + dq_accum = dq_accum.flatten(2).permute(1, 0, 2) + dk_accum = dk_accum.flatten(2).permute(1, 0, 2) + dv_accum = dv_accum.flatten(2).permute(1, 0, 2) + + l, b = dk_accum.shape[0], dk_accum.shape[1] + + grad_qkv_linear_weight[:projection_size].add_( + torch.matmul(dq_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv_accum.reshape(l * b, -1).t(), input_chunk)) + + grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv_accum.sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) + grad_layernorm_output[i].add_( + torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_( + torch.matmul(dv_accum, qkv_linear_weight[projection_size + kv_projection_size:])) + + del dq_accum, dk_accum, dv_accum + dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[i].chunk_shape, dtype=torch.float, device=device) + dq[kv_compute_chunk_idx].offload() + dq[kv_compute_chunk_idx] = None + + if i != (len(global_k) - 1): + next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 + with get_accelerator().stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + with get_accelerator().stream(general_offload_stream): + layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + layernorm_output[kv_compute_chunk_idx].offload() + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + +class FPDT_Attention(torch.nn.Module): + + def __init__(self, + config, + first_weight, + first_bias, + second_weight, + second_bias, + sequence_process_group, + gather_idx: int = 0, + scatter_idx: int = 2, + return_bias=True, + chunk_size=65536, + enable_offloading=True) -> None: + + super(FPDT_Attention, self).__init__() + if _flash_attn_forward is None or _flash_attn_backward is None: + raise ImportError( + "DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`." + ) + + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.config = config + + self.projection_size = config.kv_channels * config.num_attention_heads + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.kv_projection_size = config.kv_channels * config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.qkv_linear_weight = first_weight + self.qkv_linear_bias = first_bias + self.qkv_dense_weight = second_weight + self.qkv_dense_bias = second_bias + + self.reture_bias = return_bias + self.dropout = config.attention_dropout + + self.chunk_size = chunk_size + self.double_buffer = enable_offloading + + def forward(self, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + cpu_offloading=True) -> Tensor: + self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size + + if not cpu_offloading or self.num_chunks_attn == 1: + output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, + self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, + self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, + self.qkv_linear_bias, self.dropout, self.num_chunks_attn, + cpu_offloading) + else: + output = _FPDTGPUOffloadingAttentionImpl_.apply( + layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, + self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, + self.num_chunks_attn, cpu_offloading) + + output = output.flatten(2).permute(1, 0, 2).contiguous() + + output = torch.matmul(output, self.qkv_dense_weight.t()) + if not self.reture_bias: + output += self.qkv_dense_bias + return output, self.qkv_dense_bias if self.reture_bias else None + + +@jit_script_compat +def bias_gelu(x): + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +@jit_script_compat +def bias_gelu_back(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + +class FPDT_FFN(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): + do_save = x.requires_grad + ctx.add_bias = add_bias + device = x.device + + with torch.no_grad(): + num_chunk = x.shape[0] // chunk_size + ctx.num_chunk = num_chunk + result = torch.empty(x.shape, device=device, dtype=x.dtype) + assert chunk_size * num_chunk == x.shape[0] + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_ = torch.matmul(x[st:ed], w1.t()) + b1 + x_ = bias_gelu(x_) + if add_bias: + result[st:ed] = torch.matmul(x_, w2.t()) + b2 + else: + result[st:ed] = torch.matmul(x_, w2.t()) + + del x_ + + if do_save: + ctx.device = device + ctx.dtype = x.dtype + ctx.save_for_backward(x, w1, b1, w2, b2) + ctx.grad_x_shape = x.shape + return result.to(x.dtype), b2 if not add_bias else None + + @staticmethod + def backward(ctx, grad_output, grad_bias): + x, w1, b1, w2, b2 = ctx.saved_tensors + device = ctx.device + dtype = ctx.dtype + add_bias = ctx.add_bias + + num_chunk = ctx.num_chunk + chunk_size = x.shape[0] // num_chunk + assert chunk_size * num_chunk == grad_output.shape[0] + + grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) + grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) + grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) + grad_b1 = torch.zeros(b1.shape, device=device, dtype=torch.float) + + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_chunk = x[st:ed] + + before_act = (torch.matmul(x_chunk, w1.t()) + b1) + before_act_2 = before_act**2 + tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) + ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * + (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) + grad_w2.add_( + torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), + (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) + del before_act, before_act_2, tanh_out + + grad_inter = torch.matmul(grad_output[st:ed], w2) * ff + del ff + + grad_w1.add_(torch.matmul( + grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) + grad_b1.add_(grad_inter.sum(0).sum(0)) + + x[st:ed].copy_(torch.matmul(grad_inter, w1)) + + del grad_inter + + if add_bias: + grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) + + return x, grad_w1.to(dtype), grad_b1.to(dtype), grad_w2.to(dtype), grad_b2.to(dtype), None, None + + +class FPDT_LogitsLoss(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num_chunk): + labels = labels.t() + chunk_size = lm_output.shape[0] // num_chunk + assert chunk_size * num_chunk == lm_output.shape[0] + batch_size, local_seq_len = lm_output.shape[1], lm_output.shape[0] + loss = torch.empty((batch_size, local_seq_len), dtype=torch.float, device=lm_output.device) + + ctx.num_chunk = num_chunk + ctx.chunk_size = chunk_size + ctx.device = lm_output.device + ctx.dtype = lm_output.dtype + + ctx.rank = rank + ctx.local_seq_len = local_seq_len + with torch.no_grad(): + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + logits_chunk = torch.matmul(lm_output[st:ed], logit_weights.t()).float() + + vocab_size = logits_chunk.size(2) + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), + labels[st:ed, :].reshape(-1).contiguous(), + reduction='none') + loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() + + del logits_chunk + ctx.save_for_backward(lm_output.to('cpu'), labels) + ctx.logit_weights = logit_weights + + seqlen = local_seq_len * spg_size + batch_size = loss.size(0) + loss = loss.t().contiguous() + loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() + + dist.allgather_fn(loss_all, loss, group=spg) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + lm_output, labels = ctx.saved_tensors + logit_weights = ctx.logit_weights + device = ctx.device + dtype = ctx.dtype + num_chunk = ctx.num_chunk + chunk_size = ctx.chunk_size + + rank = ctx.rank + local_seq_len = ctx.local_seq_len + + grad_output = grad_output[rank * local_seq_len:(rank + 1) * local_seq_len] + grad_lm_output = [None for _ in range(num_chunk)] + grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + lm_output_chunk = lm_output[st:ed].to(device) + logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() + + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + vocab_size = logits_chunk.size(2) + + grad_input = softmax + grad_2d = grad_input.reshape(-1, vocab_size).contiguous() + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=device) + + grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 + grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) + grad_input = grad_input.to(dtype) + + grad_output = grad_output[chunk_size:].contiguous() + + grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) + grad_lm_output[i] = grad_lm_output_chunk + + grad_logit_weights.add_( + torch.matmul( + grad_input.reshape(-1, grad_input.shape[2]).t(), + lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) + + return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py new file mode 100644 index 000000000000..ecbe0d94120e --- /dev/null +++ b/deepspeed/sequence/layer.py @@ -0,0 +1,440 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +from einops import rearrange + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads +from deepspeed.utils import groups + + +def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input): + """ + This function generates the parameters required for `permute` and `reshape` operations, + which are used to process data before and after `all2all` communication. + """ + if batch_dim_idx == 0: + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] + pre_all2all_permute_idx = (1, 0, 2, 3, 4) + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + + post_all2all_permute_idx = (1, 0, 2, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] + else: + if scatter_idx < 2: + global_seq_len, bs, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim] + pre_all2all_permute_idx = None + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim] + else: + local_seq_len, bs, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + post_all2all_permute_idx = None + post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim] + + return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape + + +def post_all2all(permute_idx, res_shape): + """ + Post-processing function for `all2all` communication. + """ + + def post_func(input): + if permute_idx is not None: + input = input.permute(permute_idx).contiguous() + output = input.reshape(res_shape).contiguous() + + return output + + return post_func + + +def pre_all2all_fun(permute_idx, inp_shape, input): + """ + Pre-processing function for `all2all` communication. + """ + input_t = input.reshape(inp_shape).contiguous() + if permute_idx is not None: + input_t = input_t.permute(permute_idx).contiguous() + return input_t + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs_cos.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + + res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + return res + + +def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1" + + if not (scatter_idx < 2): + input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size) + input = input.transpose(0, scatter_idx).contiguous() + local_heads = input_splits[groups._get_sequence_parallel_rank()] + output_splits = [local_heads] * seq_world_size + + output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:]) + output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits,\ + input_split_sizes=input_splits,group=group) + ###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...] + output = output.view(seq_world_size, local_heads, *output.shape[1:]) + ###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + + ### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...] + ### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + if batch_dim_idx == 0: + order = [3, 0, 2, 1] + list(range(4, len(output.shape))) + output = output.permute(order).contiguous() + ###[b, seq_ws*local_seq_len, local_heads,...] + output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size, + *output.shape[3:]).contiguous() + elif batch_dim_idx == 1: + output = output.transpose(1, 3).contiguous() + ###[seq_ws*local_seq_len, b, local_heads,...] + output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous() + else: + # The compatibility handling of 4D and 3D tensors, standardizing to 3D. + input = input.reshape(input.shape[0], input.shape[1], -1) + + if batch_dim_idx == 0: #b,s,h + input = input.permute(1, 2, 0).contiguous() #s,h,b + elif batch_dim_idx == 1: #s,b,h + input = input.transpose(1, 2).contiguous() #s,h,b + seq_len, h, batch_size = input.shape + num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size) + local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()] + h_dim = h // local_heads + local_seq_len = seq_len // seq_world_size + + input = input.view(seq_len * h, batch_size) + local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim + input_splits = [local_seq_len_with_heads] * seq_world_size + coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim + + #uneven seq_world_size coeff, total_heads/local_heads. + heads_scale_coeff = get_num_kv_heads() / local_heads + + output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list] + output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads) + total_h = int(inp_shape[gather_idx] * heads_scale_coeff) + output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits, \ + input_split_sizes=input_splits,group=group) + ################## + #suppose 7 heads divide into 4 ranks [2,2,2,1] + #chunk_num_heads_small=floor(7/4)=1 + #chunk_num_heads_large=ceil(7/4)=2 + #num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts + #num_chunk_heads_small=len([1])=1, all2all_buffer_counts + #total_num_large_heads=sum([2,2,2])=7 + #total_num_small_heads=sum([1])=1 + + chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible + chunk_num_heads_large = chunk_num_heads_small + 1 + num_chunk_heads_large = get_num_kv_heads() % seq_world_size + num_chunk_heads_small = seq_world_size - num_chunk_heads_large + total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large + total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small + + heads_large_combine_size = coeff * total_num_large_heads + heads_small_combine_size = coeff * total_num_small_heads + heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size], + dim=0) + heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim, + batch_size) + heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim, + batch_size) + if batch_dim_idx == 0: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim] + order = [4, 1, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_small_heads, h_dim) + elif batch_dim_idx == 1: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim] + order = [1, 4, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_small_heads, h_dim) + + output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous() + + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + output_shape= inp_shape[: gather_idx] + \ + [total_h,] + \ + inp_shape[gather_idx + 1:] + + output = output.view(output_shape) + + return output + + +def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): + seq_world_size = dist.get_world_size(group) + # we only need num_heads once + num_heads = input.shape[2] + + if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2): + # Assuming here that the number of heads for q is consistent with kv + # If not, additional logic is required for cases like GQA + if get_num_kv_heads() is None: + assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})" + # set heads at first call by num_total_heads. + # then use ``get_num_kv_heads() is not None`` to re-entry uneven path. + set_num_kv_heads(num_heads) + assert async_op == False, "uneven head sp does not support async op" + return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group) + + pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params( + scatter_idx, batch_dim_idx, seq_world_size, input) + + input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) + + post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) + output = torch.empty_like(input_t) + work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) + + if async_op: + if type in ('dq', 'dk'): + handle[type + '_work'] = work + handle[type + '_grad'] = output + handle[type + '_post_all2all_func'] = post_all2all_fun + return output.view(post_all2all_res_shape) + + res = post_all2all_fun(output) + return res + + +class _DimZeroAllToAll(torch.autograd.Function): + """Differentiable All2All across dimension 0.""" + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: + world_size = dist.get_world_size(group) + assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size" + + ctx.group = group + + output = torch.empty_like(input).contiguous() + # torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group) + dist.all_to_all_single(output, input.contiguous(), group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output)) + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + batch_dim_idx: int, + stream=None, + handle=None, + type=None, + is_fwd=True) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.stream = stream + ctx.handle = handle + ctx.type = type + ctx.batch_dim_idx = batch_dim_idx + if ctx.handle is None: + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) + + else: + # overlap communication path + if not is_fwd and type == 'o': + assert ctx.stream != None + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) + get_accelerator().current_stream().wait_stream(ctx.stream) + # The computation of d o_weight can overlap with the communication of d o_input + + elif not is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv + type = 'd' + type + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type) + + elif is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v + type = 'fwd_' + type + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type) + + else: + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) + + return res + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + + return (None, + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, + ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None) + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + local_attention: Module, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, + gather_idx: int = 0, + sp_stream=None, + ) -> None: + + super(DistributedAttention, self).__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.sp_overlap_comm = False + self.overlap_handles = None + self.sp_stream = sp_stream + if sp_stream is not None: + self.overlap_handles = {} + self.sp_overlap_comm = True + self.default_stream = get_accelerator().default_stream() + + def layer_sync(self, layer): + if self.sp_overlap_comm and hasattr(layer, 'done_event'): + self.default_stream.wait_event(layer.done_event) + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + batch_dim_idx: int, + rotary_pos_emb=None, + *args: Any, + **kwargs) -> Tensor: + """ forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + batch_dim_idx (int): indicating which dim is batch + args: other args + + Returns: + * output (Tensor): context output + """ + + # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! + #in shape : e.g., [s/p:h:] + + def bwd_hook(layer_type): + + def pre_hook_fun(grad): + type = 'd' + layer_type + self.overlap_handles[type + '_work'].wait() + self.sp_stream.wait_stream(self.default_stream) + all2all_output = self.overlap_handles[type + '_grad'] + grad = list(grad) + grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output) + grad = tuple(grad) + + return pre_hook_fun + + self.layer_sync(query) + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'q') + self.layer_sync(key) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'k') + if self.sp_overlap_comm: + self.default_stream.wait_stream(self.sp_stream) + + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'v') + + if self.sp_overlap_comm: + # Register a hook to synchronize dq and dk after the all-to-all + # operation when the gradient data is used. + # Place this logic after the q, k, v all-to-all operation to + # improve interpreter speed to + # call and launch of the forward all-to-all communication. + grad_fn_q = query.grad_fn.next_functions[0][0] + grad_fn_q.register_prehook(bwd_hook(layer_type='q')) + grad_fn_k = key.grad_fn.next_functions[0][0] + grad_fn_k.register_prehook(bwd_hook(layer_type='k')) + + #out shape : e.g., [s:h/p:] + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin) + key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin) + + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) + + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, + self.sp_stream, self.overlap_handles, 'o') + + #out e.g., [s/p::h] + return output diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py new file mode 100644 index 000000000000..771a9e7bb6b7 --- /dev/null +++ b/deepspeed/sequence/test_autosp.py @@ -0,0 +1,724 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Unit tests for AutoSP multimodal sequence parallelism: + - autosp_detector: model scanning + - UlyssesSPViTAttention: ViT SP wrapper + - auto_wrap_model_for_sp: end-to-end wrapping + - ModalityFusionSPAdapter: cross-modal gather/scatter + - LlavaFusionAdapter: LLaVA-style visual token splice + - InternVLFusionAdapter: InternVL-style IMG_CONTEXT token splice + - Qwen2VLFusionAdapter: Qwen2-VL vision_start/end bounded splice +""" + +import pytest +import torch +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, + detect_model_sp_info) +from deepspeed.sequence.autosp_fusion import (InternVLFusionAdapter, LlavaFusionAdapter, ModalityFusionSPAdapter, + Qwen2VLFusionAdapter) +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp +from deepspeed.sequence.layer import DistributedAttention + +# --------------------------------------------------------------------------- +# Minimal fake modules that mimic the interface of real attention layers +# without requiring a GPU or a real transformer model. +# --------------------------------------------------------------------------- + + +class _FakeViTAttn(nn.Module): + """Identity ViT attention — returns hidden_states unchanged.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class _FakeViTAttnTuple(nn.Module): + """ViT attention that returns a (output, weights) tuple.""" + + def forward(self, hidden_states, **kwargs): + weights = torch.zeros(hidden_states.shape[0], 1, hidden_states.shape[1], hidden_states.shape[1]) + return hidden_states, weights + + +class _FakeLLMAttn(nn.Module): + """Identity LLM attention.""" + + def forward(self, query, key, value, *args, **kwargs): + return query + + +# Register fake class names so the detector recognises them +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttn") +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttnTuple") +_LLM_ATTN_CLASSNAMES.add("_FakeLLMAttn") + + +class _FakeMultimodalModel(nn.Module): + """Minimal multimodal model with one ViT and one LLM attention layer.""" + + def __init__(self): + super().__init__() + self.vision_encoder = nn.ModuleList([_FakeViTAttn()]) + self.mm_projector = nn.Linear(64, 64) + self.llm = nn.ModuleList([_FakeLLMAttn()]) + + +class _FakeViTOnlyModel(nn.Module): + + def __init__(self, num_layers=3): + super().__init__() + self.layers = nn.ModuleList([_FakeViTAttn() for _ in range(num_layers)]) + + +class _FakeLLMOnlyModel(nn.Module): + """Minimal LLM-only model with multiple decoder attention layers.""" + + def __init__(self, num_layers=2): + super().__init__() + self.layers = nn.ModuleList([_FakeLLMAttn() for _ in range(num_layers)]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_process_group(world_size: int, rank: int): + """Return a mock object that satisfies dist.get_world_size / get_rank.""" + import unittest.mock as mock + import deepspeed.comm as dist + + pg = mock.MagicMock() + dist.get_world_size = mock.MagicMock(return_value=world_size) + dist.get_rank = mock.MagicMock(return_value=rank) + + def _fake_all_gather(tensor_list, tensor, group=None): + for t in tensor_list: + t.copy_(tensor) + + dist.all_gather = _fake_all_gather + return pg + + +# --------------------------------------------------------------------------- +# autosp_detector tests +# --------------------------------------------------------------------------- + + +class TestAutospDetector: + + def test_detects_vit_and_llm(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 1 + assert len(info.llm_attn_modules) == 1 + + def test_detects_vision_projection(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + name, module = info.vision_projection_module + assert "mm_projector" in name + + def test_detects_multiple_vit_layers(self): + model = _FakeViTOnlyModel(num_layers=4) + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 4 + assert len(info.llm_attn_modules) == 0 + assert info.vision_projection_module is None + + def test_empty_model_returns_empty_info(self): + model = nn.Sequential(nn.Linear(8, 8)) + info = detect_model_sp_info(model) + assert isinstance(info, SPModelInfo) + assert len(info.vit_attn_modules) == 0 + assert len(info.llm_attn_modules) == 0 + + def test_only_first_projection_is_recorded(self): + """Multiple projection-like names → only the outermost is recorded.""" + + class _M(nn.Module): + + def __init__(self): + super().__init__() + self.mm_projector = nn.Sequential(nn.Linear(8, 8)) + self.mm_projector.visual_projection = nn.Linear(8, 8) + + model = _M() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + # Should be the outermost "mm_projector", not the nested one + name, _ = info.vision_projection_module + assert name == "mm_projector" + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention tests (CPU, rank-0 simulation via mocks) +# --------------------------------------------------------------------------- + + +class TestUlyssesSPViTAttention: + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches,world_size", [ + (16, 4), + (16, 2), + (9, 3), + ]) + def test_output_shape_matches_input(self, has_cls_token, num_patches, world_size): + """Output shape must equal input shape for any padding scenario.""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=has_cls_token) + + local_patches = num_patches // world_size + seq_len = (1 + local_patches) if has_cls_token else local_patches + x = torch.randn(2, seq_len, 32) + + out = wrapper(x) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_tuple_output_unwrapped_correctly(self): + """Wrappers that return (output, weights) tuples are handled.""" + pg = _make_mock_process_group(world_size=2, rank=0) + attn = _FakeViTAttnTuple() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=False) + + x = torch.randn(1, 8, 16) # 8 patches, 2 ranks → 4 local each + result = wrapper(x) + # Should return a tuple: (attention_output, attention_weights) + assert isinstance(result, tuple) + assert result[0].shape == x.shape + + def test_identity_attn_preserves_values(self): + """When attn is identity, output values should match input values.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=True) + + # Each rank holds cls + 4 local patches + x = torch.arange(2 * 5 * 4, dtype=torch.float).reshape(2, 5, 4) + out = wrapper(x) + # CLS token should be identical + assert torch.allclose(out[:, :1, :], x[:, :1, :]) + # Local patch slice should match input patches for identity attn + assert torch.allclose(out[:, 1:, :], x[:, 1:, :]) + + +# --------------------------------------------------------------------------- +# auto_wrap_model_for_sp tests +# --------------------------------------------------------------------------- + + +class TestAutoWrapModelForSP: + + def test_vit_layers_replaced(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=2) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, UlyssesSPViTAttention) + + def test_raises_on_unknown_model(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = nn.Sequential(nn.Linear(8, 8)) + with pytest.raises(ValueError, match="no recognisable attention"): + auto_wrap_model_for_sp(model, pg) + + def test_set_module_by_name_shallow(self): + model = _FakeViTOnlyModel(num_layers=1) + new_mod = nn.Linear(4, 4) + _set_module_by_name(model, "layers.0", new_mod) + assert model.layers[0] is new_mod + + def test_set_module_by_name_deep(self): + model = _FakeMultimodalModel() + new_mod = nn.Identity() + _set_module_by_name(model, "vision_encoder.0", new_mod) + assert model.vision_encoder[0] is new_mod + + def test_llm_layers_replaced_with_distributed_attention(self): + """LLM attention layers must be wrapped with DistributedAttention.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeLLMOnlyModel(num_layers=3) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, DistributedAttention) + + def test_multimodal_model_wraps_both_branches(self): + """Both ViT and LLM attention layers must be replaced in a combined model.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeMultimodalModel() + returned = auto_wrap_model_for_sp(model, pg) + # auto_wrap_model_for_sp must return the same object (in-place) + assert returned is model + assert isinstance(model.vision_encoder[0], UlyssesSPViTAttention) + assert isinstance(model.llm[0], DistributedAttention) + + def test_original_module_preserved_inside_wrapper(self): + """The wrapped module should still be accessible inside the wrapper.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=1) + original_attn = model.layers[0] + auto_wrap_model_for_sp(model, pg) + assert model.layers[0].attn is original_attn + + +# --------------------------------------------------------------------------- +# ModalityFusionSPAdapter tests +# --------------------------------------------------------------------------- + + +class _ConcatFusionAdapter(ModalityFusionSPAdapter): + """Concrete subclass that appends visual tokens after text tokens.""" + + def _splice_visual_into_text(self, text_embeds, visual_embeds, input_ids): + return torch.cat([text_embeds, visual_embeds], dim=1) + + +class TestModalityFusionSPAdapter: + + def test_base_class_raises_not_implemented(self): + """The base _splice_visual_into_text must raise NotImplementedError.""" + pg = _make_mock_process_group(world_size=2, rank=0) + adapter = ModalityFusionSPAdapter(nn.Identity(), pg) + with pytest.raises(NotImplementedError): + adapter._splice_visual_into_text(None, None, None) + + @pytest.mark.parametrize("world_size,local_v,text_len,hidden", [ + (2, 4, 6, 8), + (4, 3, 5, 16), + (1, 8, 8, 4), + ]) + def test_output_shape(self, world_size, local_v, text_len, hidden): + """Output local_len must equal ceil(fused_len / world_size).""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs = 2 + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + + # all_gather mock copies local_v to each of world_size slots + fused_len = text_len + local_v * world_size + pad = (world_size - fused_len % world_size) % world_size + expected_local = (fused_len + pad) // world_size + assert out.shape == (bs, expected_local, hidden), f"Expected ({bs},{expected_local},{hidden}), got {out.shape}" + + def test_padding_produces_valid_output_when_not_divisible(self): + """When fused_len % world_size != 0, padding must not raise and output is well-formed.""" + world_size = 4 + # text_len=5, local_v=3 → fused_len = 5 + 3*4 = 17, needs padding of 3 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 3, 5, 4 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + # padded_len = 20, local_len = 5 + assert out.shape == (bs, 5, hidden) + + def test_no_padding_when_divisible(self): + """When fused_len is already divisible, no extra tokens should be added.""" + world_size = 4 + # text_len=4, local_v=4 → fused_len = 4 + 4*4 = 20, divisible by 4 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 4, 4, 8 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + assert out.shape == (bs, 5, hidden) # 20 // 4 = 5 + + def test_different_ranks_return_different_slices(self): + """Rank 0 and rank 1 must return different slices of the fused sequence.""" + world_size = 2 + bs, local_v, text_len, hidden = 1, 4, 4, 8 + # Use distinct text vs visual values so slices clearly differ + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, local_v, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + outputs = {} + for rank in range(world_size): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + outputs[rank] = adapter(visual.clone(), text.clone(), ids.clone()) + + # fused = [0,0,0,0, 1,1,1,1, 1,1,1,1] (text zeros then visual ones x2) + # rank 0: indices 0-5, rank 1: indices 6-11 + assert not torch.allclose(outputs[0], outputs[1]) + + def test_projection_is_applied(self): + """Projection layer must transform visual features before gather.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + + # Use a projection that doubles all values + class _DoubleProjection(nn.Module): + + def forward(self, x): + return x * 2.0 + + adapter = _ConcatFusionAdapter(_DoubleProjection(), pg) + bs, local_v, text_len, hidden = 1, 4, 4, 8 + visual = torch.ones(bs, local_v, hidden) + text = torch.zeros(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + # The visual part of the output should have value 2.0 (doubled), not 1.0 + # rank 0 gets the first local_len tokens; fused = [text(0)*4, visual(2)*8] + # Since text_len=4 and local_len=6, rank0 slice starts with text zeros + # and ends with some visual twos. + assert out.max().item() == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_IMAGE_ID = -200 # matches ModalityFusionSPAdapter default + + +def _make_llava_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + +class TestLlavaFusionAdapter: + + def test_single_image_fused_shape(self): + """One image placeholder per sample → fused length = text_len - 1 + num_visual.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 2, 6, 4, 8 + # Place a single image placeholder at position 2. + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _IMAGE_ID + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, num_vis, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # placeholder is removed and replaced by num_vis tokens + assert fused.shape == (bs, text_len - 1 + num_vis, hidden) + + def test_text_values_preserved_around_image(self): + """Text tokens before and after the placeholder must be numerically intact.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 5, 3, 4 + # Placeholder at index 2: text = [A, B, , C, D] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, num_vis, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # fused = [A, B, vis0, vis1, vis2, C, D] + assert torch.allclose(fused[0, :2], text[0, :2]) # A, B preserved + assert torch.allclose(fused[0, 5:], text[0, 3:]) # C, D preserved + assert torch.allclose(fused[0, 2:5], visual[0]) # visual inserted + + def test_no_image_token_returns_text_unchanged(self): + """When input_ids has no placeholder, output equals text_embeds exactly.""" + adapter = _make_llava_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) # no -200 + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused, text) + + def test_multi_image_splice(self): + """Two placeholders per sample → visual tokens split evenly between them.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 7, 6, 4 + # Placeholders at index 1 and 4: [t0, , t2, t3, , t5, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + ids[0, 4] = _IMAGE_ID + text = torch.zeros(bs, text_len, hidden) + # First 3 visual tokens = 1.0, last 3 = 2.0 (so we can tell them apart) + visual = torch.cat([torch.ones(bs, 3, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Expected fused length: 7 - 2 placeholders + 6 visual = 11 + assert fused.shape == (bs, 11, hidden) + # First chunk (indices 1-3) should be 1.0 + assert torch.allclose(fused[0, 1:4], torch.ones(3, hidden)) + # Second chunk (indices 6-8) should be 2.0 + assert torch.allclose(fused[0, 6:9], torch.full((3, hidden), 2.0)) + + def test_batch_padding_when_lengths_differ(self): + """Samples with different numbers of image tokens are padded to max length.""" + adapter = _make_llava_adapter() + hidden = 4 + # Sample 0: 1 placeholder in a 4-token sequence + 2 visual → fused len = 5 + # Sample 1: no placeholder in a 4-token sequence → fused len = 4 + ids = torch.zeros(2, 4, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + text = torch.ones(2, 4, hidden) + visual = torch.ones(2, 2, hidden) * 3.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Max fused length is 5; sample 1 padded with zeros at the end. + assert fused.shape == (2, 5, hidden) + assert torch.all(fused[1, 4:] == 0) # padding tokens are zero + + def test_forward_end_to_end_shape(self): + """Full forward pass through LlavaFusionAdapter returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID # one placeholder + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len = text_len - 1 + local_v * world_size = 5 + 8 = 13 + # padded to 14 (next multiple of 2), local = 7 + assert out.shape == (bs, 7, hidden) + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_CONTEXT_ID = 92546 # arbitrary IMG_CONTEXT token id for tests +_START_ID = 92545 +_END_ID = 92547 + + +def _make_internvl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + +class TestInternVLFusionAdapter: + + def test_context_tokens_replaced_with_visual(self): + """IMG_CONTEXT positions must carry visual embeddings after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 7, 4 + # Layout: [t0, START, ctx, ctx, ctx, END, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _CONTEXT_ID + ids[0, 3] = _CONTEXT_ID + ids[0, 4] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 3, hidden) * 7.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 2:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 10, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 3:7] = _CONTEXT_ID # 4 context tokens per sample + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """IMG_START and IMG_END embeddings must be unchanged after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 5, 4 + # [START, ctx, ctx, END, text] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Position 0 (START) and 3 (END) must be unchanged. + assert torch.allclose(fused[0, 0], text[0, 0]) + assert torch.allclose(fused[0, 3], text[0, 3]) + + def test_no_context_tokens_returns_text_unchanged(self): + """When there are no IMG_CONTEXT tokens the output must equal text_embeds.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two separate runs of context tokens correspond to two images.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 10, 4 + # Image 1: positions 1-2, Image 2: positions 6-7 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + ids[0, 6] = _CONTEXT_ID + ids[0, 7] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + # First 2 visual tokens = 1.0, next 2 = 2.0 + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 2, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 1:3], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 6:8], torch.full((2, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2:5] = _CONTEXT_ID # 3 context tokens; local_v * world_size = 6 total + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 8 (length-preserving); padded to 8 (divisible by 2); local = 4 + assert out.shape == (bs, 4, hidden) + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_VIS_START_ID = 151652 +_VIS_END_ID = 151653 + + +def _make_qwen2vl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + +class TestQwen2VLFusionAdapter: + + def test_inner_tokens_replaced_with_visual(self): + """Tokens between vision_start and vision_end must become visual embeddings.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 7, 4 + # [t0, t1, , pad, pad, , t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _VIS_START_ID + ids[0, 5] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 5.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 3:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 12, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _VIS_START_ID + ids[:, 8] = _VIS_END_ID # 5 inner placeholder tokens + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 5, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """vision_start and vision_end embeddings must be unchanged after splice.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 6, 4 + # [t0, , pad, pad, , t5] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 1], text[0, 1]) # vision_start preserved + assert torch.allclose(fused[0, 4], text[0, 4]) # vision_end preserved + + def test_no_vision_tokens_returns_text_unchanged(self): + """When there are no vision_start/end tokens the output must equal text_embeds.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two vision blocks are handled independently.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 14, 4 + # Block 1: positions 1 (start) .. 4 (end), 2 inner tokens at 2-3 + # Block 2: positions 8 (start) .. 12 (end), 3 inner tokens at 9-11 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + ids[0, 8] = _VIS_START_ID + ids[0, 12] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 2:4], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 9:12], torch.full((3, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + # 6 inner placeholder tokens (local_v * world_size = 6) + ids[0, 1] = _VIS_START_ID + ids[0, 8] = _VIS_END_ID + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 10 (length-preserving); padded to 10; local = 5 + assert out.shape == (bs, 5, hidden) diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 6af894bf8e62..8ec69935a9c2 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -3,14 +3,21 @@ # DeepSpeed Team -from .logging import logger, log_dist +from .logging import logger, log_dist, log_dist_once, set_log_level_from_string from .comms_logging import get_caller_func #from .distributed import init_distributed from .init_on_device import OnDevice from .groups import * from .nvtx import instrument_w_nvtx # TODO: Move tensor fragment and mixed precision to zero utils -from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad +from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state -from .mixed_precision_linkage import link_hp_params +from .tensor_fragment import set_full_hp_param, set_full_hp_grad +from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad +from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state +from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state +from .tensor_fragment import safe_update_full_grad_vectorized +from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix +from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.runtime.dataloader import RepeatingLoader +from .numa import get_numactl_cmd diff --git a/deepspeed/utils/bwc.py b/deepspeed/utils/bwc.py new file mode 100644 index 000000000000..69fcc251a684 --- /dev/null +++ b/deepspeed/utils/bwc.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + + +def bwc_tensor_model_parallel_rank(mpu=None): + """Backwards-compatible way of querying the tensor model parallel rank from + an ``mpu`` object. + + *Tensor* model parallelism means that tensors are physically split across + processes. This contrasts with *pipeline* model parallelism, in which the + layers are partitioned but tensors left intact. + + The API for tensor model parallelism has changed across versions and this + helper provides a best-effort implementation across versions of ``mpu`` + objects. The preferred mechanism is + ``mpu.get_tensor_model_parallel_rank()``. + + This should "just work" with both Megatron-LM and DeepSpeed's pipeline + parallelism. + + Args: + mpu (model parallel unit, optional): The tensor model parallel rank. + If ``mpu=None``, returns 0. Defaults to ``None``. + + Returns: + int: the rank + """ + if mpu is None: + # No model parallelism in easy :) + return 0 + + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_rank() + elif hasattr(mpu, 'get_slice_parallel_rank'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_rank() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_rank() + + +def bwc_tensor_model_parallel_world_size(mpu=None): + """Backwards-compatible way of querying the tensor model parallel world size. + Similar to bwc_tensor_model_parallel_rank. + """ + if mpu is None: + return 1 + + if hasattr(mpu, 'get_tensor_model_parallel_world_size'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_world_size() + elif hasattr(mpu, 'get_slice_parallel_world_size'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_world_size() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_world_size() + + +def bwc_tensor_model_parallel_group(mpu=None): + """Backwards-compatible way of querying the tensor model parallel group. + Similar to bwc_tensor_model_parallel_rank. + """ + if mpu is None: + return None + + if hasattr(mpu, 'get_tensor_model_parallel_group'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_group() + elif hasattr(mpu, 'get_slice_parallel_group'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_group() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_group() + + +def bwc_pipeline_parallel_world_size(mpu=None): + """Backwards-compatible way of querying the pipeline parallel world size.""" + world_size = 1 + if mpu is not None: + if hasattr(mpu, 'get_pipeline_model_parallel_world_size'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + world_size = mpu.get_pipeline_model_parallel_world_size() + elif hasattr(mpu, 'get_pipe_parallel_world_size'): + # DeepSpeed Topology + world_size = mpu.get_pipe_parallel_world_size() + return world_size + + +def bwc_pipeline_parallel_group(mpu=None): + """Backwards-compatible way of querying the pipeline parallel group.""" + if mpu is None: + return None + if hasattr(mpu, 'get_pipeline_model_parallel_group'): + # Megatron + return mpu.get_pipeline_model_parallel_group() + elif hasattr(mpu, 'get_pipe_parallel_group'): + # DeepSpeed Topology + return mpu.get_pipe_parallel_group() + assert False, 'mpu does not support pipeline parallel group' diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 2400fa55b20e..4532ac34e7be 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -12,6 +12,12 @@ def get_caller_func(frame=3): return sys._getframe(frame).f_code.co_name +def print_rank_0(message): + import deepspeed.comm as dist + if dist.get_rank() == 0: + print(message) + + # Helper function to pretty-print message sizes def convert_size(size_bytes): if size_bytes == 0: @@ -38,7 +44,7 @@ def calc_bw_log(comm_op, size, duration): size *= n tput = (size / duration) busbw = (size / duration) * ((n - 1) / n) - elif comm_op == "all_reduce": + elif comm_op == "all_reduce" or comm_op == "all_reduce_coalesced" or comm_op == "inference_all_reduce": tput = (size * 2 / duration) busbw = (size / duration) * (2 * (n - 1) / n) elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier": @@ -96,7 +102,6 @@ def stop_profiling_op(self, op_name_list): # Add log entry def append(self, raw_name, record_name, latency, msg_size): - import deepspeed.comm as dist algbw, busbw = calc_bw_log(raw_name, msg_size, latency) if record_name in self.comms_dict.keys(): # If this comm_op has already been logged with this message size, just add to existing record @@ -114,22 +119,184 @@ def append(self, raw_name, record_name, latency, msg_size): # If verbose, print every comm op # TODO: Add to tensorboard if self.verbose: - n = dist.get_world_size() - log_str = f"rank={dist.get_rank()} | comm op: " + record_name + " | time (ms): {:.2f}".format(latency) - log_str += " | msg size: " + convert_size(msg_size) - log_str += " | algbw (Gbps): {:.2f} ".format(algbw) - log_str += " | busbw (Gbps): {:.2f} ".format(busbw) + log_str = f"comm op: {record_name} | time (ms): {latency:.2f} | msg size: {convert_size(msg_size)} | algbw (Gbps): {algbw:.2f} | busbw (Gbps): {busbw:.2f}" log_dist(log_str, [0]) + def get_raw_data(self): + """ + Get the raw communication data dictionary. + + Returns: + dict: Raw communication data in format {record_name: {msg_size: [count, [latencies], [algbws], [busbws]]}} + """ + return self.comms_dict.copy() + + def has_data(self): + """ + Check if any communication data has been logged. + + Returns: + bool: True if communication data exists, False otherwise + """ + return len(self.comms_dict) > 0 + + def reset_data(self): + """ + Clear all logged communication data. + """ + self.comms_dict.clear() + + def get_operation_names(self): + """ + Get list of all logged communication operation names. + + Returns: + list: List of operation names that have been logged + """ + return list(self.comms_dict.keys()) + + def get_total_operations(self): + """ + Get total number of communication operations logged across all types. + + Returns: + int: Total count of operations + """ + total = 0 + for record_name in self.comms_dict: + for msg_size in self.comms_dict[record_name]: + total += self.comms_dict[record_name][msg_size][0] # count is at index 0 + return total + + def get_operation_summary(self, operation_name): + """ + Get summary statistics for a specific operation type. + + Args: + operation_name (str): Name of the communication operation + + Returns: + dict: Summary statistics for the operation, or None if operation not found + """ + if operation_name not in self.comms_dict: + return None + + from deepspeed.utils.timer import trim_mean + + # Create a snapshot to avoid concurrent modification issues + op_data = self.comms_dict[operation_name].copy() + summary = {} + + for msg_size, vals in op_data.items(): + count = vals[0] + total_lat = sum(vals[1]) + avg_lat = trim_mean(vals[1], 0.1) + avg_algbw = trim_mean(vals[2], 0.1) + avg_busbw = trim_mean(vals[3], 0.1) + + summary[msg_size] = { + "count": count, + "total_latency_ms": total_lat, + "avg_latency_ms": avg_lat, + "tput_avg_gbps": avg_algbw, + "busbw_avg_gbps": avg_busbw, + "msg_size_bytes": msg_size, + "msg_size_str": convert_size(msg_size) + } + + return summary + # Print summary at end of iteration, epoch, or training - def log_all(self): + def log_all(self, print_log=True, show_straggler=False, return_dict=False): + """ + Print and/or return communication operation statistics. + + Args: + print_log (bool, optional): Whether to print the summary to console. Defaults to True. + show_straggler (bool, optional): Whether to include straggler effect analysis. Defaults to False. + return_dict (bool, optional): Whether to return statistics as a dictionary. Defaults to False. + + Returns: + dict or None: If return_dict=True, returns a comprehensive dictionary with the following structure: + { + "summary": { + "operation_name": { + message_size_bytes: { + "count": int, # Number of operations with this message size + "total_latency_ms": float, # Sum of all latencies for this message size + "avg_latency_ms": float, # Average latency (outliers trimmed) + "tput_avg_gbps": float, # Average algorithmic bandwidth in Gbps + "busbw_avg_gbps": float, # Average bus bandwidth in Gbps + "msg_size_bytes": int, # Message size in bytes + "msg_size_str": str # Human-readable message size (e.g., "678.86 MB") + } + } + }, + "straggler_analysis": { # Only present if show_straggler=True + "operation_name": { + message_size_bytes: { + "count": int, # Number of operations + "total_comm_lat_ms": float, # Total communication latency (min across ranks) + "total_straggler_ms": float, # Total straggler effect + "avg_comm_lat_ms": float, # Average communication latency + "avg_straggler_ms": float, # Average straggler effect + "msg_size_bytes": int, # Message size in bytes + "msg_size_str": str # Human-readable message size + } + } + } if show_straggler else None, + "metadata": { + "world_size": int, # Number of processes in distributed setup + "rank": int, # Current process rank + "timestamp": str # ISO format timestamp when log_all was called + } + } + + Returns None if return_dict=False. + + Note: + - Statistics use trimmed mean (10% trimmed from both ends) to remove outliers + - Straggler analysis requires distributed communication and may impact performance + - All bandwidth values are in Gbps (Gigabits per second) + - Latency values are in milliseconds + """ + import torch from deepspeed.utils.timer import trim_mean - print( - f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" - ) - for record_name in self.comms_dict.keys(): - print(record_name) - for msg_size, vals in sorted(self.comms_dict[record_name].items()): + import deepspeed.comm as dist + from deepspeed.comm.reduce_op import ReduceOp + from deepspeed.accelerator import get_accelerator + from datetime import datetime + + # Create a snapshot of the dictionary to avoid concurrent modification issues + # This prevents "dictionary changed size during iteration" errors when + # communication operations are happening in other threads + comms_dict_snapshot = self.comms_dict.copy() + + # Initialize return dictionary structure + result_dict = { + "summary": {}, + "straggler_analysis": None, + "metadata": { + "world_size": dist.get_world_size() if dist.is_initialized() else 1, + "rank": dist.get_rank() if dist.is_initialized() else 0, + "timestamp": datetime.now().isoformat() + } + } if return_dict else None + + if print_log: + print( + f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" + ) + + for record_name in comms_dict_snapshot.keys(): + if print_log: + print(record_name) + + # Initialize operation entry in result dict + if return_dict: + result_dict["summary"][record_name] = {} + + for msg_size, vals in sorted(comms_dict_snapshot[record_name].items()): # vals[0] is the count for each msg size count = vals[0] # vals[1] is a list of latency records for each msg size @@ -139,6 +306,73 @@ def log_all(self): avg_lat = trim_mean(vals[1], 0.1) avg_algbw = trim_mean(vals[2], 0.1) avg_busbw = trim_mean(vals[3], 0.1) + + # Store data in result dictionary + if return_dict: + result_dict["summary"][record_name][msg_size] = { + "count": count, + "total_latency_ms": total_lat, + "avg_latency_ms": avg_lat, + "tput_avg_gbps": avg_algbw, + "busbw_avg_gbps": avg_busbw, + "msg_size_bytes": msg_size, + "msg_size_str": convert_size(msg_size) + } + + if print_log: + print( + f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" + ) + + if show_straggler: + if return_dict: + result_dict["straggler_analysis"] = {} + + if print_log: + print("_______________________________") + print("Breakdown with straggler effect") + print("-------------------------------") print( - f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" + f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}" ) + + device = get_accelerator().current_device_name() + for record_name in comms_dict_snapshot.keys(): + if print_log: + print(record_name) + + # Initialize operation entry in straggler dict + if return_dict: + result_dict["straggler_analysis"][record_name] = {} + + for msg_size, vals in sorted(comms_dict_snapshot[record_name].items()): + # vals[0] is the count for each msg size + count = vals[0] + # vals[1] is a list of latency records for each msg size + lats = torch.tensor(vals[1], device=device) + min_lats = torch.tensor(vals[1], device=device) + dist.all_reduce(min_lats, op=ReduceOp.MIN) + total_lat = min_lats.sum().item() + total_straggler = (lats - min_lats).sum().item() + avg_lat = trim_mean(min_lats.tolist(), 0.1) + avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1) + + # Store straggler data in result dictionary + if return_dict: + result_dict["straggler_analysis"][record_name][msg_size] = { + "count": count, + "total_comm_lat_ms": total_lat, + "total_straggler_ms": total_straggler, + "avg_comm_lat_ms": avg_lat, + "avg_straggler_ms": avg_straggler, + "msg_size_bytes": msg_size, + "msg_size_str": convert_size(msg_size) + } + + if print_log: + print( + f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}" + ) + + # Return the dictionary if requested + return result_dict if return_dict else None diff --git a/deepspeed/utils/config.py b/deepspeed/utils/config.py new file mode 100644 index 000000000000..15f37ca7d874 --- /dev/null +++ b/deepspeed/utils/config.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + +######################################### +# Timers +######################################### +# Timers. By default, timers are enabled. +# Users can configure in ds_config.json as below example: +TIMERS_FORMAT = ''' +Timers should be enabled as: +"timers": { + "throughput": { + "enabled": true, + "synchronized": true + } +} +''' + +TIMERS = "timers" +TIMERS_THROUGHPUT = "throughput" + + +def get_timers_config(param_dict): + if param_dict and TIMERS in param_dict and TIMERS_THROUGHPUT in param_dict[TIMERS]: + timers_config_dict = param_dict[TIMERS][TIMERS_THROUGHPUT] + else: + timers_config_dict = {} + return DeepSpeedThroughputTimerConfig(**timers_config_dict) + + +class DeepSpeedThroughputTimerConfig(DeepSpeedConfigModel): + """ Configure throughput timers """ + + enabled: bool = True + """ Turn on/off throughput timers """ + + synchronized: bool = True + """ Whether to synchronize a device when measuring the time. + Synchronizing a device is required to produce the most accurate timer measurements. + However, this comes at the expense of performance degradation. The CPU timer provides + sufficient accuracy in many cases. + """ diff --git a/deepspeed/utils/debug.py b/deepspeed/utils/debug.py index b693915e531b..f644562deee9 100644 --- a/deepspeed/utils/debug.py +++ b/deepspeed/utils/debug.py @@ -3,6 +3,8 @@ # DeepSpeed Team +import deepspeed.comm as dist + # For lazy import with printflock() fcntl = None @@ -11,6 +13,13 @@ param_names = {} +def debug_clear_module_and_param_names(): + global module_names + global param_names + module_names = {} + param_names = {} + + def debug_extract_module_and_param_names(model): # extract the fully qualified names as soon as the model is acquired global module_names @@ -28,7 +37,7 @@ def debug_module2name(module): def debug_module2name_id(module): - return f"name={debug_module2name(module)} id={module.id}" + return f"name={debug_module2name(module)}" def debug_module2name_class(module): @@ -42,24 +51,38 @@ def debug_param2name(param): return "unknown" +def ds_id(param): + if hasattr(param, "ds_id"): + return param.ds_id + else: + return "none" + + +def ds_shape(param): + if hasattr(param, "ds_shape"): + return param.ds_shape + else: + return param.shape + + def debug_param2name_id(param): - return f"name={debug_param2name(param)} id={param.ds_id}" + return f"name={debug_param2name(param)} id={ds_id(param)}" def debug_param2name_id_shape(param): - return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape}" + return f"name={debug_param2name(param)} id={ds_id(param)} shape={ds_shape(param)}" def debug_param2name_id_shape_device(param): - return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} device={param.device}" + return f"name={debug_param2name(param)} id={ds_id(param)} shape={ds_shape(param)} device={param.device}" def debug_param2name_id_numel(param): - return f"name={debug_param2name(param)} id={param.ds_id} numel={param.numel()}" + return f"name={debug_param2name(param)} id={ds_id(param)} numel={param.numel()}" def debug_param2name_id_shape_status(param): - return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} status={param.ds_status}" + return f"name={debug_param2name(param)} id={ds_id(param)} shape={ds_shape(param)} status={param.ds_status}" def printflock(*msgs): @@ -83,7 +106,7 @@ def print_rank_0(message, debug=False, force=False): """ global fcntl - if fcntl == None: + if fcntl is None: import fcntl with open(__file__, "r") as fh: @@ -144,3 +167,21 @@ def _print_bwd_tensors(grad_fn): if hasattr(tensor, 'grad_fn'): _print_bwd_tensors(tensor.grad_fn) + + +def print_rank(*msg, force=False): + """print something on all global ranks with [rank] prefix. + """ + if not force: + return + global_rank = dist.get_rank() + print(f"[{global_rank}]", *msg) + + +def print_rank0(*msg, force=False): + """print something only on rank 0""" + if not force: + return + global_rank = dist.get_rank() + if global_rank == 0: + print(f"[{global_rank}]", *msg) diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index dc58bb3c780f..a6f0a7228977 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -26,20 +26,31 @@ """ from deepspeed import comm as dist - from deepspeed.utils import log_dist +from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size from deepspeed.utils.exceptions import DeprecatedException +from deepspeed.accelerator import get_accelerator # Expert parallel group that the current rank belongs to. _EXPERT_PARALLEL_GROUP = {} +# Mapping of expert parallel group to ranks +_EXPERT_PARALLEL_GROUP_RANKS = {} # Expert data parallel group that the current rank belongs to. _EXPERT_DATA_PARALLEL_GROUP = {} +# Mapping of expert data parallel group to ranks +_EXPERT_DATA_PARALLEL_GROUP_RANKS = {} # dist world group needs to be cloned for some cases _WORLD_GROUP = None +# ZeRO parameter partitioning group that the current rank belongs to. +_ZERO_PARAM_INTRA_PARALLEL_GROUP = None # global object to maintain mpu object if passed by a Megatron client mpu = None # global object that stores tensor parallel world size for experts expert_tensor_parallel_world_size = 1 +# All to All quantized graident communication groups +_ALL_TO_ALL_GROUP = {} + +mesh_device = None # Deprecated groups initialize function. @@ -55,6 +66,127 @@ def _ensure_divisibility(numerator, denominator): assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) +# ======== Start: Tensor Parallel Group Attributes ======== + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None + +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None + + +def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None): + """Initialize model data parallel groups.""" + + global _DATA_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GROUP + + if _TENSOR_MODEL_PARALLEL_GROUP is not None: + return + + if data_parallel_size is None: + data_parallel_size = dist.get_world_size() // tensor_model_parallel_size + + mesh_device = dist.initialize_mesh_device((data_parallel_size, tensor_model_parallel_size), + ("data_parallel", "tensor_parallel")) + _TENSOR_MODEL_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="tensor_parallel") + _DATA_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="data_parallel") + + # They are always equal only in 2D (DP + TP) parallelism. + # _MODEL_PARALLEL_GROUP is assigned the same value as _TENSOR_MODEL_PARALLEL_GROUP + # to allow for future potential changes. + _MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP + + return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_model_parallel_world_size(): + return get_tensor_model_parallel_world_size() + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return dist.get_rank(group=get_tensor_model_parallel_group()) + + +def get_model_parallel_rank(): + return get_tensor_model_parallel_rank() + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return dist.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return dist.get_rank(group=get_data_parallel_group()) + + +# ======== End: Tensor Parallel Group Attributes ======== + + # Not currently used. Helper function to create a model (tensor) parallel group. def _create_model_parallel(model_parallel_size_): """ @@ -105,7 +237,7 @@ def _create_model_parallel(model_parallel_size_): return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP -def _create_expert_and_data_parallel(expert_parallel_size_): +def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False): """ Create expert and data parallel groups. @@ -117,43 +249,74 @@ def _create_expert_and_data_parallel(expert_parallel_size_): expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology """ assert dist.is_initialized() log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0]) world_size = dist.get_world_size() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) rank = dist.get_rank() - _ensure_divisibility(world_size, expert_parallel_size_) + pp_stride = world_size // pp_world_size + _ensure_divisibility(pp_stride, expert_parallel_size_) group_name = f"ep_size_{expert_parallel_size_}" # Build the expert data parallel groups. global _EXPERT_DATA_PARALLEL_GROUP + global _EXPERT_DATA_PARALLEL_GROUP_RANKS + + ep_stride = pp_stride // expert_parallel_size_ # Only create group if it does not already exist if group_name not in _EXPERT_DATA_PARALLEL_GROUP: - for i in range(expert_parallel_size_): - ranks = range(i, world_size, expert_parallel_size_) - group = dist.new_group(ranks) - log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank % expert_parallel_size_): - _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(expert_parallel_size_): + if use_data_before_expert_parallel_: + ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride) + else: + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_) + group = dist.new_group(ranks) + log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', + [0]) + if rank in ranks: + _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = ranks # Build the expert parallel groups. global _EXPERT_PARALLEL_GROUP + global _EXPERT_PARALLEL_GROUP_RANKS # Only create group if it does not already exist if group_name not in _EXPERT_PARALLEL_GROUP: - for i in range(world_size // expert_parallel_size_): - ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) - group = dist.new_group(ranks) - log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank // expert_parallel_size_): - _EXPERT_PARALLEL_GROUP[group_name] = group - - -def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel_size_): + if use_data_before_expert_parallel_: + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(ep_stride): + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride) + group = dist.new_group(ranks) + log_dist( + f'creating expert parallel process group named {group_name} ' + f'with ranks: {list(ranks)}', [0]) + if rank in ranks: + _EXPERT_PARALLEL_GROUP[group_name] = group + _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks + else: + for i in range(world_size // expert_parallel_size_): + ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) + group = dist.new_group(ranks) + log_dist(f'creating expert parallel process group named {group_name} ' + f'with ranks: {list(ranks)}', [0]) + if rank in ranks: + _EXPERT_PARALLEL_GROUP[group_name] = group + _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks + + +def _get_expert_parallel_ranks(world_size, + tensor_parallel_size_, + expert_parallel_size_, + pipeline_parallel_size_=1, + use_data_before_expert_parallel_=False): """Generate expert parallel and expert data parallel group ranks list. Example - E + M + D parallel @@ -167,21 +330,40 @@ def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel Args: world_size (int): Distributed world size. - model_parallel_size_ (int): Model parallel group size. + tensor_parallel_size_ (int): Tensor parallel group size. expert_parallel_size_ (int): Expert parallel group size. - + pipeline_parallel_size_ (int): Pipeline parallel group size + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology Returns: Expert parallel group ranks and Expert data parallel group ranks list. """ - _ensure_divisibility(world_size, model_parallel_size_) - dp_world_size = world_size // model_parallel_size_ + _ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_) + dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_) _ensure_divisibility(dp_world_size, expert_parallel_size_) # Generate data parallel groups data_parallel_groups = [] - dp_group_size = model_parallel_size_ - for i in range(dp_group_size): - data_parallel_groups.append(list(range(i, world_size, dp_group_size))) + dp_group_size = tensor_parallel_size_ + pp_stride = world_size // pipeline_parallel_size_ + + if use_data_before_expert_parallel_: + dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_ + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list()) + for ds in range(dp_stride): + # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] + # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] + data_parallel_groups[-1].extend( + list( + range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next, + dp_stride * tensor_parallel_size_))) + else: + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size))) expert_parallel_groups = [] expert_data_parallel_groups = [] @@ -199,7 +381,7 @@ def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel return expert_parallel_groups, expert_data_parallel_groups -def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): +def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False): """ Create expert and data parallel groups based on MPU (model parallel) group. @@ -215,28 +397,26 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] """ assert dist.is_initialized(), "dist is not initialized" - model_parallel_size_ = mpu.get_model_parallel_world_size() + tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu) global expert_tensor_parallel_world_size - expert_tensor_parallel_world_size = model_parallel_size_ + expert_tensor_parallel_world_size = tensor_parallel_size_ world_size = dist.get_world_size() rank = dist.get_rank() - dp_world_size = mpu.get_data_parallel_world_size() - dp_rank = mpu.get_data_parallel_rank() + dp_world_size = _get_data_parallel_world_size() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) - _ensure_divisibility(world_size, model_parallel_size_) + _ensure_divisibility(world_size, tensor_parallel_size_) _ensure_divisibility(dp_world_size, expert_parallel_size_) log_dist( - f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}", - [0]) + f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, " + f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, " + f"world size {world_size}, dp world size {dp_world_size}", [0]) global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP - - # Get world size and rank. Ensure some consistencies. - _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group() - _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group() + global _EXPERT_PARALLEL_GROUP_RANKS, _EXPERT_DATA_PARALLEL_GROUP_RANKS group_name = f"ep_size_{expert_parallel_size_}" @@ -244,16 +424,18 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): # Need to check conditions outside the group creation loop because of the way torch.dist group creation works if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( - world_size, model_parallel_size_, expert_parallel_size_) + world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_) for ranks in expert_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks): _EXPERT_PARALLEL_GROUP[group_name] = group + _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks for ranks in expert_data_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks): _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = ranks def _get_max_expert_size(): @@ -284,6 +466,13 @@ def _get_expert_parallel_group(group_name): return _EXPERT_PARALLEL_GROUP[group_name] +def _get_expert_parallel_group_ranks(group_name): + """Get the ranks of the expert parallel group the caller rank belongs to.""" + assert group_name in _EXPERT_PARALLEL_GROUP_RANKS, \ + 'expert parallel group is not initialized' + return _EXPERT_PARALLEL_GROUP_RANKS[group_name] + + def _get_expert_parallel_group_dict(): """Get the expert parallel group dict.""" return _EXPERT_PARALLEL_GROUP @@ -296,6 +485,13 @@ def _get_expert_data_parallel_group(group_name): return _EXPERT_DATA_PARALLEL_GROUP[group_name] +def _get_expert_data_parallel_group_ranks(group_name): + """Get the ranks of the expert data parallel group the caller rank belongs to.""" + assert group_name in _EXPERT_DATA_PARALLEL_GROUP_RANKS, \ + 'expert data parallel group is not initialized' + return _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] + + def _get_expert_data_parallel_group_dict(): """Get the expert data parallel group dict.""" return _EXPERT_DATA_PARALLEL_GROUP @@ -316,43 +512,94 @@ def _clone_world_group(): return _WORLD_GROUP +def _get_local_all_to_all_group(): + assert dist.is_initialized(), 'dist is not initialized' + global _ALL_TO_ALL_GROUP + device_per_node = get_accelerator().device_count() + num_local = dist.get_world_size() // device_per_node + if num_local == 0 and dist.get_world_size() > 0: + assert dist.get_world_size() >= 1, 'num_gpus must >=1, cannot initialize All-To-All' + cur_rank = [] + for i in range(dist.get_world_size()): + cur_rank.append(i) + _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=cur_rank) + elif num_local == 1: + assert dist.get_world_size( + ) == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All' + _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=[i for i in range(device_per_node)]) + else: + assert dist.get_world_size() > device_per_node, 'num_nodes<2 cannot initialize All-To-All' + for i in range(num_local): + local_rank = [j + device_per_node * i for j in range(device_per_node)] + _ALL_TO_ALL_GROUP[f"local_{i}"] = dist.new_group(ranks=local_rank) + + for i in range(device_per_node): + cur_rank = [] + for j in range(num_local): + cur_rank.append(i + j * device_per_node) + _ALL_TO_ALL_GROUP[f"global_{i}"] = dist.new_group(ranks=cur_rank) + return _ALL_TO_ALL_GROUP + + def _get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" - assert dist.is_initialized(), \ - 'dist is not initialized' + assert dist.is_initialized(), 'dist is not initialized' global mpu + if mesh_device is not None: + return mesh_device.get_group(mesh_dim="data_parallel") if mpu is not None: - return mpu.get_data_parallel_group() + if hasattr(mpu, 'initialize_sequence_parallel'): + return None + else: + return mpu.get_data_parallel_group() + # Return the clone of dist world group return _clone_world_group() +def _get_data_parallel_group_ranks(): + """Get the ranks of data parallel group the caller rank belongs to.""" + assert dist.is_initialized(), \ + 'dist is not initialized' + global mpu + if mpu is not None: + return mpu.get_data_parallel_group_ranks() + # Return all ranks + return range(dist.get_world_size()) + + def _get_broadcast_src_rank(): - return dist.get_global_rank(_get_data_parallel_group(), 0) + assert dist.is_initialized(), 'dist is not initialized' + return dist.get_global_rank(_get_sequence_data_parallel_group(), 0) def _get_expert_broadcast_src_rank(group_name): + assert dist.is_initialized(), 'dist is not initialized' return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0) def _get_expert_parallel_world_size(group_name): """Return world size for the expert parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' return dist.get_world_size(group=_get_expert_parallel_group(group_name)) def _get_expert_data_parallel_world_size(group_name): """Return world size for the expert data parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' return dist.get_world_size(group=_get_expert_data_parallel_group(group_name)) def _get_expert_parallel_rank(group_name): """Return my rank for the expert parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' return dist.get_rank(group=_get_expert_parallel_group(group_name)) def _get_expert_parallel_src_rank(group_name): """Calculate the global rank corresponding to a local rank zero in the expert parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' global_rank = dist.get_rank() local_world_size = _get_expert_parallel_world_size(group_name) return (global_rank // local_world_size) * local_world_size @@ -360,33 +607,153 @@ def _get_expert_parallel_src_rank(group_name): def _get_expert_data_parallel_rank(group_name): """Return my rank for the expert data parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' return dist.get_rank(group=_get_expert_data_parallel_group(group_name)) def _get_data_parallel_world_size(): """Return world size for the data parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' + if mesh_device is not None: + return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) global mpu if mpu is not None: - return mpu.get_data_parallel_world_size() + if hasattr(mpu, 'initialize_sequence_parallel'): + return None + else: + return mpu.get_data_parallel_world_size() return dist.get_world_size(group=_get_data_parallel_group()) def _get_model_parallel_world_size(): """Return world size for the model parallel group.""" global mpu - if mpu is not None: - return mpu.get_model_parallel_world_size() - return 1 + if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'): + return 1 + return mpu.get_model_parallel_world_size() def _get_data_parallel_rank(): """Return my rank for the data parallel group.""" - global mpu - if mpu is not None: - return mpu.get_data_parallel_rank() + assert dist.is_initialized(), 'dist is not initialized' return dist.get_rank(group=_get_data_parallel_group()) +def _get_sequence_parallel_world_size(): + """Return world size for the model parallel group.""" + assert dist.is_initialized(), 'dist is not initialized' + global mpu + if mesh_device is not None: + return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel")) + if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'): + return mpu.get_sequence_parallel_world_size() + return 1 + + +def _get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): + return mpu.get_sequence_parallel_rank() + if mesh_device is not None: + return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel")) + return 0 + + +def _get_sequence_parallel_group(): + global mpu + if mpu is None or not hasattr(mpu, 'get_sequence_parallel_group'): + if mesh_device is None: + raise KeyError("No sequence parallel group found") + return mesh_device.get_group(mesh_dim="sequence_parallel") + return mpu.get_sequence_parallel_group() + + +def _get_sequence_data_parallel_world_size(): + """Return world size for the model parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_world_size'): + return mpu.get_sequence_data_parallel_world_size() + return _get_data_parallel_world_size() + + +def _get_sequence_data_parallel_rank(): + """Return my rank for the data parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_rank'): + return mpu.get_sequence_data_parallel_rank() + return _get_data_parallel_rank() + + +def _get_sequence_data_parallel_group(): + global mpu + # When sequence parallelism is enabled, the process group for zero sharding and + # gradient allreduce must be across both dimensions of data and sequence parallelism. + if mpu is not None and hasattr(mpu, 'get_sequence_data_parallel_group'): + return mpu.get_sequence_data_parallel_group() + return _get_data_parallel_group() + + def _get_expert_model_parallel_world_size(): global expert_tensor_parallel_world_size return expert_tensor_parallel_world_size + + +def _create_zero_param_parallel_group(group_size): + """ + Create parameter partitioning group within ZeRO data parallel groups. + + Example - ZP + D parallel + world_size = 16 + zero_hpz_partition_size = 2 # number of ranks with replicated params (dual partitioning) + zero_param_intra_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - segmented (subgroup) with rep partition + data_parallel_group = [0,1,...,15] - all reduce is on ZeRO model + """ + assert dist.is_initialized() + global _ZERO_PARAM_INTRA_PARALLEL_GROUP + # Only create group if it does not already exist + assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is None, \ + 'ZeRO parameter intra parallel group is already initialized' + + world_size = dist.get_world_size() + rank = dist.get_rank() + + zero_param_parallel_size_ = min(group_size, world_size) + _ensure_divisibility(world_size, zero_param_parallel_size_) + + # Build the ZeRO param intra parallel groups. + for i in range(world_size // zero_param_parallel_size_): + ranks = range(i * zero_param_parallel_size_, (i + 1) * zero_param_parallel_size_) + group = dist.new_group(ranks) + if i == (rank // zero_param_parallel_size_): + _ZERO_PARAM_INTRA_PARALLEL_GROUP = group + + +def _get_zero_param_intra_parallel_group(): + """Get the ZeRO parameter partitioning intra parallel group the caller rank belongs to.""" + #assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None, \ + # 'ZeRO parameter partitioning group is not initialized' + #TODO: Add warning + return _ZERO_PARAM_INTRA_PARALLEL_GROUP + + +def _zero_param_parallel_is_initialized(): + """Check if ZeRO data parallel with parameter partititioning groups are initialized.""" + ###TODO: assert that MPU is not set + if _ZERO_PARAM_INTRA_PARALLEL_GROUP is None and _DATA_PARALLEL_GROUP is None: + return False + + +def _get_zero_param_intra_parallel_rank_in_mygroup(): + """Return my rank for the ZeRO parameter inter parallel group.""" + return dist.get_rank(group=_get_zero_param_intra_parallel_group()) + + +def _get_zero_param_intra_parallel_group_world_size(): + """Return world size for the ZeRO parameter parallel group.""" + return dist.get_world_size(group=_get_zero_param_intra_parallel_group()) + + +def _get_zero_param_intra_parallel_group_ranks(): + """Return all ranks for the ZeRO parameter intra parallel group.""" + return dist.get_all_ranks_from_group(group=_get_zero_param_intra_parallel_group()) diff --git a/deepspeed/utils/logging.py b/deepspeed/utils/logging.py index 1e62d96e1032..8e3a8b6d5a5d 100644 --- a/deepspeed/utils/logging.py +++ b/deepspeed/utils/logging.py @@ -7,6 +7,8 @@ import logging import sys import os +import torch +from deepspeed.utils.torch import required_torch_version log_levels = { "debug": logging.DEBUG, @@ -20,7 +22,7 @@ class LoggerFactory: @staticmethod - def create_logger(name=None, level=logging.INFO): + def create_logger(name=None, level=logging.WARNING): """create a logger Args: @@ -44,10 +46,19 @@ def create_logger(name=None, level=logging.INFO): ch.setLevel(level) ch.setFormatter(formatter) logger_.addHandler(ch) + if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1": + excluded_set = { + item.strip() + for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",") + } + ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set + for method in ignore_set: + original_logger = getattr(logger_, method) + torch._dynamo.config.ignore_logger_methods.add(original_logger) return logger_ -logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO) +logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.WARNING) @functools.lru_cache(None) @@ -72,18 +83,20 @@ def print_configuration(args, name): logger.info(" {} {} {}".format(arg, dots, getattr(args, arg))) -def log_dist(message, ranks=None, level=logging.INFO): +def get_dist_msg(message, ranks=None): from deepspeed import comm as dist - """Log message when one of following condition meets + """Return a message with rank prefix when one of following conditions is met: - + not dist.is_initialized() - + dist.get_rank() in ranks if ranks is not None or ranks = [-1] + + not dist.is_initialized() + + dist.get_rank() in ranks if ranks is not None or ranks = [-1] + + If neither is met, `None` is returned. + + Example: "hello" => "[Rank 0] hello" Args: message (str) ranks (list) - level (int) - """ should_log = not dist.is_initialized() ranks = ranks or [] @@ -92,10 +105,54 @@ def log_dist(message, ranks=None, level=logging.INFO): should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) if should_log: - final_message = "[Rank {}] {}".format(my_rank, message) + return "[Rank {}] {}".format(my_rank, message) + else: + return None + + +def log_dist(message, ranks=None, level=logging.INFO): + """Log message when get_dist_msg() deems it should be logged, see its docstring for details. + + Args: + message (str) + ranks (list) + level (int) + """ + final_message = get_dist_msg(message, ranks) + if final_message is not None: logger.log(level, final_message) +def print_dist(message, ranks=None): + """print message when get_dist_msg() deems it should be logged, see its docstring for details. + + Use this function instead of `log_dist` when the log level shouldn't impact whether the message should be printed or not. + + Args: + message (str) + ranks (list) + """ + final_message = get_dist_msg(message, ranks) + if final_message is not None: + print(final_message) + + +@functools.lru_cache(None) +def _log_dist_once_cached(message, ranks_key, level): + ranks_arg = list(ranks_key) if ranks_key is not None else None + log_dist(message, ranks=ranks_arg, level=level) + + +def log_dist_once(message, ranks=None, level=logging.INFO): + # Identical to `log_dist`, but will emit each unique message only once per process. + # ranks is a list which is unhashable, so convert to tuple for caching + ranks_key = tuple(ranks) if ranks is not None else None + _log_dist_once_cached(message, ranks_key, level) + + +logger.log_dist_once = log_dist_once + + def print_json_dist(message, ranks=None, path=None): from deepspeed import comm as dist """Print message when one of following condition meets @@ -123,6 +180,31 @@ def print_json_dist(message, ranks=None, path=None): os.fsync(outfile) +def get_log_level_from_string(log_level_str): + """converts a log level string into its numerical equivalent. e.g. "info" => `logging.INFO` + """ + log_level_str = log_level_str.lower() + if log_level_str not in log_levels: + raise ValueError( + f"{log_level_str} is not one of the valid logging levels. Valid log levels are {log_levels.keys()}.") + return log_levels[log_level_str] + + +def set_log_level_from_string(log_level_str, custom_logger=None): + """Sets a log level in the passed `logger` and its handlers from string. e.g. "info" => `logging.INFO` + + Args: + log_level_str: one of 'debug', 'info', 'warning', 'error', 'critical' + custom_logger: if `None` will use the default `logger` object + """ + log_level = get_log_level_from_string(log_level_str) + if custom_logger is None: + custom_logger = logger + custom_logger.setLevel(log_level) + for handler in custom_logger.handlers: + handler.setLevel(log_level) + + def get_current_level(): """ Return logger's current log level @@ -145,8 +227,5 @@ def should_log_le(max_log_level_str): if not isinstance(max_log_level_str, str): raise ValueError(f"{max_log_level_str} is not a string") - max_log_level_str = max_log_level_str.lower() - if max_log_level_str not in log_levels: - raise ValueError(f"{max_log_level_str} is not one of the `logging` levels") - - return get_current_level() <= log_levels[max_log_level_str] + max_log_level = get_log_level_from_string(max_log_level_str) + return get_current_level() <= max_log_level diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index ecc29e930954..c97515ca8fef 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -5,16 +5,23 @@ import types from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping +from deepspeed.utils import set_full_hp_param, set_full_hp_grad def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group): + param_group_index, partition_start, partition_size, dp_group): local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group) for lp_param, lp_start in local_lp_param_and_offset: lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, - partition_start, partition_size, partition_optimizer_state) + partition_start, partition_size) + + +def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state): + for lp in lp_param_list: + if lp._hp_mapping is not None: + lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition]) def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): @@ -27,6 +34,8 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr lp_param._dp_group = dp_group lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param) + lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param) + lp_param.set_full_hp_grad = types.MethodType(set_full_hp_grad, lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, diff --git a/deepspeed/utils/numa.py b/deepspeed/utils/numa.py new file mode 100644 index 000000000000..ded088c511fb --- /dev/null +++ b/deepspeed/utils/numa.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +# return a list of list for cores to numa mapping +# [ +# [ cores for numa 0 ] +# [ cores belong to numa 1 ] +# ... +# ] + +import os +import psutil +import shutil +import subprocess + + +# return a list of list for cores to numa mapping +# [ +# [ cores for numa 0 ] +# [ cores belong to numa 1 ] +# ... +# ] +def get_numa_cores(): + ret = [] + try: + output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8") + except Exception: + return [] + lines = output.split('\n') + for line in lines: + if line.startswith('available:'): + num_numas = int(line.split(' ')[1]) + break + for numa in range(num_numas): + for line in lines: + if line.startswith(f'node {numa} cpus:'): + cores = line.split(' ')[3:] + ret.append([int(core) for core in cores]) + return ret + + +def check_for_numactl_pkg(): + libs = dict( + dpkg=["-l", "numactl", "apt"], + pacman=["-Q", "numactl", "pacman"], + rpm=["-q", "numactl", "yum"], + ) + + found = False + for pkgmgr, data in libs.items(): + flag, lib, tool = data + path = shutil.which(pkgmgr) + if path is not None: + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.wait() == 0: + found = True + else: + print(f"please install the {lib} package with {tool}") + break + return found + + +def parse_range(rng): + try: + value = int(rng) + return range(value, value + 1) + except ValueError: + # value is not a single number + parts = rng.split('-') + if len(parts) != 2: + raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" % + (rng, )) + start = int(parts[0]) + end = int(parts[1]) + if start > end: + raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, )) + return range(start, end + 1) + + +# parse comma and dash separated range list into list +# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6] +# rules: +# 1. Range list number be comma separated, each item are either a single number, +# or a range marked by two numbers (both number are included in the range) +# 2. Sub ranges must be in ascend order and not overlap with each other +# 3. No space in the range expression +def parse_range_list(range_str): + number_list = [] + last = -1 + range_list = range_str.split(',') + for sub_range in range_list: + sub_number_list = parse_range(sub_range) + if sub_number_list[0] <= last: + raise ValueError( + "Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" % + (range_str, )) + last = sub_number_list[-1] + number_list.extend(sub_number_list) + return number_list + + +def get_numactl_cmd(bind_core_list, num_local_procs, local_rank): + numactl_cmd = [] + check_for_numactl_pkg() + if 'KMP_AFFINITY' in os.environ.keys(): + raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl " + "because it interfere with how many CPU cores numactl can set. " + "Unset KMP_AFFINITY before launching deepspeed.\n\n" + "\t$ unset KMP_AFFINITY\n" + "\t$ deepspeed ") + if bind_core_list is not None: + core_list = parse_range_list(bind_core_list) + total_cores = len(core_list) + else: + total_cores = psutil.cpu_count(logical=False) + core_list = range(total_cores) + cores_per_rank = total_cores // num_local_procs + assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank" + core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)] + numactl_cmd.append("numactl") + + # check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter + numa_cores = get_numa_cores() + num_numas = len(numa_cores) + + numa_mode = "normal" + + non_empty_numa_list = [] + empty_numa_list = [] + previous_numa_cores = [] + numa_node_list = [] + numa_node_list_list = [] + for i in range(num_numas): + # look for empty numa which is HBM numa + if numa_cores[i] == []: + empty_numa_list.append(i) + else: + non_empty_numa_list.append(i) + + # check for fakenuma + if numa_cores[i] == previous_numa_cores: + if numa_node_list == []: + #first duplication, add previous node into list + numa_node_list.append(i - 1) + numa_node_list.append(i) + else: + if numa_node_list != []: + numa_node_list_list.append(numa_node_list) + numa_node_list = [] + previous_numa_cores = numa_cores[i] + if numa_node_list != []: + numa_node_list_list.append(numa_node_list) + + if empty_numa_list != [] and len(empty_numa_list) == len(non_empty_numa_list): + numa_mode = "flat_hbm" + numa_dict = dict(zip(non_empty_numa_list, empty_numa_list)) + elif numa_node_list_list != []: + numa_mode = "fake" + + if numa_mode == "normal": + for i in range(num_numas): + if set(core_list_for_rank) <= set(numa_cores[i]): + numactl_cmd.append("-m") + numactl_cmd.append(f"{i}") + break + elif numa_mode == "flat_hbm": + for i in range(num_numas): + if set(core_list_for_rank) <= set(numa_cores[i]): + numactl_cmd.append("-p") + numactl_cmd.append(f"{numa_dict[i]}") + break + elif numa_mode == "fake": + for i in range(num_numas): + if set(core_list_for_rank) <= set(numa_cores[i]): + for nodes in numa_node_list_list: + if i in nodes: + numactl_cmd.append("-m") + numactl_cmd.append(f"{','.join(map(str, nodes))}") + break + # the following construct break the outer loop if inner loop breaks + else: + continue + break + + numactl_cmd.append("-C") + last_core = core_list_for_rank[0] + first_core = last_core + core_list_str = f"{last_core}" + for core_id in core_list_for_rank[1:]: + if core_id == last_core + 1: + last_core = core_id + continue + else: + if first_core == last_core: + core_list_str = f"{core_list_str},{core_id}" + else: + core_list_str = f"{core_list_str}-{last_core},{core_id}" + first_core = core_id + last_core = core_id + if first_core != last_core: + core_list_str = f"{core_list_str}-{last_core}" + numactl_cmd.append(f"{core_list_str}") + return cores_per_rank, numactl_cmd diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 3823599e7bf2..20725acbdde6 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -4,16 +4,35 @@ # DeepSpeed Team from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.compiler import is_compiling + +enable_nvtx = True +DEEPSPEED_NVTX_DOMAIN = "DeepSpeed" + + +def _range_push(accelerator, msg): + if getattr(accelerator, "supports_nvtx_domain", False): + return accelerator.range_push(msg, domain=DEEPSPEED_NVTX_DOMAIN) + return accelerator.range_push(msg) + + +def _range_pop(accelerator): + if getattr(accelerator, "supports_nvtx_domain", False): + return accelerator.range_pop(domain=DEEPSPEED_NVTX_DOMAIN) + return accelerator.range_pop() def instrument_w_nvtx(func): - """decorator that causes an NVTX range to be recorded for the duration of the - function call.""" + """Decorator that records an NVTX range for the duration of the function call. + Skips NVTX instrumentation when torch.compile is active to avoid graph breaks. + """ def wrapped_fn(*args, **kwargs): - get_accelerator().range_push(func.__qualname__) + if enable_nvtx and not is_compiling(): + _range_push(get_accelerator(), func.__qualname__) ret_val = func(*args, **kwargs) - get_accelerator().range_pop() + if enable_nvtx and not is_compiling(): + _range_pop(get_accelerator()) return ret_val return wrapped_fn diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index d117defc9875..1947ec3d8853 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -6,6 +6,7 @@ import torch from dataclasses import dataclass from deepspeed import comm as dist +from typing import Dict, List, Callable @dataclass @@ -20,11 +21,11 @@ class tensor_fragment: lp_fragment_address: fragment_address hp_fragment: torch.Tensor hp_fragment_address: fragment_address - optim_fragment: {} - gradient_dict: {} - offload_gradient_dict: {} + gradient_dict: Dict + offload_gradient_dict: Dict use_offload: bool param_group_index: int + optim_fragment: Dict = None def update_hp(self): self.hp_fragment.data.copy_(self.lp_fragment.data) @@ -38,42 +39,74 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def set_optim_state_fragment(self, flat_hp_partition, optim_fragment): + self.optim_fragment = { + key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel) + for key, value in optim_fragment.items() + if torch.is_tensor(value) and value.shape == flat_hp_partition.shape + } + def get_hp_fragment_address(self): return self.hp_fragment_address def get_optim_state_keys(self): return list(self.optim_fragment.keys()) + def get_hp_fragment(self, optim_state_key=None): + if optim_state_key is None: + return self.hp_fragment + return self.get_optim_state_fragment(optim_state_key) + + def get_lp_grad_fragment(self, index_in_param_group): + if self.use_offload: + gradient_dict = self.offload_gradient_dict + else: + gradient_dict = self.gradient_dict + + if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None: + raise ValueError("Gradients are only available immediately after backward and before engine step") + + return gradient_dict[self.param_group_index][index_in_param_group] + + +def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys): + for key in opt_keys: + hp_param = flat_hp_tensor + buffer = torch.zeros_like(hp_param) + + for lp in lp_tensors: + if lp._hp_mapping is not None: + hp_fragment_address = lp._hp_mapping.get_hp_fragment_address() + hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel) + hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data) + lp._hp_mapping.hp_fragment = hp_fragment + + optim_state[hp_param][key] = buffer + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: lp_frag_address = self._hp_mapping.lp_fragment_address reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel) - if optim_state_key is None: - hp_fragment = self._hp_mapping.hp_fragment - else: - hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) - + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) reduce_fragment.data.copy_(hp_fragment.data) dist.all_reduce(reduce_buffer, group=self._dp_group) return reduce_buffer.reshape_as(self) -def get_full_hp_grad(self): - reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() +def set_full_hp_param(self, value, optim_state_key=None): if self._hp_mapping is not None: - hp_mapping = self._hp_mapping - - if hp_mapping.use_offload: - gradient_dict = hp_mapping.offload_gradient_dict - else: - gradient_dict = hp_mapping.gradient_dict + lp_frag_address = self._hp_mapping.lp_fragment_address + value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel) + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) + hp_fragment.data.copy_(value_fragment.data) - if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None: - raise ValueError("Gradients are only available immediately after backward and before engine step") - lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group] +def get_full_hp_grad(self): + reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() + if self._hp_mapping is not None: + lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group) hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten() lp_frag_address = self._hp_mapping.lp_fragment_address @@ -88,11 +121,24 @@ def get_full_hp_grad(self): return reduce_buffer.reshape_as(self) +def set_full_hp_grad(self, value): + if self._hp_mapping is not None: + lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group) + lp_frag_address = self._hp_mapping.lp_fragment_address + value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel) + lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data)) + if hasattr(self, '_zero_optimizer'): + self._zero_optimizer.update_offload_overflow_tracker(value) + + def safe_get_full_fp32_param(param): """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device """ # ZeRO stage 3 param if hasattr(param, 'ds_id'): @@ -104,12 +150,32 @@ def safe_get_full_fp32_param(param): return None +def safe_set_full_fp32_param(param, value): + """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value) + + def safe_get_full_optimizer_state(param, optim_state_key): """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter - """ + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device +""" # ZeRO stage 3 param if hasattr(param, 'ds_id'): return param._z3_optimizer.get_full_hp_param(param, optim_state_key) @@ -120,12 +186,35 @@ def safe_get_full_optimizer_state(param, optim_state_key): return None +def safe_set_full_optimizer_state(param, value, optim_state_key): + """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param, optim_state_key) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value, optim_state_key) + + # TODO: Figure out the correct return dtype def safe_get_full_grad(param): - """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. + """ + Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. + The return data type is that used for gradient accumulation. This is usually the param data type, + but could also be different (e.g., bf16 param training with fp32 gradient accumulation). Args: param (``torch.nn.Parameter``): A model parameter + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device """ if param.grad is not None: return param.grad @@ -141,8 +230,148 @@ def safe_get_full_grad(param): return None +def safe_set_full_grad(param, value): + """ + Update the partitioned gradient of a low-precision (e.g., fp16) parameter. + To avoid precision issues, the update value should have the data type of + gradient accumulation. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): The un-partitioned new gradient value. + """ + if param.grad is not None: + param.grad.copy_(value) + elif hasattr(param, 'ds_id'): + # ZeRO stage 3 param + param._z3_optimizer.set_fp32_grad_for_param(value, param) + elif hasattr(param, '_hp_mapping'): + # ZeRO stage 1, 2, and bf16_optimizer params + param.set_full_hp_grad(value) + + +### Local API START ### +def safe_get_local_grad(param): + """ + Get the local gradient partition of a ZeRO-3 partitioned parameter. + The return data type is that used for gradient accumulation. This is usually the param data type, + but could also be different (e.g., bf16 param training with fp32 gradient accumulation). + + Args: + param (``torch.nn.Parameter``): A model parameter + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + return param._z3_optimizer.get_local_fp32_grad_for_param(param) + + +def safe_set_local_grad(param, value): + """ + Update the local gradient partition of a ZeRO-3 partitioned parameter. + To avoid precision issues, the update value should have the data type of + gradient accumulation. + + Args: + param (``torch.nn.Parameter``): A model parameter. + value (``torch.Tensor``): New value of local gradient partition. + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + param._z3_optimizer.set_local_grad_for_param(value, param) + + +def safe_get_local_fp32_param(param): + """Get the local partition of a ZeRO-3 partitioned parameter in fp32 precision. + + Args: + param (``torch.nn.Parameter``): A model parameter. + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + return param._z3_optimizer.get_local_fp32_param(param) + + +def safe_get_local_optimizer_state(param, optim_state_key): + """Get the local optimizer state partition of ZeRO-3 partitioned parameter in fp32 precision. + + Args: + param (``torch.nn.Parameter``): A model parameter + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + + Returns: + Union[torch.Tensor, None]: A tensor on accelerator device + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + return param._z3_optimizer.get_local_fp32_param(param, optim_state_key) + + +def safe_set_local_optimizer_state(param, value, optim_state_key): + """Update the local optimizer state partition of a ZeRO-3 partitioned parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter. + value (``torch.Tensor``): New value of local optimizer state partition. + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer). + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + param._z3_optimizer.set_local_hp_param(value, param, optim_state_key) + + +def safe_set_local_fp32_param(param, value): + """Update the local partition of ZeRO-3 partitioned parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter. + value (``torch.Tensor``): New value of local parameter partition. + """ + assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters' + param._z3_optimizer.set_local_hp_param(value, param) + + +### Local API END ### + + +### VECTORIZED API BEGIN ### +def safe_update_full_grad_vectorized(param_list: List[torch.nn.Parameter], update_func: Callable): + """ + Vectorized update of the partitioned gradients of a list of low-precision (e.g., fp16) parameters. + To avoid precision issues, the update value should have the data type of + gradient accumulation. + + Args: + param_list (``List[torch.nn.Parameter]``): List of model parameters + update_func (``torch.Tensor``): A function that takes current full gradient value and returns new one. + """ + partitioned_grad_params = [] + for p in param_list: + if p.grad is not None: + p.grad.copy_(update_func(p.grad, p)) + elif p.requires_grad: + partitioned_grad_params.append(p) + + if not partitioned_grad_params: + return + + if hasattr(partitioned_grad_params[0], 'ds_id'): + # ZeRO stage 3 param + partitioned_grad_params[0]._z3_optimizer.update_fp32_grad_for_param_vectorized( + update_func, partitioned_grad_params) + elif hasattr(partitioned_grad_params[0], '_hp_mapping'): + # ZeRO stage 1, 2, and bf16_optimizer params + for p in partitioned_grad_params: + old_grad = safe_get_full_grad(p) + new_grad = update_func(old_grad, p) + p.set_full_hp_grad(new_grad) + + +### VECTORIZED API END ### + + def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, optimizer_state_dict): + param_group_index, partition_start, partition_size): lp_end = lp_param.numel() + lp_start hp_start = partition_start hp_end = partition_start + partition_size @@ -155,11 +384,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict fragment_numel = fragment_end - fragment_start hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) - optim_fragment = { - key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) - for key, value in optimizer_state_dict.items() - if torch.is_tensor(value) and value.shape == flat_hp_partition.shape - } lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) @@ -168,7 +392,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, - optim_fragment=optim_fragment, gradient_dict=gradient_dict, offload_gradient_dict=offload_gradient_dict, use_offload=use_offload, diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index e52eb20d9602..0aa7be55d829 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -5,9 +5,20 @@ import time from numpy import mean -from deepspeed.utils.logging import log_dist +from deepspeed.utils.logging import print_dist from deepspeed.accelerator import get_accelerator -from deepspeed import comm as dist + +FORWARD_MICRO_TIMER = 'fwd_microstep' +FORWARD_GLOBAL_TIMER = 'fwd' +BACKWARD_MICRO_TIMER = 'bwd_microstep' +BACKWARD_GLOBAL_TIMER = 'bwd' +BACKWARD_INNER_MICRO_TIMER = 'bwd_inner_microstep' +BACKWARD_INNER_GLOBAL_TIMER = 'bwd_inner' +BACKWARD_REDUCE_MICRO_TIMER = 'bwd_allreduce_microstep' +BACKWARD_REDUCE_GLOBAL_TIMER = 'bwd_allreduce' +STEP_MICRO_TIMER = 'step_microstep' +STEP_GLOBAL_TIMER = 'step' +TIME_EPSILON = 1e-6 try: import psutil @@ -40,28 +51,43 @@ def __init__(self, name): self.name_ = name self.started_ = False self.event_timers = [] + self.use_host_timer = get_accelerator().use_host_timers() self.start_event = None self.elapsed_records = None + self.start_time = 0.0 + self.end_time = 0.0 def start(self): """Start the timer.""" assert not self.started_, f"{self.name_} timer has already been started" - self.start_event = get_accelerator().Event(enable_timing=True) - self.start_event.record() + if self.use_host_timer: + self.start_time = time.time() + else: + event_class = get_accelerator().Event + self.start_event = event_class(enable_timing=True) + self.start_event.record() self.started_ = True def stop(self, reset=False, record=False): """Stop the timer.""" assert self.started_, "timer is not started" - end_event = get_accelerator().Event(enable_timing=True) - end_event.record() - self.event_timers.append(CudaEventTimer(self.start_event, end_event)) - self.start_event = None + event_class = get_accelerator().Event + if self.use_host_timer: + self.end_time = time.time() + self.event_timers.append(self.end_time - self.start_time) + else: + event_class = get_accelerator().Event + end_event = event_class(enable_timing=True) + end_event.record() + self.event_timers.append(CudaEventTimer(self.start_event, end_event)) + self.start_event = None self.started_ = False def _get_elapsed_msec(self): - self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers] - self.event_timers.clear() + if self.use_host_timer: + self.elapsed_records = [et * 1000.0 for et in self.event_timers] + else: + self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers] return sum(self.elapsed_records) def reset(self): @@ -115,13 +141,14 @@ def memory_usage(): def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None): """Log a group of timers.""" assert normalizer > 0.0 - string = f"rank={dist.get_rank()} time (ms)" + string = "time (ms)" for name in names: if name in self.timers: elapsed_time = (self.timers[name].elapsed(reset=reset) / normalizer) string += " | {}: {:.2f}".format(name, elapsed_time) - log_dist(string, ranks=ranks or [0]) + # timers logging should be independent of the global log level it's already conditional on wall_clock_breakdown being True, so using use_logger=False will always print the stats + print_dist(string, ranks=ranks or [0]) def get_mean(self, names, normalizer=1.0, reset=True): """Get the mean of a group of timers.""" @@ -134,17 +161,46 @@ def get_mean(self, names, normalizer=1.0, reset=True): return means +class NoopTimer: + + class Timer: + + def start(self): + ... + + def reset(self): + ... + + def stop(self, **kwargs): + ... + + def elapsed(self, **kwargs): + return 0 + + def mean(self): + return 0 + + def __init__(self): + self.timer = self.Timer() + + def __call__(self, name): + return self.timer + + def get_timers(self): + return {} + + def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None): + ... + + def get_mean(self, names, normalizer=1.0, reset=True): + ... + + class ThroughputTimer: - def __init__( - self, - batch_size, - start_step=2, - steps_per_output=50, - monitor_memory=False, - logging_fn=None, - ): + def __init__(self, config, batch_size, start_step=2, steps_per_output=None, monitor_memory=False, logging_fn=None): from deepspeed.utils import logger + self.config = config self.start_time = 0 self.end_time = 0 self.started = False @@ -173,14 +229,22 @@ def _init_timer(self): self.initialized = True def start(self): + if not self.config.enabled: + return self._init_timer() self.started = True if self.global_step_count >= self.start_step: - get_accelerator().synchronize() + if self.config.synchronized: + get_accelerator().synchronize() self.start_time = time.time() + def _is_report_boundary(self): + if self.steps_per_output is None: + return False + return self.global_step_count % self.steps_per_output == 0 + def stop(self, global_step=False, report_speed=True): - if not self.started: + if not self.config.enabled or not self.started: return self.started = False self.micro_step_count += 1 @@ -188,14 +252,15 @@ def stop(self, global_step=False, report_speed=True): self.global_step_count += 1 if self.start_time > 0: - get_accelerator().synchronize() + if self.config.synchronized: + get_accelerator().synchronize() self.end_time = time.time() duration = self.end_time - self.start_time self.total_elapsed_time += duration self.step_elapsed_time += duration if global_step: - if report_speed and self.global_step_count % self.steps_per_output == 0: + if report_speed and self._is_report_boundary(): self.logging( "epoch={}/micro_step={}/global_step={}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, " "MemAllocated={}GB, MaxMemAllocated={}GB".format( @@ -203,7 +268,7 @@ def stop(self, global_step=False, report_speed=True): self.micro_step_count, self.global_step_count, self.avg_samples_per_sec(), - self.batch_size / self.step_elapsed_time, + self.batch_size / (self.step_elapsed_time + TIME_EPSILON), round(get_accelerator().memory_allocated() / 1024**3, 2), round(get_accelerator().max_memory_allocated() / 1024**3, 2), )) @@ -238,7 +303,7 @@ def trim_mean(data, trim_percent): Returns: float: Trimmed mean. """ - assert trim_percent >= 0.0 and trim_percent <= 1.0 + assert 0.0 <= trim_percent <= 1.0 n = len(data) # Account for edge case of empty list if len(data) == 0: diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py new file mode 100644 index 000000000000..e8c2831c4356 --- /dev/null +++ b/deepspeed/utils/torch.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import logging +import sys + +from packaging import version as pkg_version + +import torch + +_logger = logging.getLogger(__name__) + + +def required_torch_version(min_version=None, max_version=None): + assert min_version or max_version, "Must provide a min_version or max_version argument" + + torch_version = pkg_version.parse(torch.__version__) + + if min_version and pkg_version.parse(str(min_version)) > torch_version: + return False + + if max_version and pkg_version.parse(str(max_version)) < torch_version: + return False + + return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook) + + +def jit_script_compat(fn): + fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", repr(fn))) + + can_try_compile = (required_torch_version(min_version=2.0) and hasattr(torch, "compile") + and not (sys.version_info >= (3, 12) and not required_torch_version(min_version=2.4))) + + if can_try_compile: + try: + return torch.compile(fn) + except Exception: + _logger.debug( + "torch.compile failed for %s, falling back to torch.jit.script", + fn_name, + exc_info=True, + ) + + try: + return torch.jit.script(fn) + except Exception: + _logger.debug( + "torch.jit.script failed for %s, returning unmodified function", + fn_name, + exc_info=True, + ) + return fn diff --git a/deepspeed/utils/types.py b/deepspeed/utils/types.py index 2de4350fbd7a..96b5df625965 100644 --- a/deepspeed/utils/types.py +++ b/deepspeed/utils/types.py @@ -10,3 +10,18 @@ class ActivationFuncType(IntEnum): UNKNOWN = 0 GELU = 1 ReLU = 2 + GATED_GELU = 3 + GATED_SILU = 4 + + +GATED_ACTIVATION_TYPES = [ + ActivationFuncType.GATED_GELU, + ActivationFuncType.GATED_SILU, +] + + +class NormType(IntEnum): + UNKNOWN = 0 + LayerNorm = 1 + GroupNorm = 2 + RMSNorm = 3 diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py new file mode 100644 index 000000000000..123763069138 --- /dev/null +++ b/deepspeed/utils/z3_leaf_module.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING + +from .logging import logger + +if TYPE_CHECKING: + from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig + + +def z3_leaf_module(model: torch.nn.Module) -> bool: + """Returns whether a module in `model` has been flagged as a 'leaf' module. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + Returns: + bool: Whether the module has been flagged as a 'leaf' module. + """ + return hasattr(model, '_z3_leaf') and model._z3_leaf + + +def z3_leaf_parameter(model: torch.nn.Parameter) -> bool: + """Returns whether a parameter belongs to a leaf module. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Parameter): The parameter to which the leaf module flag will be applied. + Returns: + bool: Whether the parameter belongs to a leaf module. + """ + return hasattr(model, 'ds_z3_leaf_module') + + +def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]: + """Returns a list of modules in `model` that have been flagged as 'leaf' modules. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + Returns: + List[torch.nn.Module]: A list of modules that have been flagged as 'leaf' modules. + """ + return [module for module in model.modules() if z3_leaf_module(module)] + + +def set_z3_leaf_module(model: torch.nn.Module, flag: bool): + model._z3_leaf = flag + + +def _fully_qualified_class_name(module: torch.nn.Module) -> str: + cls = module.__class__ + return f"{cls.__module__}.{cls.__qualname__}" + + +def _do_set_z3_leaf_modules(model: torch.nn.Module, + leaf_module_classes: Union[List[Type], List[str]], + flag: bool, + raise_if_not_found: bool = True) -> List[torch.nn.Module]: + assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \ + f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}' + + leaf_modules: List[torch.nn.Module] = [] + + def _set_z3_leaf_flag(module_instance: torch.nn.Module): + nonlocal leaf_modules + for module in leaf_module_classes: + if isinstance(module, type) and isinstance(module_instance, module): + module_instance._z3_leaf = flag + leaf_modules.append(module_instance) + break + + if isinstance(module, str): + if (module_instance.__class__.__name__ == module + or _fully_qualified_class_name(module_instance) == module): + module_instance._z3_leaf = flag + leaf_modules.append(module_instance) + break + + model.apply(_set_z3_leaf_flag) + + if len(leaf_modules) == 0 and raise_if_not_found: + raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}') + + return leaf_modules + + +def set_z3_leaf_modules_by_name(model: torch.nn.Module, + module_names: List[str], + flag: bool = True, + raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]: + """Sets a leaf flag for modules referenced by their names in ``model.named_modules()``. + Args: + model (torch.nn.Module): The model containing the modules to update. + module_names (List[str]): Module names as returned by ``named_modules()``. + flag (bool): Desired flag state. + raise_if_not_found (bool): Whether to raise when no module matches a provided name. + Returns: + Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names. + """ + modules_by_name = dict(model.named_modules()) + leaf_modules: List[torch.nn.Module] = [] + missing: List[str] = [] + + for name in module_names: + module = modules_by_name.get(name) + if module is None: + missing.append(name) + continue + module._z3_leaf = flag + leaf_modules.append(module) + + if missing and raise_if_not_found: + raise ValueError(f'No modules named {missing} found in model {model}') + + return leaf_modules, missing + + +def set_z3_leaf_modules_by_suffix(model: torch.nn.Module, + module_name_suffixes: List[str], + flag: bool = True, + raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]: + """Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names.""" + modules_by_name = dict(model.named_modules()) + leaf_modules: List[torch.nn.Module] = [] + missing: List[str] = [] + seen_ids = set() + + for suffix in module_name_suffixes: + matched = False + for name, module in modules_by_name.items(): + if name.endswith(suffix): + module._z3_leaf = flag + module_id = id(module) + if module_id not in seen_ids: + seen_ids.add(module_id) + leaf_modules.append(module) + matched = True + if not matched: + missing.append(suffix) + + if missing and raise_if_not_found: + raise ValueError(f'No modules matching suffixes {missing} found in model {model}') + + return leaf_modules, missing + + +def set_z3_leaf_modules(model: torch.nn.Module, + leaf_module_classes: Union[List[Type], List[str]], + raise_if_not_found: bool = True) -> List[torch.nn.Module]: + """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. + This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. + Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. + raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes + match a module inside ``model``. + Returns: + List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. + """ + return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found) + + +def unset_z3_leaf_modules(model: torch.nn.Module, + leaf_module_classes: List[Type], + raise_if_not_found: bool = True) -> List[torch.nn.Module]: + """Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. + raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes + match a module inside ``model``. + Returns: + List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. + """ + return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found) + + +def apply_zero_leaf_module_config(model: torch.nn.Module, + leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]: + """Apply ZeRO leaf module configuration to ``model``. + + Args: + model (torch.nn.Module): Root module to update. + leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None`` + no changes are applied. + + Returns: + List[torch.nn.Module]: Modules flagged as leaves. + """ + if leaf_cfg is None: + return [] + + from deepspeed.runtime.zero.leaf_module_config import ( + DEFAULT_LEAF_MODULE_CLASSES, + DEFAULT_LEAF_MODULE_NAMES, + DEFAULT_LEAF_MODULE_NAME_SUFFIXES, + ) + + matched_modules: List[torch.nn.Module] = [] + matched_ids = set() + + customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES + customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES + customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES + + if leaf_cfg.classes: + class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False) + for module in class_matched: + module_id = id(module) + if module_id not in matched_ids: + matched_ids.add(module_id) + matched_modules.append(module) + + if leaf_cfg.names: + name_matched, missing_names = set_z3_leaf_modules_by_name(model, + leaf_cfg.names, + flag=True, + raise_if_not_found=False) + for module in name_matched: + module_id = id(module) + if module_id not in matched_ids: + matched_ids.add(module_id) + matched_modules.append(module) + + if missing_names and customized_names: + logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}") + + if leaf_cfg.name_suffixes: + suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model, + leaf_cfg.name_suffixes, + flag=True, + raise_if_not_found=False) + for module in suffix_matched: + module_id = id(module) + if module_id not in matched_ids: + matched_ids.add(module_id) + matched_modules.append(module) + + if missing_suffixes and customized_suffixes: + logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}") + + if not matched_modules and (customized_classes or customized_names or customized_suffixes): + logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual") + + return matched_modules diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 57009f6e2f6b..5995d6e6f04e 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -5,12 +5,15 @@ # DeepSpeed Team -# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # the future. Once extracted, the weights don't require DeepSpeed and can be used in any # application. # -# example: python zero_to_fp32.py . pytorch_model.bin +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization import argparse import torch @@ -18,13 +21,30 @@ import math import os import re +import gc +import json +import numpy as np +from tqdm import tqdm from collections import OrderedDict +from dataclasses import dataclass # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # DeepSpeed data structures it has to be available in the current python environment. from deepspeed.utils import logger from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, - FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES) + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + debug = 0 @@ -50,7 +70,7 @@ def get_model_state_file(checkpoint_dir, zero_stage): raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") # there should be only one file - if zero_stage == 2: + if zero_stage <= 2: file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") elif zero_stage == 3: file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") @@ -61,59 +81,81 @@ def get_model_state_file(checkpoint_dir, zero_stage): return file -def get_optim_files(checkpoint_dir): +def get_checkpoint_files(checkpoint_dir, glob_pattern): # XXX: need to test that this simple glob rule works for multi-node setup too - optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")), key=natural_keys) + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) - if len(optim_files) == 0: - raise FileNotFoundError(f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") - return optim_files + return ckpt_files -def parse_model_state(file): - state_dict = torch.load(file, map_location=device) +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") - if BUFFER_NAMES not in state_dict: - raise ValueError(f"{file} is not a model state checkpoint") - buffer_names = state_dict[BUFFER_NAMES] - if debug: - print("Found buffers:", buffer_names) - - # recover just the buffers while restoring them to fp32 if they were saved in fp16 - buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} - param_shapes = state_dict[PARAM_SHAPES] - - # collect parameters that are included in param_shapes - param_names = [] - for s in param_shapes: - for name in s.keys(): - param_names.append(name) - - # record shared parameters so that they can be recovered based on partners - # this is because such parameters holding reference only are not saved by optimizer - shared_params = [] - for param in state_dict["module"]: - if param not in [*param_names, *buffer_names]: - for share_param in state_dict["module"]: - if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr() - and share_param != param): - shared_params.append([param, share_param]) - break - ds_version = state_dict.get(DS_VERSION, None) +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") - return buffers, param_shapes, shared_params, ds_version +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device, weights_only=False) -def parse_optim_states(files, ds_checkpoint_dir): + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] - for f in files: - state_dicts.append(torch.load(f, map_location=device)) - - if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]: raise ValueError(f"{files[0]} is not a zero checkpoint") zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] @@ -132,30 +174,18 @@ def parse_optim_states(files, ds_checkpoint_dir): ) # the groups are named differently in each stage - if zero_stage == 2: + if zero_stage <= 2: fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS elif zero_stage == 3: fp32_groups_key = FP32_FLAT_GROUPS else: raise ValueError(f"unknown zero stage {zero_stage}") - if zero_stage == 2: - fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] - elif zero_stage == 3: - # if there is more than one param group, there will be multiple flattened tensors - one - # flattened tensor per group - for simplicity merge them into a single tensor - # - # XXX: could make the script more memory efficient for when there are multiple groups - it - # will require matching the sub-lists of param_shapes for each param group flattened tensor - - fp32_flat_groups = [ - torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) - ] - + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -169,19 +199,58 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) - buffers, param_shapes, shared_params, ds_version = parse_model_state(model_file) - print(f'Parsing checkpoint created by deepspeed=={ds_version}') + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') - if zero_stage == 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, - shared_params) + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, - shared_params) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, shared_params): +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes # Reconstruction protocol: # @@ -209,13 +278,6 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_fl print(f"Have {avail_numel} numels to process.") print(f"Need {wanted_numel} numels in {wanted_params} params.") - state_dict = OrderedDict() - - # buffers - state_dict.update(buffers) - if debug: - print(f"added {len(buffers)} buffers") - # params # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # out-of-core computing solution @@ -226,7 +288,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, param_shapes, fp32_fl avail_numel = full_single_fp32_vector.numel() for name, shape in shapes.items(): - unpartitioned_numel = shape.numel() + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) total_numel += unpartitioned_numel total_params += 1 @@ -257,12 +319,29 @@ def zero2_align(x): if offset != avail_numel: raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") - # recover shared parameters - for pair in shared_params: - state_dict[pair[0]] = state_dict[pair[1]] - print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + return state_dict @@ -273,12 +352,95 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size): return partitioned_numel, padding_numel -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_flat_groups, buffers, shared_params): +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + + def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.flat_groups_offset = flat_groups_offset + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + self.dtype = self.flat_groups[0][0].dtype + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] + start_group_id = None + end_group_id = None + for group_id in range(len(self.flat_groups_offset)): + if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + start_group_id = group_id + if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + end_group_id = group_id + break + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - self.flat_groups_offset[group_id] + end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) + + # collect weights from all ranks + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # param, re-consolidating each param, while dealing with padding if any - avail_numel = fp32_flat_groups[0].numel() * world_size # merge list of dicts, preserving order param_shapes = {k: v for d in param_shapes for k, v in d.items()} @@ -289,15 +451,9 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_fl wanted_params = len(param_shapes) wanted_numel = sum(shape.numel() for shape in param_shapes.values()) # not asserting if there is a mismatch due to possible padding - print(f"Have {avail_numel} numels to process.") - print(f"Need {wanted_numel} numels in {wanted_params} params.") - - state_dict = OrderedDict() - - # buffers - state_dict.update(buffers) - if debug: - print(f"added {len(buffers)} buffers") + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") # params # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support @@ -305,23 +461,21 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_fl offset = 0 total_numel = 0 total_params = 0 - for name, shape in param_shapes.items(): - + flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: print( - f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) - # XXX: memory usage doubles here - state_dict[name] = torch.cat( - tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), - 0).narrow(0, 0, unpartitioned_numel).view(shape) + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + state_dict[name] = tensor offset += partitioned_numel offset *= world_size @@ -330,16 +484,56 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, param_shapes, fp32_fl if offset != avail_numel: raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") - # recover shared parameters - for pair in shared_params: - state_dict[pair[0]] = state_dict[pair[1]] + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") - print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): +def to_torch_tensor(state_dict, return_empty_tensor=False): + """ + Convert state_dict of GatheredTensor to torch tensor + """ + torch_state_dict = {} + converted_tensors = {} + for name, tensor in state_dict.items(): + tensor_id = id(tensor) + if tensor_id in converted_tensors: # shared tensors + shared_tensor = torch_state_dict[converted_tensors[tensor_id]] + torch_state_dict[name] = shared_tensor + else: + converted_tensors[tensor_id] = name + if return_empty_tensor: + torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + else: + torch_state_dict[name] = tensor.contiguous() + return torch_state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag=None, + exclude_frozen_parameters=False, + lazy_mode=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -348,14 +542,13 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. + Convert the pesduo tensor to torch tensor by ``.contiguous()`` Returns: - pytorch ``state_dict`` - Note: this approach may not work if your application doesn't have sufficient free CPU memory and - you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with - the checkpoint. - A typical usage might be :: from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint @@ -371,6 +564,16 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + Note: the above usage may not work if your application doesn't have sufficient free CPU memory. + You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. Or you can load state_dict in lazy mode :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu + for name, lazy_tensor in state_dict.item(): + tensor = lazy_tensor.contiguous() # to cpu + print(name, tensor) + # del tensor to release memory if it no longer in use """ if tag is None: latest_path = os.path.join(checkpoint_dir, 'latest') @@ -385,23 +588,96 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + if lazy_mode: + return state_dict + else: + return to_torch_tensor(state_dict) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. Args: - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) - print(f"Saving fp32 state dict to {output_file}") - torch.save(state_dict, output_file) + # Dependency pre-check + if safe_serialization: + try: + from safetensors.torch import save_file + except ImportError: + print('If you want to use `safe_serialization`, please `pip install safetensors`') + raise + if max_shard_size is not None: + try: + from huggingface_hub import split_torch_state_dict_into_shards + except ImportError: + print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') + raise + + # Convert zero checkpoint to state_dict + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag, + exclude_frozen_parameters, + lazy_mode=True) + + # Shard the model if it is too big. + weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" + if max_shard_size is not None: + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + # an memory-efficient approach for sharding + empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) + state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size) + else: + from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) + state_dict_split = StateDictSplit(is_sharded=False, + filename_to_tensors={weights_name: list(state_dict.keys())}) + + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): + shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors} + shard_state_dict = to_torch_tensor(shard_state_dict) + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard_state_dict, output_path, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, output_path) + # release the memory of current shard + for tensor_name in list(shard_state_dict.keys()): + del state_dict[tensor_name] + del shard_state_dict[tensor_name] + del shard_state_dict + gc.collect() + + # Save index if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): @@ -433,10 +709,10 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. """ - logger.info(f"Extracting fp32 weights") + logger.info("Extracting fp32 weights") state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) - logger.info(f"Overwriting model with fp32 weights") + logger.info("Overwriting model with fp32 weights") model = model.cpu() model.load_state_dict(state_dict, strict=False) @@ -444,18 +720,41 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument("output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)") parser.add_argument( - "output_file", + "--max_shard_size", type=str, - help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9bcfedb8d8f3..263a30be27c5 100755 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,6 @@ -FROM nvidia/cuda:10.0-devel-ubuntu18.04 +FROM nvidia/cuda:12.2.2-devel-ubuntu20.04 + +ENV DEBIAN_FRONTEND noninteractive ############################################################################## # Temporary Installation Directory @@ -17,7 +19,7 @@ RUN apt-get update && \ curl wget vim tmux emacs less unzip \ htop iftop iotop ca-certificates openssh-client openssh-server \ rsync iputils-ping net-tools sudo \ - llvm-9-dev + llvm-dev ############################################################################## # Installation Latest Git @@ -38,20 +40,20 @@ RUN cp /etc/ssh/sshd_config ${STAGE_DIR}/sshd_config && \ ############################################################################## # Mellanox OFED ############################################################################## -ENV MLNX_OFED_VERSION=4.6-1.0.1.1 +ENV MLNX_OFED_VERSION=4.9-7.1.0.0 RUN apt-get install -y libnuma-dev RUN cd ${STAGE_DIR} && \ - wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64.tgz | tar xzf - && \ - cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64 && \ + wget -q -O - http://www.mellanox.com/downloads/ofed/MLNX_OFED-${MLNX_OFED_VERSION}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu20.04-x86_64.tgz | tar xzf - && \ + cd MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu20.04-x86_64 && \ ./mlnxofedinstall --user-space-only --without-fw-update --all -q && \ cd ${STAGE_DIR} && \ - rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu18.04-x86_64* + rm -rf ${STAGE_DIR}/MLNX_OFED_LINUX-${MLNX_OFED_VERSION}-ubuntu20.04-x86_64* ############################################################################## # nv_peer_mem ############################################################################## -ENV NV_PEER_MEM_VERSION=1.1 -ENV NV_PEER_MEM_TAG=1.1-0 +ENV NV_PEER_MEM_VERSION=1.2 +ENV NV_PEER_MEM_TAG=${NV_PEER_MEM_VERSION}-0 RUN mkdir -p ${STAGE_DIR} && \ git clone https://github.com/Mellanox/nv_peer_memory.git --branch ${NV_PEER_MEM_TAG} ${STAGE_DIR}/nv_peer_memory && \ cd ${STAGE_DIR}/nv_peer_memory && \ @@ -67,8 +69,8 @@ RUN mkdir -p ${STAGE_DIR} && \ ############################################################################## # OPENMPI ############################################################################## -ENV OPENMPI_BASEVERSION=4.0 -ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.1 +ENV OPENMPI_BASEVERSION=4.1 +ENV OPENMPI_VERSION=${OPENMPI_BASEVERSION}.6 RUN cd ${STAGE_DIR} && \ wget -q -O - https://download.open-mpi.org/release/open-mpi/v${OPENMPI_BASEVERSION}/openmpi-${OPENMPI_VERSION}.tar.gz | tar xzf - && \ cd openmpi-${OPENMPI_VERSION} && \ @@ -95,7 +97,7 @@ ENV PYTHON_VERSION=3 RUN apt-get install -y python3 python3-dev && \ rm -f /usr/bin/python && \ ln -s /usr/bin/python3 /usr/bin/python && \ - curl -O https://bootstrap.pypa.io/get-pip.py && \ + curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py && \ python get-pip.py && \ rm get-pip.py && \ pip install --upgrade pip && \ @@ -104,12 +106,6 @@ RUN apt-get install -y python3 python3-dev && \ RUN pip install pyyaml RUN pip install ipython -############################################################################## -# TensorFlow -############################################################################## -ENV TENSORFLOW_VERSION=1.15.2 -RUN pip install tensorflow-gpu==${TENSORFLOW_VERSION} - ############################################################################## # Some Packages ############################################################################## @@ -136,33 +132,26 @@ RUN pip install psutil \ sentencepiece \ msgpack \ requests \ - pandas \ sphinx \ sphinx_rtd_theme \ scipy \ numpy \ - sklearn \ scikit-learn \ nvidia-ml-py3 \ - mpi4py \ - cupy-cuda100 + mpi4py ############################################################################## ## SSH daemon port inside container cannot conflict with host OS port ############################################################################### ENV SSH_PORT=2222 RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \ - sed "0,/^#Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config + sed "0,/^Port 22/s//Port ${SSH_PORT}/" ${STAGE_DIR}/sshd_config > /etc/ssh/sshd_config ############################################################################## # PyTorch ############################################################################## -ENV PYTORCH_VERSION=1.2.0 -ENV TORCHVISION_VERSION=0.4.0 -ENV TENSORBOARDX_VERSION=1.8 +ENV PYTORCH_VERSION=1.13.0 RUN pip install torch==${PYTORCH_VERSION} -RUN pip install torchvision==${TORCHVISION_VERSION} -RUN pip install tensorboardX==${TENSORBOARDX_VERSION} ############################################################################## # PyYAML build issue @@ -185,7 +174,7 @@ USER deepspeed ############################################################################## # DeepSpeed ############################################################################## -RUN git clone https://github.com/microsoft/DeepSpeed.git ${STAGE_DIR}/DeepSpeed +RUN git clone https://github.com/deepspeedai/DeepSpeed.git ${STAGE_DIR}/DeepSpeed RUN cd ${STAGE_DIR}/DeepSpeed && \ git checkout . && \ git checkout master && \ diff --git a/docker/gh-builder/Dockerfile.py311 b/docker/gh-builder/Dockerfile.py311 new file mode 100644 index 000000000000..603fb614314f --- /dev/null +++ b/docker/gh-builder/Dockerfile.py311 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.11 +RUN wget https://www.python.org/ftp/python/3.11.9/Python-3.11.9.tgz \ + && tar xzf Python-3.11.9.tgz \ + && cd Python-3.11.9 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.11.9 Python-3.11.9.tgz + +# Set Python 3.11 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.11 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.11 1 diff --git a/docker/gh-builder/Dockerfile.py312 b/docker/gh-builder/Dockerfile.py312 new file mode 100644 index 000000000000..a0a7193201d4 --- /dev/null +++ b/docker/gh-builder/Dockerfile.py312 @@ -0,0 +1,35 @@ +# Start with NGC container +FROM nvcr.io/nvidia/pytorch:24.03-py3 + +# Set noninteractive mode for apt-get +ARG DEBIAN_FRONTEND=noninteractive + +# Install necessary dependencies for building Python +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + libncursesw5-dev \ + libgdbm-dev \ + libc6-dev \ + libffi-dev \ + tk-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and install Python 3.12 +RUN wget https://www.python.org/ftp/python/3.12.5/Python-3.12.5.tgz \ + && tar xzf Python-3.12.5.tgz \ + && cd Python-3.12.5 \ + && ./configure --enable-optimizations \ + && make altinstall \ + && cd .. \ + && rm -rf Python-3.12.5 Python-3.12.5.tgz + +# Set Python 3.12 as the default Python version +RUN update-alternatives --install /usr/bin/python python /usr/local/bin/python3.12 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.12 1 diff --git a/docker/Dockerfile.rocm b/docker/rocm/Dockerfile similarity index 100% rename from docker/Dockerfile.rocm rename to docker/rocm/Dockerfile diff --git a/docs/CNAME b/docs/CNAME index 72033bc5f7fe..47f170e64eeb 100644 --- a/docs/CNAME +++ b/docs/CNAME @@ -1 +1 @@ -www.deepspeed.ai +www.deepspeed.ai \ No newline at end of file diff --git a/docs/Gemfile b/docs/Gemfile index 888e3c8dfd6a..f40c61e4575f 100644 --- a/docs/Gemfile +++ b/docs/Gemfile @@ -20,3 +20,5 @@ end # Performance-booster for watching directories on Windows gem "wdm", "~> 0.1.1", :install_if => Gem.win_platform? + +gem "webrick", "~> 1.8" diff --git a/docs/README.md b/docs/README.md index fbd9b68ac20e..7333a119c7be 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,7 +20,7 @@ Add these lines to your `.bashrc` or equivalent to ensure you have permissions t export GEM_HOME="$HOME/gems" export PATH="$HOME/gems/bin:$PATH" ``` -Don't forget to `source ~/.bashrc` afterwards 😊. +Don't forget to `source ~/.bashrc` afterward 😊. Now we can install Jekyll and [Bundler](https://bundler.io/): @@ -35,13 +35,23 @@ We now need to install the required Ruby packages for the website. > Could not locate Gemfile -**NOTE**: this step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. +**NOTE**: This step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. ``` bundle install ``` +Depending on your environment, you may need to add `webrick` to avoid the following [error](https://talk.jekyllrb.com/t/load-error-cannot-load-such-file-webrick/5417/6): + +> gems/gems/jekyll-3.9.5/lib/jekyll/commands/serve/servlet.rb:3:in `require': cannot load such file -- webrick (LoadError) + + +``` +bundle add webrick +``` + + You can now start a local webserver via: ``` bundle exec jekyll serve diff --git a/docs/_config.yml b/docs/_config.yml index 7127b8459fe2..ac8d9028e58f 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -41,6 +41,7 @@ collections: - cifar-10.md - curriculum-learning.md - data-efficiency.md + - ds4sci_evoformerattention.md - flops-profiler.md - pytorch-profiler.md - autotuning.md diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 6f7c443c7958..fe72f8890bab 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -8,7 +8,7 @@ main: - title: 'Documentation' url: https://deepspeed.readthedocs.io/ - title: 'GitHub' - url: https://github.com/microsoft/DeepSpeed + url: https://github.com/deepspeedai/DeepSpeed lnav: - title: 'Training' @@ -17,6 +17,8 @@ lnav: url: /inference/ - title: 'Compression' url: /compression/ + - title: 'Science' + url: /deepspeed4science/ - title: 'Getting Started' url: /getting-started/ - title: 'ds_config' @@ -39,7 +41,7 @@ lnav: - title: 'Flops Profiler' url: /docs/config-json/#flops-profiler - title: 'Monitoring' - url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv + url: /docs/config-json/#monitoring-module - title: 'Communication Logging' url: /docs/config-json/#communication-logging - title: 'Model Compression' @@ -53,8 +55,14 @@ lnav: url: /getting-started/ - title: 'Getting started on Azure' url: /tutorials/azure/ - - title: 'Automatic Tensor Parallelism' + - title: 'Accelerator Abstraction' + url: /tutorials/accelerator-abstraction-interface/ + - title: 'Accelerator Setup Guides' + url: /tutorials/accelerator-setup-guide/ + - title: 'Automatic Tensor Parallelism (Inference)' url: /tutorials/automatic-tensor-parallelism/ + - title: 'Automatic Tensor Parallelism (Training)' + url: /tutorials/autotp-training/ - title: 'Autotuning' url: /tutorials/autotuning/ - title: 'BingBertSQuAD Fine-tuning' @@ -67,6 +75,12 @@ lnav: url: /tutorials/curriculum-learning/ - title: 'Data Efficiency' url: /tutorials/data-efficiency/ + - title: 'DeepNVMe' + url: /tutorials/deepnvme/ + - title: 'Domino' + url: /tutorials/domino/ + - title: 'DS4Sci_EvoformerAttention' + url: /tutorials/ds4sci_evoformerattention/ - title: 'Flops Profiler' url: /tutorials/flops-profiler/ - title: 'PyTorch Profiler' @@ -109,9 +123,13 @@ lnav: url: /tutorials/sparse-attention/ - title: 'Transformer Kernel' url: /tutorials/transformer_kernel/ + - title: 'Arctic Long Sequence Training (ALST) for HF Transformers integration' + url: /tutorials/ulysses-alst-sequence-parallelism - title: 'ZeRO-Offload' url: /tutorials/zero-offload/ - title: 'ZeRO' url: /tutorials/zero/ + - title: 'ZeRO++' + url: /tutorials/zeropp/ - title: 'Contributing' url: /contributing/ diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 42186b12f4de..53d8adcadd93 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -6,7 +6,7 @@ toc_label: "Contents" ### Batch Size Related Parameters -**Note:** **train_batch_size** must be equal to **train_micro_batch_size_per_gpu** * **gradient_accumulation** * number of GPUs. For simplicity, you can choose to only specify two of the three parameters, the last one will be inferred automatically by DeepSpeed. +**Note:** **train_batch_size** must be equal to **train_micro_batch_size_per_gpu** * **gradient_accumulation_steps** * number of GPUs. For simplicity, you can choose to only specify two of the three parameters, the last one will be inferred automatically by DeepSpeed. {: .notice--warning} **train_batch_size**: [integer] @@ -36,9 +36,22 @@ toc_label: "Contents" | Fields | Value | Example | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------- | -| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, and **OneBitLamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | +| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | +Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe. + +Muon supports the following params: + +| "params" key | Description | Default | +| -------------- | -------------------------------------------------------------------------------------------------------------------- | --------- | +| lr | Learning rate for all parameters. Overridden by `muon_lr` / `adam_lr` if set. | 0.001 | +| momentum | Momentum coefficient for the Muon update. | 0.95 | +| weight\_decay | Weight decay (AdamW-style). | 0.0 | +| muon\_lr | Learning rate override for Muon parameters. Defaults to `lr` if not set. | - | +| adam\_lr | Learning rate override for non-Muon (Adam) parameters. Defaults to `lr` if not set. | - | +| ns\_method | Newton-Schulz orthogonalization method: `"gram"` for Gram NS (~2x faster on rectangular matrices), `"standard"` for the original iteration. Use `"standard"` to fall back if you encounter convergence issues. | `"gram"` | + Example of **optimizer** with Adam ```json @@ -62,6 +75,25 @@ The Adam optimizer also supports the following two params keys/values in additio | torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | | adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | +Example of **optimizer** with Muon +If not set, muon_lr will default to lr. +```json +"optimizer": { + "type": "Muon", + "params": { + "lr": 0.001, + "momentum": 0.9, + "weight_decay": 0.0, + "muon_lr": 0.001, + "ns_method": "gram" + } + }, + "zero_optimization": { + "stage": 3, + "save_muon_momentum_buffer_in_memory": true + } +``` + Another example of **optimizer** with 1-bit Adam specific parameters is as follows. ```json @@ -181,7 +213,7 @@ Example of **scheduler** ### Communication options -**communication_data_type**: [boolean] +**communication_data_type**: [string] | Description | Default | | ----------------------------------------------------------------------------------------------------------------------------- | ------- | @@ -224,7 +256,9 @@ Example of **scheduler** "initial_scale_power": 16, "loss_scale_window": 1000, "hysteresis": 2, - "min_loss_scale": 1 + "consecutive_hysteresis": false, + "min_loss_scale": 1, + "fp16_master_weights_and_grads": false } ``` @@ -250,7 +284,7 @@ Example of **scheduler** | Description | Default | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | -| **initial_scale_power** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2**initial_scale_power**. | `32` | +| **initial_scale_power** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2**initial_scale_power**. | `16` | **fp16:loss_scale_window**: [integer] @@ -264,12 +298,32 @@ Example of **scheduler** | --------------------------------------------------------------------------------------------------- | ------- | | **hysteresis** is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2` | +**fp16:consecutive_hysteresis**: [boolean] + +| Description | Default | +| --------------------------------------------------------------------------------------------------- | ------- | +| **consecutive_hysteresis** is a **fp16** parameter representing whether to refill the hysteresis if we reach an iteration that doesn't overflow | `false` | + **fp16:min_loss_scale**: [integer] | Description | Default | | ----------------------------------------------------------------------------------------------------- | ------- | | **min_loss_scale** is a **fp16** parameter representing the minimum dynamic loss scale value. | `1` | +**fp16:fp16_master_weights_and_grads**: [boolean] + +| Description | Default | +| ----------- | ------- | +| Keep master parameters/gradients in fp16 instead of fp32 for ZeRO optimizer state. Requires ZeRO Stage 2 or 3 with ZeRO-Offload and `DeepSpeedCPUAdam` so optimizer states can remain in fp32. | `false` | + +**Support matrix (fp16 master weights/gradients)** + +| ZeRO stage | Offload required? | Notes | +| ---------- | ----------------- | ----- | +| 0 | Not supported | | +| 1/2/3 | Yes (`offload_optimizer` with `DeepSpeedCPUAdam`) | Optimizer states stay fp32 on CPU. | + + ### BFLOAT16 training options **Note:** this mode cannot be combined with the `amp` mode described below. @@ -286,7 +340,9 @@ Example of **scheduler** ```json "bf16": { - "enabled": true + "enabled": true, + "bf16_master_weights_and_grads": true, + "bf16_optimizer_states": true } ``` @@ -296,6 +352,24 @@ Example of **scheduler** |--------------------------------------------------------------------| ------- | | **enabled** indicates whether BFLOAT16 training is enabled. | `false` | +**bf16:bf16_master_weights_and_grads**: [boolean] + +| Description | Default | +| ----------- | ------- | +| Keep ZeRO master parameters/gradients in bf16 instead of fp32. Supported with ZeRO Stages 1, 2, or 3. If you leave optimizer states in fp32, ZeRO-Offload with `DeepSpeedCPUAdam` is required. | `false` | + +**bf16:bf16_optimizer_states**: [boolean] + +| Description | Default | +| ----------- | ------- | +| Keep optimizer states in bf16 as well. Requires `bf16_master_weights_and_grads=true`. Offload is optional: without `offload_optimizer` the bf16 states stay on the GPU; with `offload_optimizer` (`DeepSpeedCPUAdam`) they are offloaded to CPU memory in bf16. The offloaded state (bf16 master weights plus the two bf16 Adam moments) is then ~6 bytes/param, versus ~10 bytes/param when the moments are kept in fp32. | `false` | + +**Support matrix (bf16 master weights/gradients)** + +| ZeRO stage | bf16_optimizer_states=False | bf16_optimizer_states=True | +| ---------- | --------------------------- | -------------------------- | +| 0 | Not supported | Not supported | +| 1/2/3 | Requires ZeRO-Offload + `DeepSpeedCPUAdam` (optimizer states stay fp32 on CPU) | On GPU without offload, or on CPU with `offload_optimizer` + `DeepSpeedCPUAdam`; optimizer states kept in bf16 either way | ### Automatic mixed precision (AMP) training options @@ -329,6 +403,29 @@ Example of **scheduler** | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None | +### PyTorch Automatic Mixed Precision (torch.autocast) training options + +**torch_autocast**: [dictionary] + +| Description | Default | +| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Configuration for using PyTorch's native automatic mixed precision training via [torch.autocast](https://pytorch.org/docs/stable/amp.html). For detailed usage instructions, see the [Mixed Precision Training](https://deepspeed.readthedocs.io/en/latest/training.html#mixed-precision-training) documentation. | None | + +```json +"torch_autocast": { + "enabled": true, + "dtype": "bfloat16", + "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"] +} +``` + +| Parameter | Type | Default | Description | +| --------- | ---- | ------- | ----------- | +| **enabled** | boolean | `false` | Enable torch.autocast (no manual `torch.autocast` call needed in your code). | +| **dtype** | string | `"bfloat16"` | Lower precision dtype (`"bfloat16"` or `"float16"`). Also used for gradient/parameter communication of `lower_precision_safe_modules`. | +| **lower_precision_safe_modules** | list | `["torch.nn.Linear", "torch.nn.Conv1d", "torch.nn.Conv2d", "torch.nn.Conv3d"]` | Module types for lower-precision communication (all-reduce/all-gather). | + + ### Gradient Clipping **gradient_clipping**: [float] @@ -364,8 +461,12 @@ Enabling and configuring ZeRO memory optimizations "sub_group_size" : 1e12, "elastic_checkpoint" : [true|false], "stage3_gather_16bit_weights_on_model_save": [true|false], - "ignore_unused_parameters": [true|false] - "round_robin_gradients": [true|false] + "ignore_unused_parameters": [true|false], + "round_robin_gradients": [true|false], + "zero_hpz_partition_size": 1, + "zero_quantized_weights": [true|false], + "zero_quantized_gradients": [true|false], + "log_trace_cache_warnings": [true|false], } ``` @@ -417,6 +518,12 @@ Enabling and configuring ZeRO memory optimizations | ------------------------------------------------------------------------------------------------------------------- | ------- | | Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. | `True` | +**load_from_fp32_weights**: [boolean] + +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- | +| Initialize fp32 master weights from fp32 copies in checkpoint (no precision loss) or from model's fp16 copies (with precision loss). This can be used to initialize optimizer state even when checkpoint is missing optimizer state. | `True` | + **grad_hooks**: [boolean] | Description | Default | @@ -464,7 +571,7 @@ Enabling and configuring ZeRO memory optimizations | Description | Default | | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | -| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` | +| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e5` | ***stage3_gather_16bit_weights_on_model_save***: [boolean] @@ -473,6 +580,35 @@ Enabling and configuring ZeRO memory optimizations |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | | Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | +***stage3_module_granularity_threshold***: [integer] + +| Description | Default | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | +| The granularity of a module is determined by the ratio of `parameter_count` / `(1 + descendant_count)`. ZeRO3 classifies modules with a granularity below the threshold as fine-grained, treating them as integral units during parameter fetching. This reduces host and communication overhead from separate hooks. | `0` | + +***zero_hpz_partition_size***: [integer] + +| Description | Default | +| ----------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Number of ranks in hiearchical partitioning ZeRO (hpZ) secondary tensor group of ZeRO++, default is 1 meaning no hpZ, ideal is number of ranks (gpus) per node. | `1` | + +***zero_quantized_weights***: [boolean] + +| Description | Default | +| ----------------------------------------------------------------------------------------------------------------------------------- | ------- | +|Boolean indicating whether to enable communication efficient quantized weights of ZeRO++. | `False` | + +***zero_quantized_gradients***: [boolean] + +| Description | Default | +| ----------------------------------------------------------------------------------------------------------------------------------- | ------- | +|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` | + +**log_trace_cache_warnings**: [boolean] + +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------- | ------- | +| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` | ***cpu_offload***: [boolean] @@ -543,6 +679,7 @@ Note that if the value of "device" is not specified or not supported, an asserti "device": "[cpu|nvme]", "nvme_path": "/local_nvme", "pin_memory": [true|false], + "ratio": 0.3, "buffer_count": 4, "fast_init": false } @@ -565,6 +702,12 @@ Note that if the value of "device" is not specified or not supported, an asserti | ---------------------------------------------------------------------------------------------------- | ------- | | Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. | `false` | +***ratio***: [float] + +| Description | Default | +| ------------------------------------------------------------------- | ------- | +| the ratio of parameters updating (i.e. optimizer step) on CPU side. | 1 | + ***buffer_count***: [integer] | Description | Default | @@ -619,11 +762,103 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s | -------------------------------------------------------------------------------------------------------------- | ------- | | Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. | `true` | +### Tensor Parallel (AutoTP) +Configure AutoTP tensor parallelism for training via the DeepSpeed config and hybrid TP + ZeRO. AutoTP supports ZeRO stages 0, 1, and 2 (stage 3 is not supported). `deepspeed.tp_model_init()` remains supported for backward compatibility but is not required when `tensor_parallel` is set in the config. + +When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_model_tp_plan`), DeepSpeed automatically detects and uses it. In this case, neither `preset_model` nor `partition_config` is required -- just set `autotp_size`. If `partition_config` is also provided, it takes precedence over the model's `tp_plan`. +```json + "tensor_parallel": { + "autotp_size": 4, + "preset_model": "llama", + "tp_overlap_comm": false, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + } + ] + } + } +``` +**tensor_parallel**: [dictionary] + +| Description | Default | +| ------------------------------------------------------------------------------------------ | ------- | +| Enable AutoTP tensor parallelism and configure preset or custom partitioning rules. | `{}` | + +***autotp_size***: [integer] + +| Description | Default | +| --------------------------------------------------------------------------- | ------- | +| Tensor-parallel degree. Set to `0` to disable AutoTP. | `0` | + +***preset_model***: [string] + +| Description | Default | +| ----------------------------------------------------------------------------------------------------- | ------- | +| Built-in model presets: `llama`, `bloom`, `chatglm`, `mixtral`, `deepseek_v2`, `qwen2`, `phi3`. | `null` | + +***tp_overlap_comm***: [boolean] + +| Description | Default | +| -------------------------------------------------------------------------------------------------------- | ------- | +| Overlap tensor-parallel allreduce communication with computation (training only). | `false` | + +***partition_config***: [dictionary] + +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Custom AutoTP layer partitioning rules. Use with or without `preset_model` to customize sharding patterns. | `null` | + +***use_default_specs***: [boolean] + +| Description | Default | +| -------------------------------------------------------------------------------------------------------------------- | ------- | +| Merge custom `layer_specs` with preset defaults when `preset_model` is set; otherwise use only custom specs. | `true` | + +***layer_specs***: [list] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Ordered list of pattern rules that define how to partition matching parameters. | `[]` | + +***patterns***: [list of strings] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Regex patterns to match parameter names for this partition rule. | `[]` | + +***partition_type***: [string] + +| Description | Default | +| ---------------------------------------------------------------------------- | ------- | +| Partition type for matching parameters: `row`, `column`, or `skip`. | `column` | + +***shape***: [list] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Optional sub-parameter shape for fused weights before TP partitioning (e.g., `[2, -1]`). | `null` | + +***partition_dim***: [integer] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Dimension to split when `shape` is provided (e.g., `0` for fused QKV or gate/up). | `null` | + +***model_types***: [list of strings] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------- | ------- | +| Optional model type filters (from `model.config.model_type`) for shared configs. | `null` | + ***ignore_unused_parameters***: [boolean] | Description | Default | | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | -| Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to `False` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. | `True` | +| Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to `True` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. | `True` | ### Logging @@ -1099,15 +1334,16 @@ DeepSpeed Data Efficiency Library includes two techniques: curriculum learning a | ---------------------------------------------------------------------------------------------------------------------------- | ------- | | List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | -### Monitoring Module (TensorBoard, WandB, CSV) +### Monitoring Module **Note:** Deepspeed logs to TensorBoard through PyTorch. Logging to TensorBoard requires that the `tensorboard` package is installed (read more in the [PyTorch documentation](https://pytorch.org/docs/1.8.0/tensorboard.html)). {: .notice--warning} **Note:** Logging to WandB requires that the `wandb` package is installed (read more in the [WandB documentation](https://docs.wandb.ai/quickstart)). {: .notice--warning} +**Note:** Logging to Comet requires that the `comet_ml` package is installed (read more in the [Comet documentation](https://www.comet.com/docs/v2/guides/quickstart/#1-install-and-configure-the-comet-ml-sdk)). +{: .notice--warning} - -Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. +Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), to [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=docs) or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. | Field | Description |Conditions | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | @@ -1117,7 +1353,7 @@ Deepspeed's Monitor module can log training details into a [Tensorboard](https:/ | `Train/Eigenvalues/ModelBlockParam_{i}` | Eigen values per param block. | `eigenvalue` must be enabled. | | `Train/Samples/elapsed_time_ms_forward` | The global duration of the forward pass. | `flops_profiler.enabled` or `wall_clock_breakdown`. | | `Train/Samples/elapsed_time_ms_backward` | The global duration of the forward pass. | `flops_profiler.enabled` or `wall_clock_breakdown`. | -| `Train/Samples/elapsed_time_ms_backward_inner` | The backward time that does not include the the gradient reduction time. Only in cases where the gradient reduction is not overlapped, if it is overlapped then the inner time should be about the same as the entire backward time. | `flops_profiler.enabled` or `wall_clock_breakdown`. | +| `Train/Samples/elapsed_time_ms_backward_inner` | The backward time that does not include the gradient reduction time. Only in cases where the gradient reduction is not overlapped, if it is overlapped then the inner time should be about the same as the entire backward time. | `flops_profiler.enabled` or `wall_clock_breakdown`. | | `Train/Samples/elapsed_time_ms_backward_allreduce` | The global duration of the allreduce operation. | `flops_profiler.enabled` or `wall_clock_breakdown`. | | `Train/Samples/elapsed_time_ms_step` | The optimizer step time | `flops_profiler.enabled` or `wall_clock_breakdown`. | @@ -1161,6 +1397,36 @@ Example of **wandb** configuration: } ``` +**comet**: [dictionary] + +| Fields | Value | Default | +|--- |--- |--- | +| enabled | Whether logging to [Comet](https://www.comet.com/site/) is enabled. | `false` | +| workspace | Comet workspace name. | `None` | +| project | Comet project name. | `None` | +| samples_log_interval | Metrics will be submitted to Comet after processing every `samples_log_intervas` samples. | `100` | +| experiment_name | The name for comet experiment to be used for logging. | `None` | +| api_key | Comet API key. It's not recommended to save the Comet API Key in code. | `None` | +| experiment_key | The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. | `None` | +| online | If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment. Default is `True`. | `None` | +| mode | Control how the Comet experiment is started. "get": Continue logging to an existing experiment identified by the `experiment_key` value. "create": Always creates of a new experiment, useful for HPO sweeps. "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. | `None` | + + +Example of **comet** configuration: + +```json +"comet": { + "enabled": true, + "workspace": "my_workspace", + "project": "my_project", + "samples_log_interval": 50, + "experiment_name": "llama-fine-tuning", + "experiment_key": "0c4a1c4a90664f2a8084e600b19a9d7", + "online": false, + "mode": "get", +} +``` + **csv_monitor**: [dictionary] | Fields | Value |Default | @@ -1435,6 +1701,25 @@ Different quantization sets, this is used for different quantization parameters. } ``` +```json +"compression_training": { + "sparse_pruning":{ + "shared_parameters":{ + "enabled": true, + "schedule_offset": 30, + "schedule_offset_end": 90, + "schedule_offset_stride": 15, + "method": "snip_momentum", + "block_pattern": "4x1", + "dense_ratio": 0.4, + "excluded_modules": ['classifier', 'pooler'] + }, + "different_groups":{ + } + } +} +``` + **shared_parameters**: [dictionary] Shared parameters for all sparse pruning groups. @@ -1443,11 +1728,17 @@ Shared parameters for all sparse pruning groups. | ----- | ----- | ----- | | **enabled**: [boolean] | Enable sparse pruning or not. | `false` | | **schedule_offset**: [integer] | Enable sparse pruning after scheduled steps (can be treated as warmup steps). | `0` | -| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` | +| **schedule_offset_end**: [integer] | Disable sparse pruning after scheduled steps, mandotory for `snip_momentum`. | `0` | +| **schedule_offset_stride**: [integer] | The stride of pruning on training steps, mandotory for `snip_momentum`. | `"1"` | +| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based), topk (dynamic, learnable) or snip_momentum (structured pruning). | `"l1"` | +| **block_pattern**: [string] | Choose different structured pruning block patterns, NxM or N:M (N and M are integers). For instance, "4x1" or "2:4" are common block patterns, mandotory for `snip_momentum`. | `"4x1"` | +| **dense_ratio**: [float] | Used to get the targeted global sparsity ratio, mandotory for `snip_momentum`. | `"0.1"` | +| **excluded_modules**: [list] | Excluded pruning scope on some special modules like output layer. | `[]` | **different_groups**: [dictionary] Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements. +Note for `snip_momentum` method, you can leave it as empty. | Fields | Value | Default | | ----- | ----- | ----- | @@ -1639,6 +1930,26 @@ Different pruning sets, this is used for different pruning parameters. In this e | ------------------------------------------------------------- | ------- | | Use pipeline stages to parallelize the writing of checkpoints.| `false` | +### AutoSP options + +DeepSpeed provides compiler-based optimization passes through the `compile` configuration. This includes enabling Ulysses-styled sequence paralllelism and a custom heuristic selective activation checkpointing pass. To enable Automatic Sequence Parallelism (AutoSP), configure the `compile` section: + +```json +{ + "zero_optimization": {"stage": 0}, + "compile": { + "deepcompile": true, + "passes": ["autosp"], + } +} +``` + +**passes**: [array of strings] + +| Description | Default | +| ------------------------------------------------------------------------ | ------- | +| List of compiler passes to apply. Currently supported: `["autosp"]`. | `[]` | + ### Data Type options ```json diff --git a/docs/_pages/deepspeed4science.md b/docs/_pages/deepspeed4science.md new file mode 100755 index 000000000000..b1aa706ad180 --- /dev/null +++ b/docs/_pages/deepspeed4science.md @@ -0,0 +1,50 @@ +--- +title: "DeepSpeed4Science Overview and Tutorial" +permalink: /deepspeed4science/ +toc: true +toc_label: "Contents" +toc_sticky: true +--- + +In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research). + +To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610): + +``` +@article{song2023deepspeed4science, + title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies}, + author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others}, + journal={arXiv preprint arXiv:2310.04610}, + year={2023} +} +``` + +* [2023/09] We are releasing two techniques: [DeepSpeed4Science large-scale training framework](#new-megatron-deepspeed-for-large-scale-ai4science-model-training), [DS4Sci_EvoformerAttention](#memory-efficient-evoformerattention-kernels) and their scientific applications in structural biology research. + + +## New Megatron-DeepSpeed for Large-Scale AI4Science Model Training + +We are proud to introduce [new Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed), which is an updated framework for large-scale model training. We rebased and enabled DeepSpeed with the newest Megatron-LM for long sequence support and many other capabilities. With the new Megatron-DeepSpeed, users can now train their large AI4Science models like GenSLMs with much longer sequences via a synergetic combination of ZeRO-style data parallelism, tensor parallelism, sequence parallelism, pipeline parallelism, model state offloading, and several newly added memory optimization techniques such as attention mask offloading and position embedding partitioning. + +![new Megatron-DeepSpeed](/assets/images/new-megatron-ds.png){: .align-center} +

+The figure depicts system capability in terms of enabling long sequence lengths for training a 33B parameter GPT-like model using our new Megatron-DeepSpeed framework. The results show that the new Megatron-DeepSpeed enables 9x longer sequence lengths than NVIDIA's Megatron-LM without triggering out-of-memory error. +

+ +To see how the new Megatron-DeepSpeed helps enabling new system capabilities, such as training models with massive sequences length, please read our [tutorial](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support). + +Meanwhile, our new Megatron-DeepSpeed has been applied to genome-scale foundation model [GenSLMs](https://github.com/ramanathanlab/genslm), which is a 2022 [ACM Gordon Bell award](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022) winning genome-scale language model from Argonne National Lab. To achieve their scientific goal, GenSLMs and similar models require very long sequence support for both training and inference that is beyond generic LLM's long-sequence strategies. By leveraging DeepSpeed4Science's new Megatron-DeepSpeed, GenSLMs team is able to train their 25B model with 512K sequence length, much longer than their original 42K sequence length. Detailed information about the methodology can be found at [our website](https://deepspeed4science.ai/2023/09/18/model-showcase-genslms/). GenSLMs team also hosts an [example](https://github.com/ramanathanlab/genslm/tree/main/examples/long-sequences) about how to use DeepSpeed4Science in the GenSLMs repo. + + +## Memory-Efficient EvoformerAttention Kernels + +[Evoformer](https://www.nature.com/articles/s41586-021-03819-2) is a key building block for scientific models such as DeepMind's AlphaFold. However, EvoFormer's multiple sequence alignment (MSA) attention frequently runs into memory explosion problems during training/inference, such as in protein structure prediction models. Existing techniques such as FlashAttention cannot effectively support Evoformer because EvoFormerAttention uses row-wise/column-wise/triangle attention, which are different from standard Transformer self-attention and cross-attention that require custom optimizations. To mitigate the memory explosion problem, we introduce `DS4Sci_EvoformerAttention` kernels, a collection of kernels that improve the memory efficiency of variants of EvoFormer. `DS4Sci_EvoformerAttention` is easy-to-use. To see how you can use it, please refer to our [tutorial](/tutorials/ds4sci_evoformerattention/). + +`DS4Sci_EvoformerAttention` has already been applied to [OpenFold](https://github.com/aqlaboratory/openfold), which is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. With DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about the methodology can be found at [our website](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/). + + + +![DS4Sci_EvoformerAttention](/assets/images/evoformer.png){: .align-center} +

+The figure shows that DeepSpeed's EvoFormerAttention kernels help reduce OpenFold’s peak memory requirement for training by 13X. +

diff --git a/docs/_pages/inference.md b/docs/_pages/inference.md index d63604e1f022..fb3534872439 100755 --- a/docs/_pages/inference.md +++ b/docs/_pages/inference.md @@ -6,8 +6,10 @@ toc: true toc_label: "Contents" --- +>**DeepSpeed-Inference v2 is here and it's called DeepSpeed-FastGen! For the best performance, latest features, and newest model support please see our [DeepSpeed-FastGen release blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen)!** + DeepSpeed-Inference introduces several features to efficiently serve transformer-based PyTorch models. It supports model parallelism (MP) to fit large models that would otherwise not fit in GPU memory. Even for smaller models, MP can be used to reduce latency for inference. To further reduce latency and cost, we introduce inference-customized kernels. Finally, we propose a novel approach to quantize models, called MoQ, to both shrink the model and reduce the inference cost at production. For more details on the inference related optimizations in DeepSpeed, please refer to our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/). -DeepSpeed provides a seamless inference mode for compatible transformer based models trained using DeepSpeed, Megatron, and HuggingFace, meaning that we don’t require any change on the modeling side such as exporting the model or creating a different checkpoint from your trained checkpoints. To run inference on multi-GPU for compatible models, provide the model parallelism degree and the checkpoint information or the model which is already loaded from a checkpoint, and DeepSpeed will do the rest. It will automatically partition the model as necessary, inject compatible high performance kernels into your model and manage the inter-gpu communication. For list of compatible models please see [here](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py). +DeepSpeed provides a seamless inference mode for compatible transformer based models trained using DeepSpeed, Megatron, and HuggingFace, meaning that we don’t require any change on the modeling side such as exporting the model or creating a different checkpoint from your trained checkpoints. To run inference on multi-GPU for compatible models, provide the model parallelism degree and the checkpoint information or the model which is already loaded from a checkpoint, and DeepSpeed will do the rest. It will automatically partition the model as necessary, inject compatible high performance kernels into your model and manage the inter-gpu communication. For list of compatible models please see [here](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py). To get started with DeepSpeed-Inference, please checkout our [tutorial](https://www.deepspeed.ai/tutorials/inference-tutorial/). diff --git a/docs/_pages/training.md b/docs/_pages/training.md index e5eee86564d3..e31651cc487a 100644 --- a/docs/_pages/training.md +++ b/docs/_pages/training.md @@ -44,7 +44,7 @@ optimizations on advanced hyperparameter tuning and optimizers. For example: | 64 V100 GPUs | DeepSpeed | **8.68** hr| | 16 V100 GPUs | DeepSpeed | **33.22** hr| - *BERT codes and tutorials will be available soon.* + *BERT code and tutorials will be available soon.* * DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA Megatron on Azure GPUs. @@ -201,6 +201,7 @@ Enable 16-bit (FP16) training by in the `deepspeed_config` JSON. "loss_scale": 0, "loss_scale_window": 1000, "hysteresis": 2, + "consecutive_hysteresis": false, "min_loss_scale": 1 } ``` @@ -484,7 +485,7 @@ The flops profiler can also be used as a standalone package. Please refer to the ### Autotuning -The DeepSpeed Autotuner uses model information, system information, and heuristics to efficiently tune Zero stage, micro batch size, and other Zero configurations. Using the autotuning feature requires no code change from DeepSpeed users. While `"autotuning": {"enabled": true}` is the minimal required to enable auotuning, there are other parameters users can define to configure the autotuning process. Below shows major parameters and their default values in the autotuning configuration. Please refer to the [Autotuning](/tutorials/autotuning) tutorial for more details. +The DeepSpeed Autotuner uses model information, system information, and heuristics to efficiently tune Zero stage, micro batch size, and other Zero configurations. Using the autotuning feature requires no code change from DeepSpeed users. While `"autotuning": {"enabled": true}` is the minimal required to enable autotuning, there are other parameters users can define to configure the autotuning process. Below shows major parameters and their default values in the autotuning configuration. Please refer to the [Autotuning](/tutorials/autotuning) tutorial for more details. ```json { diff --git a/docs/_posts/2020-02-13-release.md b/docs/_posts/2020-02-13-release.md index 792ff7bfee67..a97a4ba9ccf1 100644 --- a/docs/_posts/2020-02-13-release.md +++ b/docs/_posts/2020-02-13-release.md @@ -3,5 +3,5 @@ title: "ZeRO & DeepSpeed: New system optimizations enable training models with o date: 2020-02-13 link: https://www.microsoft.com/en-us/research/blog/ZeRO-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/ excerpt: "" -tags: training ZeRO +tags: training ZeRO English --- diff --git a/docs/_posts/2020-02-13-turing-nlg.md b/docs/_posts/2020-02-13-turing-nlg.md index 0da59aa8fee3..240f6d78ad02 100644 --- a/docs/_posts/2020-02-13-turing-nlg.md +++ b/docs/_posts/2020-02-13-turing-nlg.md @@ -3,5 +3,5 @@ title: "Turing-NLG: A 17-billion-parameter language model by Microsoft" date: 2020-02-13 link: https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/ excerpt: "DeepSpeed was used to train the world's largest language model." -tags: training +tags: training English --- diff --git a/docs/_posts/2020-03-17-reduce-scatter.md b/docs/_posts/2020-03-17-reduce-scatter.md index 1753a22e3aa7..329409dfefab 100644 --- a/docs/_posts/2020-03-17-reduce-scatter.md +++ b/docs/_posts/2020-03-17-reduce-scatter.md @@ -1,6 +1,7 @@ --- title: "ZeRO stage 1 with reduced communication" sneak_preview: true +tags: training ZeRO English excerpt: "Partition-aware ZeRO with up to 2x reduction in communication time!" --- diff --git a/docs/_posts/2020-05-19-bert-record.md b/docs/_posts/2020-05-19-bert-record.md index 4c2a93e5be86..67d00280348e 100644 --- a/docs/_posts/2020-05-19-bert-record.md +++ b/docs/_posts/2020-05-19-bert-record.md @@ -1,10 +1,9 @@ --- title: "The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels" excerpt: "" -tags: training date: 2020-05-19 00:00:00 toc: false -tags: training +tags: training English --- We introduce new technology to accelerate single GPU performance via kernel @@ -20,4 +19,4 @@ the same number and generation of GPUs. * Brief overview, see our [press release](https://www.microsoft.com/en-us/research/blog/ZeRO-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/). * Detailed technology deep dive, see our [blog post](https://www.deepspeed.ai/2020/05/27/fastest-bert-training.html). * Tutorial on how to reproduce our results, see our [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/). -* The source code for our transformer kernels can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed) and BERT pre-training code can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples). +* The source code for our transformer kernels can be found in the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed) and BERT pre-training code can be found in the [DeepSpeedExamples repo](https://github.com/deepspeedai/deepspeedexamples). diff --git a/docs/_posts/2020-05-19-press-release.md b/docs/_posts/2020-05-19-press-release.md index 9022a7db40c5..a6611b11cb59 100644 --- a/docs/_posts/2020-05-19-press-release.md +++ b/docs/_posts/2020-05-19-press-release.md @@ -2,6 +2,6 @@ title: "ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale" excerpt: "" link: https://www.microsoft.com/en-us/research/blog/ZeRO-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/ -tags: training ZeRO +tags: training ZeRO English date: 2020-05-19 02:00:00 --- diff --git a/docs/_posts/2020-05-19-zero-stage2.md b/docs/_posts/2020-05-19-zero-stage2.md index 4f35012d9aae..44f6cc194bc2 100644 --- a/docs/_posts/2020-05-19-zero-stage2.md +++ b/docs/_posts/2020-05-19-zero-stage2.md @@ -1,7 +1,7 @@ --- title: "An Order-of-Magnitude Larger and Faster Training with ZeRO-2" excerpt: "" -tags: training ZeRO +tags: training ZeRO English date: 2020-05-19 01:00:00 toc: false --- diff --git a/docs/_posts/2020-05-28-fastest-bert-training.md b/docs/_posts/2020-05-28-fastest-bert-training.md index 99d132c1e53d..2154c36fe279 100644 --- a/docs/_posts/2020-05-28-fastest-bert-training.md +++ b/docs/_posts/2020-05-28-fastest-bert-training.md @@ -1,7 +1,7 @@ --- title: "Microsoft DeepSpeed achieves the fastest BERT training time" excerpt: "" -tags: training +tags: training English date: 2020-05-28 00:00:00 --- @@ -284,7 +284,7 @@ and faster convergence. To try out these optimizations and training recipe, please check out our [BERT training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/) and source code at the [DeepSpeed GitHub -repo](https://github.com/microsoft/deepspeed). +repo](https://github.com/deepspeedai/deepspeed). ### References diff --git a/docs/_posts/2020-07-24-deepspeed-webinar.md b/docs/_posts/2020-07-24-deepspeed-webinar.md index be4ee777ed61..a5b4aa15bef5 100644 --- a/docs/_posts/2020-07-24-deepspeed-webinar.md +++ b/docs/_posts/2020-07-24-deepspeed-webinar.md @@ -1,7 +1,7 @@ --- title: "DeepSpeed Microsoft Research Webinar on August 6th, 2020" excerpt: "" -tags: presentations +tags: presentations English link: https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html image: /assets/images/webinar-aug2020.png date: 2020-07-24 00:00:00 diff --git a/docs/_posts/2020-08-07-webinar-on-demand.md b/docs/_posts/2020-08-07-webinar-on-demand.md index 983e17eca36b..8b258e88a9b2 100644 --- a/docs/_posts/2020-08-07-webinar-on-demand.md +++ b/docs/_posts/2020-08-07-webinar-on-demand.md @@ -1,7 +1,7 @@ --- title: "DeepSpeed Microsoft Research Webinar is now on-demand" excerpt: "" -tags: presentations +tags: presentations English link: https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html date: 2020-08-07 00:00:00 --- diff --git a/docs/_posts/2020-09-08-sparse-attention-news.md b/docs/_posts/2020-09-08-sparse-attention-news.md index c5f2a104374d..b9a0aeb88d9b 100644 --- a/docs/_posts/2020-09-08-sparse-attention-news.md +++ b/docs/_posts/2020-09-08-sparse-attention-news.md @@ -1,7 +1,7 @@ --- title: "Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention" excerpt: "" -tags: training +tags: training English date: 2020-09-09 00:00:00 toc: false --- @@ -11,4 +11,4 @@ DeepSpeed offers sparse attention kernels, an instrumental technology to support * Brief overview, see our [press release]({{ site.press_release_v3 }}). * Detailed technology deep dive, see our [blog post](https://www.deepspeed.ai/2020/09/08/sparse-attention.html). * Tutorial on how to use sparse attention, see our [Sparse attention tutorial](https://www.deepspeed.ai/tutorials/sparse-attention/). -* The source code for our sparse attention kernels can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed) and BERT pre-training code using sparse attention can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples). +* The source code for our sparse attention kernels can be found in the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed) and BERT pre-training code using sparse attention can be found in the [DeepSpeedExamples repo](https://github.com/deepspeedai/deepspeedexamples). diff --git a/docs/_posts/2020-09-09-ZeRO-Offload.md b/docs/_posts/2020-09-09-ZeRO-Offload.md index c270ceadf381..e0626f791a4e 100755 --- a/docs/_posts/2020-09-09-ZeRO-Offload.md +++ b/docs/_posts/2020-09-09-ZeRO-Offload.md @@ -2,7 +2,7 @@ title: "10x bigger model training on a single GPU with ZeRO-Offload" excerpt: "" date: 2020-09-09 00:00:00 -tags: training ZeRO +tags: training ZeRO English toc: false --- @@ -10,4 +10,4 @@ We introduce a new technology called ZeRO-Offload to enable **10X bigger model t * For more information on ZeRO-Offload, see our [press release]( {{ site.press_release_v3 }} ). * For more information on how to use ZeRO-Offload, see our [ZeRO-Offload tutorial](https://www.deepspeed.ai/tutorials/ZeRO-offload/). -* The source code for ZeRO-Offload can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed). +* The source code for ZeRO-Offload can be found in the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed). diff --git a/docs/_posts/2020-09-09-onebit-adam-blog-post.md b/docs/_posts/2020-09-09-onebit-adam-blog-post.md index 413a3d0c1afb..8152190f24d0 100644 --- a/docs/_posts/2020-09-09-onebit-adam-blog-post.md +++ b/docs/_posts/2020-09-09-onebit-adam-blog-post.md @@ -2,7 +2,7 @@ title: "DeepSpeed with 1-bit Adam: 5x less communication and 3.4x faster training" excerpt: "" date: 2020-09-09 00:00:00 -tags: training +tags: training English --- ## 1. Introduction diff --git a/docs/_posts/2020-09-09-onebit-adam-news.md b/docs/_posts/2020-09-09-onebit-adam-news.md index 4ec3c3c85ba8..1fd8ef89edce 100644 --- a/docs/_posts/2020-09-09-onebit-adam-news.md +++ b/docs/_posts/2020-09-09-onebit-adam-news.md @@ -2,7 +2,7 @@ title: "Up to 5x less communication and 3.4x faster training through 1-bit Adam" excerpt: "" date: 2020-09-09 00:00:00 -tags: training +tags: training English toc: false --- @@ -17,4 +17,4 @@ its efficient implementation in DeepSpeed. 1-bit Adam offers the ***same converg * Brief overview, see our [press release]({{ site.press_release_v3 }}). * Detailed technology deep dive, see our [blog post](https://www.deepspeed.ai/2020/09/08/onebit-adam-blog-post.html). * Tutorial on how to reproduce our results, see our [1-bit Adam tutorial](/tutorials/onebit-adam/). -* The source code for 1-bit Adam can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed). The implementation of 1-bit Adam is in [onebit_adam.py](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/fp16/onebit_adam.py) and CUDA-Aware communication for 1-bit Adam is in [custom_collectives.py](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/custom_collectives.py). Example codes to try this feature can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples) as shown in the [tutorial](/tutorials/onebit-adam/). +* The source code for 1-bit Adam can be found in the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed). The implementation of 1-bit Adam is in [onebit_adam.py](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/fp16/onebit_adam.py) and CUDA-Aware communication for 1-bit Adam is in [custom_collectives.py](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/custom_collectives.py). Example codes to try this feature can be found in the [DeepSpeedExamples repo](https://github.com/deepspeedai/deepspeedexamples) as shown in the [tutorial](/tutorials/onebit-adam/). diff --git a/docs/_posts/2020-09-09-pipeline-parallelism.md b/docs/_posts/2020-09-09-pipeline-parallelism.md index 4f2e53ed80ee..fe708bc4d50d 100644 --- a/docs/_posts/2020-09-09-pipeline-parallelism.md +++ b/docs/_posts/2020-09-09-pipeline-parallelism.md @@ -2,7 +2,7 @@ title: "Training a Trillion Parameters with Pipeline Parallelism" excerpt: "" date: 2020-09-09 00:00:00 -tags: training +tags: training English --- DeepSpeed includes new support for pipeline parallelism! DeepSpeed's training @@ -14,5 +14,5 @@ low-bandwidth network by up to 7x. * For a brief overview and results including trillion-parameter capabilities, see our [press release]({{ site.press_release_v3 }}). * To get started with pipeline parallel training in DeepSpeed, we recommend our [tutorial](/tutorials/pipeline/). -* See our AlexNet example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples). +* See our AlexNet example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples). * Read our API documentation on [readthedocs](https://deepspeed.readthedocs.io/en/latest/pipeline.html). diff --git a/docs/_posts/2020-09-09-sparse-attention.md b/docs/_posts/2020-09-09-sparse-attention.md index aa0fa0bb60d4..1ab827d6fc8e 100644 --- a/docs/_posts/2020-09-09-sparse-attention.md +++ b/docs/_posts/2020-09-09-sparse-attention.md @@ -2,10 +2,10 @@ title: "DeepSpeed Sparse Attention" excerpt: "" date: 2020-09-09 01:00:00 -tags: training inference +tags: training inference English --- -Attention-based deep learning models such as the transformers are highly effective in capturing relationship between tokens in an input sequence, even across long distances. As a result, they are used with text, image, and sound-based inputs, where the sequence length can be in thousands of tokens. However, despite the effectiveness of attention modules to capture long term dependencies, in practice, their application to long sequence input is limited by compute and memory requirements of the attention computation that grow quadratically, `O(n^2)`, with the sequence length `n`. +Attention-based deep learning models such as the transformers are highly effective in capturing the relationship between tokens in an input sequence, even across long distances. As a result, they are used with text, image, and sound-based inputs, where the sequence length can be in thousands of tokens. However, despite the effectiveness of attention modules to capture long term dependencies, in practice, their application to long sequence input is limited by compute and memory requirements of the attention computation that grow quadratically, `O(n^2)`, with the sequence length `n`. To address this limitation, DeepSpeed offers a suite of sparse attention kernels --an instrumental technology that can reduce the compute and memory requirement of attention computation by orders-of-magnitude via block-sparse computation. The suite not only alleviates the memory bottleneck of attention calculation, but also performs sparse computation efficiently. Its APIs allow convenient integration with any transformer-based models. Along with providing a wide spectrum of sparsity structures, it has the flexibility of handling any user-defined block-sparse structures. More specifically, sparse attention (SA) can be designed to compute local attention between nearby tokens, or global attention via summary tokens computed with local attention. Moreover, SA can also allow random attention, or any combination of local, global, and random attention as shown in the following figure with blue, orange, and green blocks, respectively. As a result, SA decreases the memory footprint to `O(wn)`, in which `1 < w < n` is a parameter, whose value depends on the attention structure. @@ -27,8 +27,8 @@ In a pre-training experiment, we ran BERT model under three settings: dense, den ![Maximum sequence runnable on BERT](/assets/images/sa_maximum_sequence_runnable_on_bert.png){: .align-center} -* **up to 6.3x faster computation** -We continued the pre-training experiment for different batch sizes and sequence lengths, using [BERT base/large](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert) and [Megatron GPT2](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). In this experiment we let the training to continue for 100 iteration and recorded the average time per last 30 iterations. SA reduces total computation comparing with dense and improves training speed: the boost is higher with increased sequence length and it is up to 6.3x faster for BERT base, 5.3x for BERT large, and 6.1x for GPT2. Following charts show these results. +* **Up to 6.3x faster computation** +We continued the pre-training experiment for different batch sizes and sequence lengths, using [BERT base/large](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/bing_bert) and [Megatron GPT2](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/Megatron-LM). In this experiment we let the training to continue for 100 iteration and recorded the average time per last 30 iterations. SA reduces total computation comparing with dense and improves training speed: the boost is higher with increased sequence length and it is up to 6.3x faster for BERT base, 5.3x for BERT large, and 6.1x for GPT2. Following charts show these results. ![Training time for BERT base with varying sequence length](/assets/images/sa_bert_base_time_result.png){: .align-center} @@ -36,14 +36,14 @@ We continued the pre-training experiment for different batch sizes and sequence ![Training time for GPT2 with varying sequence length](/assets/images/sa_gpt2_time_result.png){: .align-center} -* **higher accuracy** +* **Higher accuracy** Related works along the line of sparse attention ([Sparse Transformer](https://arxiv.org/pdf/1904.10509.pdf), [Longformer](https://arxiv.org/pdf/2004.05150.pdf), [BigBird](https://arxiv.org/pdf/2007.14062.pdf)) have shown comparable or higher accuracy than full attention. Our experience is well aligned. In addition to lower memory overhead and faster computation, we also observe cases in production where SA reaches higher accuracy and faster convergence. The following chart illustrates accuracy of training a production model based on BERT for long document comprehension (2,048 sequence length). The experiment is performed in three settings: dense starting from scratch, SA starting from scratch, and SA continued training from a checkpoint of using dense with sequence length of 512. We have observed that, for pre-training from scratch, SA converges faster with higher accuracy comparing with dense. Furthermore, SA continuing training from a pre-trained checkpoint performs even better, with respect to both time and accuracy. ![Accuracy of long document comprehension application](/assets/images/sa_long_document_comprehension_result.png){: .align-center} -* **comparison with state of the art, Longformer** +* **Comparison with state of the art, Longformer** We compared SA with Longformer, a state-of-the-art sparse structure and implementation. In our experiment, SA uses `Fixed` sparsity, and two implementations have comparable accuracy. On system performance, SA outperforms Longformer both in training and inference: * **1.47x** faster execution pre-training MLM on Wikitext103 We ran an experiment following the [notebook](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) offered by Longformer. In this experiment, we pre-train an MLM model using RoBERTa-base checkpoint. This is done on 8 V100-SXM2 GPU. Following table shows the details of the result in which using DeepSpeed Sparse Attention shows 1.47x speed up. @@ -73,7 +73,7 @@ Through our Long Document Comprehension application we described above, we also |32 |1.24 | |16 |1.23 | -* **flexibility to handle any block-sparse structure** +* **Flexibility to handle any block-sparse structure** DeepSpeed Sparse Attention suite does not target at any specific sparse structure but enables model scientists to explore any block sparse structure with efficient system support. Currently, we have added popular sparse structure like: * [Fixed](https://arxiv.org/pdf/1904.10509.pdf) (from OpenAI Sparse Transformer) * [BigBird](https://arxiv.org/pdf/2007.14062.pdf) (from Google) diff --git a/docs/_posts/2020-10-28-progressive-layer-dropping-news.md b/docs/_posts/2020-10-28-progressive-layer-dropping-news.md index 9664e4de94e7..da07edd7b922 100755 --- a/docs/_posts/2020-10-28-progressive-layer-dropping-news.md +++ b/docs/_posts/2020-10-28-progressive-layer-dropping-news.md @@ -2,7 +2,7 @@ title: "Progressive Layer Dropping" excerpt: "" date: 2020-10-29 00:00:00 -tags: training +tags: training English toc: false --- @@ -10,4 +10,4 @@ We introduce a new technology called progressive layer dropping (PLD) to speedup * For detailed technology deep dive, see our [technical report](https://arxiv.org/pdf/2010.13369.pdf). * For more information on how to use PLD, see our [Progressive layer dropping tutorial](https://www.deepspeed.ai/tutorials/progressive_layer_dropping/). - * The source code for PLD is now available at the [DeepSpeed repo](https://github.com/microsoft/deepspeed). + * The source code for PLD is now available at the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed). diff --git a/docs/_posts/2021-03-08-zero3-offload.md b/docs/_posts/2021-03-08-zero3-offload.md index 9008ebc9f6fa..2bca2bdd826a 100644 --- a/docs/_posts/2021-03-08-zero3-offload.md +++ b/docs/_posts/2021-03-08-zero3-offload.md @@ -2,7 +2,7 @@ title: "DeepSpeed ZeRO-3 Offload" excerpt: "" date: 2021-03-08 00:00:00 -tags: training ZeRO +tags: training ZeRO English --- Today we are announcing the release of ZeRO-3 Offload, a highly efficient and easy to use implementation of ZeRO Stage 3 and ZeRO Offload combined, geared towards our continued goal of democratizing AI by making efficient large-scale DL training available to everyone. The key benefits of ZeRO-3 Offload are: diff --git a/docs/_posts/2021-05-05-MoQ.md b/docs/_posts/2021-05-05-MoQ.md index e6f7872a4007..5dd5006e886f 100644 --- a/docs/_posts/2021-05-05-MoQ.md +++ b/docs/_posts/2021-05-05-MoQ.md @@ -2,7 +2,7 @@ title: "Mixture-of-Quantization: A novel quantization approach for reducing model size with minimal accuracy impact" excerpt: "" date: 2021-05-05 00:00:00 -tags: inference +tags: inference English --- ## A unified suite for quantization-aware training and inference diff --git a/docs/_posts/2021-05-05-inference-kernel-optimization.md b/docs/_posts/2021-05-05-inference-kernel-optimization.md index 63e3ac669e22..991295de9759 100644 --- a/docs/_posts/2021-05-05-inference-kernel-optimization.md +++ b/docs/_posts/2021-05-05-inference-kernel-optimization.md @@ -2,7 +2,7 @@ title: "DeepSpeed Inference: Multi-GPU inference with customized inference kernels and quantization support" excerpt: "" date: 2021-03-16 00:00:00 -tags: inference +tags: inference English --- While DeepSpeed supports training advanced large-scale models, using these trained models in the desired application scenarios is still challenging due to three major limitations in existing inference solutions: 1) lack of support for multi-GPU inference to fit large models and meet latency requirements, 2) limited GPU kernel performance when running inference with small batch sizes, and 3) difficulties in exploiting quantization, which includes both quantizing the model to reduce the model size and latency as well as supporting high-performance inference of quantized models without specialized hardware. diff --git a/docs/_posts/2021-05-14-inference-release.md b/docs/_posts/2021-05-14-inference-release.md index fd5cca2e0259..14c300d0bc9f 100644 --- a/docs/_posts/2021-05-14-inference-release.md +++ b/docs/_posts/2021-05-14-inference-release.md @@ -3,5 +3,5 @@ title: "DeepSpeed: Accelerating large-scale model inference and training via sys date: 2021-05-14 link: https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/ excerpt: "" -tags: inference +tags: inference English --- diff --git a/docs/_posts/2021-08-18-deepspeed-moe.md b/docs/_posts/2021-08-18-deepspeed-moe.md index 5bd9667f2a7f..665c09751b55 100644 --- a/docs/_posts/2021-08-18-deepspeed-moe.md +++ b/docs/_posts/2021-08-18-deepspeed-moe.md @@ -3,5 +3,5 @@ title: "DeepSpeed powers 8x larger MoE model training with high performance" excerpt: "" link: https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/ date: 2021-08-18 00:00:00 -tags: training +tags: training English --- diff --git a/docs/_posts/2021-11-15-autotuning.md b/docs/_posts/2021-11-15-autotuning.md index ee48d44c5bdf..410e32c878a3 100644 --- a/docs/_posts/2021-11-15-autotuning.md +++ b/docs/_posts/2021-11-15-autotuning.md @@ -2,14 +2,14 @@ title: "Autotuning: Automatically discover the optimal DeepSpeed configuration that delivers good training speed" excerpt: "" date: 2021-11-16 10:00:00 -tags: training +tags: training English toc: false --- We introduce a new feature called Autotuning to automatically discover the optimal DeepSpeed configuration that delivers good training speed. One pain point in model training is to figure out good performance-relevant configurations such as micro-batch size to fully utilize the hardware and achieve a high throughput number. This configuration exploring process is commonly done manually but is important since model training is repeated many times and benefits from using a good configuration. Not only is the hand-tuning process time-consuming, but the outcome is hardware-dependent. This means that a good configuration on one hardware might not be the best on another different hardware. The user thus has to hand tune the configuration again. With DeepSpeed, there are more configuration parameters that could potentially affect the training speed, thus making it more tedious to manually tune the configuration. -The DeepSpeed Autotuner mitigates this pain point and automatically discovers the optimal DeepSpeed configuration that delivers good training speed. It not only reduces the time and resources users spend on tuning, but also can discover configurations better than hand-tuned methods. [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/autotuning) would demonstrate the effectiveness of autotuning across different models. +The DeepSpeed Autotuner mitigates this pain point and automatically discovers the optimal DeepSpeed configuration that delivers good training speed. It not only reduces the time and resources users spend on tuning, but also can discover configurations better than hand-tuned methods. [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/autotuning) would demonstrate the effectiveness of autotuning across different models. * For a brief overview, see the [Autotuning tutorial](https://www.deepspeed.ai/tutorials/autotuning/). -* For more information on how to use Autotuning, see the [Autotuning README](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning#deepspeed-autotuning). -* The source code can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed). +* For more information on how to use Autotuning, see the [Autotuning README](https://github.com/deepspeedai/DeepSpeed/tree/master/deepspeed/autotuning#deepspeed-autotuning). +* The source code can be found in the [DeepSpeed repo](https://github.com/deepspeedai/deepspeed). diff --git a/docs/_posts/2021-12-09-deepspeed-moe-nlg.md b/docs/_posts/2021-12-09-deepspeed-moe-nlg.md index 6402202cca3b..69fef131d3c0 100644 --- a/docs/_posts/2021-12-09-deepspeed-moe-nlg.md +++ b/docs/_posts/2021-12-09-deepspeed-moe-nlg.md @@ -2,7 +2,7 @@ title: "DeepSpeed-MoE for NLG: Reducing the training cost of language models by 5 times" excerpt: "" date: 2021-12-09 22:00:00 -tags: training +tags: training English --- Autoregressive transformer-based natural language generation (referred to as @@ -170,9 +170,9 @@ high quality language models accessible to a broad audience, even with limited compute resources. To this end we are releasing our [end-to-end pipeline for training MoE based -NLG models](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training), +NLG models](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/moe-training), along with [specific example -scripts](https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training/examples/MoE) +scripts](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/moe-training/examples_deepspeed/MoE) and [tutorial](/tutorials/mixture-of-experts-nlg) to help get started with our pipeline. We look forward to the application and the innovations that this may bring to the deep learning community. diff --git a/docs/_posts/2022-01-19-moe-inference.md b/docs/_posts/2022-01-19-moe-inference.md index f2ac1c6de2e1..66ff5b51ad2d 100644 --- a/docs/_posts/2022-01-19-moe-inference.md +++ b/docs/_posts/2022-01-19-moe-inference.md @@ -3,5 +3,5 @@ title: "DeepSpeed: Advancing MoE inference and training to power next-generation excerpt: "" link: https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/ date: 2022-01-19 00:00:00 -tags: inference +tags: inference English --- diff --git a/docs/_posts/2022-03-21-amd-support.md b/docs/_posts/2022-03-21-amd-support.md index ba8917bc386a..01b2a52c7ca4 100644 --- a/docs/_posts/2022-03-21-amd-support.md +++ b/docs/_posts/2022-03-21-amd-support.md @@ -3,5 +3,5 @@ title: "Supporting efficient large model training on AMD Instinct GPUs with Deep excerpt: "" link: https://cloudblogs.microsoft.com/opensource/2022/03/21/supporting-efficient-large-model-training-on-amd-instinct-gpus-with-deepspeed/ date: 2022-03-21 00:00:00 -tags: training ZeRO +tags: training ZeRO English --- diff --git a/docs/_posts/2022-07-26-deepspeed-azure.md b/docs/_posts/2022-07-26-deepspeed-azure.md index 128cbf4a416e..540f74d4be1b 100644 --- a/docs/_posts/2022-07-26-deepspeed-azure.md +++ b/docs/_posts/2022-07-26-deepspeed-azure.md @@ -2,7 +2,7 @@ title: "Azure empowers easy-to-use, high-performance, and hyperscale model training using DeepSpeed" excerpt: "" date: 2022-07-26 00:09:00 -tags: training azure +tags: training azure English --- ## Introduction @@ -19,7 +19,7 @@ In this extended post, we share the details of how DeepSpeed users can train tri ## Making distributed training faster and easier on Azure using DeepSpeed -We compare the existing manual and error-prone workflow with our proposed easy-to-use workflow for DeepSpeed on Azure in *Figure 2*. Customers can now use easy-to-use [training pipelines](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples) to launch training jobs at scale. The new workflow reduces the number of steps from 11 to just 1 if users rely on the recommended [AzureML](https://azure.microsoft.com/en-us/services/machine-learning/) [recipes](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml). +We compare the existing manual and error-prone workflow with our proposed easy-to-use workflow for DeepSpeed on Azure in *Figure 2*. Customers can now use easy-to-use [training pipelines](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed) to launch training jobs at scale. The new workflow reduces the number of steps from 11 to just 1 if users rely on the recommended [AzureML](https://azure.microsoft.com/en-us/services/machine-learning/) [recipes](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azureml). ![Workflow](/assets/images/old-vs-new-azure.png){: .align-center} @@ -29,7 +29,7 @@ We compare the existing manual and error-prone workflow with our proposed easy-t For users who have custom environments built using Azure VMs or [Azure VMSS](https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/overview), only two steps are needed: - 1) Run the cluster setup script (to be released in the next few weeks) -- 2) Use the Azure VMSS [recipes](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azure) to launch training. +- 2) Use the Azure VMSS [recipes](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azure) to launch training. ## Key Performance Benefits We already shared a summary of our key performance results in the Azure [announcement](https://azure.microsoft.com/en-us/blog/azure-empowers-easytouse-highperformance-and-hyperscale-model-training-using-deepspeed/). We enable the capability to train 2x larger model sizes (2 trillion vs. 1 trillion parameters), scale to 2x more GPUs (1024 vs. 512), and offer up to 1.8x higher compute throughput/GPU (150 TFLOPs vs. 81 TFLOPs) compared to other [cloud providers](https://medium.com/pytorch/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff). @@ -48,7 +48,7 @@ We share the details of our experimental setup and some of the best practices we We used [NDm A100 v4-series](https://docs.microsoft.com/en-us/azure/virtual-machines/ndm-a100-v4-series) instances in our experiments. Each instance includes two socket AMD EPYC 7V12 64-Core CPUs, 1.7TB main memory and eight A100 80GB GPUs. The system has a balanced PCIe topology connecting 4 GPU devices to each CPU socket. Each GPU within the VM is provided with its own dedicated, topology-agnostic 200 Gb/s NVIDIA Mellanox HDR InfiniBand connection providing an accelerated 200 Gbps high speed fabric. The DeepSpeed library exploits offload capabilities where the activation and optimizer states are allocated in the main memory. Hence, 1.7TB memory capacity per node helps us to scale to large model sizes. ### Training setup using AzureML -Users can directly use the AzureML studio and use our published [recipes](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml) to run experiments without any additional setup. This is the easiest and recommended way of running experiments on Azure. +Users can directly use the AzureML studio and use our published [recipes](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azureml) to run experiments without any additional setup. This is the easiest and recommended way of running experiments on Azure. ### Training setup using Azure VMSS @@ -59,7 +59,7 @@ A cluster is created using Azure Virtual Machine Scale Sets (VMSS) to provision | ------------------------------: | :----------------: | | PyTorch | 1.10.2 (installed from source) | | DeepSpeed | 0.6.2 (installed from source) | -| Megatron-LM | [https://github.com/microsoft/Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed) | +| Megatron-LM | [https://github.com/deepspeedai/Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) | | Apex | 0.1 | | NCCL | 2.12.10 | | CUDNN | 8.2.4.15 | @@ -122,9 +122,9 @@ The 2T parameter model consists of 160 layers, 32k hidden dimension, and 128 att We recognize that DeepSpeed users are diverse and have different environments. In this tutorial, our focus is on making things simpler for users who plan to run large model training experiments on Azure. -> The easiest way to do model training on Azure is via the Azure ML recipes. The job submission and data preparation scripts have been made available [here](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml). Users simply need to setup their Azure ML workspace following the [guide](https://github.com/Azure/azureml-examples/tree/main/python-sdk#set-up) and submit experiment using the aml_submit.py file. +> The easiest way to do model training on Azure is via the Azure ML recipes. The job submission and data preparation scripts have been made available [here](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azureml). Users simply need to setup their Azure ML workspace following the [guide](https://github.com/Azure/azureml-examples/tree/main/python-sdk#set-up) and submit experiment using the aml_submit.py file. -Some users have customized environments built on top of Azure VMs and VMSS based clusters. To simplify training on such setups, we are working on an easy-to-use cluster setup script that will be published in the next few weeks. If you already have a cluster setup running, you can use the [azure recipes](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azure) for the 175B and the 1T model. The recipes can easily be modified to train other model configurations. +Some users have customized environments built on top of Azure VMs and VMSS based clusters. To simplify training on such setups, we are working on an easy-to-use cluster setup script that will be published in the next few weeks. If you already have a cluster setup running, you can use the [azure recipes](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azure) for the 175B and the 1T model. The recipes can easily be modified to train other model configurations. ## Acknowledgement diff --git a/docs/_posts/2022-09-10-zero-inference.md b/docs/_posts/2022-09-10-zero-inference.md index dd718b9f8839..3c588e39c1dc 100644 --- a/docs/_posts/2022-09-10-zero-inference.md +++ b/docs/_posts/2022-09-10-zero-inference.md @@ -2,7 +2,7 @@ title: "ZeRO-Inference: Democratizing massive model inference" excerpt: "" date: 2022-09-10 00:09:00 -tags: inference ZeRO +tags: inference ZeRO English --- ## Introduction @@ -83,7 +83,7 @@ Next, we measure the impact on generation throughput using four V100-32GB GPUs. We briefly discuss how users can determine when ZeRO-Inference is suitable for their application and how to enable ZeRO-Inference in DeepSpeed. ### When to use ZeRO-Inference -ZeRO-Inference is designed for inference applications that require GPU acceleration but lack sufficient GPU memory to host the model. Also, ZeRO-Inference is optimized for inference applications that are **throughput-oriented** and allow **large batch sizes**. Alternative techniques, such as [Accelerate](https://github.com/huggingface/accelerate), [DeepSpeed-Inference](https://www.deepspeed.ai/inference/), and [DeepSpeed-MII](https://github.com/microsoft/deepspeed-mii) that fit the entire model into GPU memory, possibly using multiple GPUs, are more suitable for inference applications that are latency sensitive or have small batch sizes. +ZeRO-Inference is designed for inference applications that require GPU acceleration but lack sufficient GPU memory to host the model. Also, ZeRO-Inference is optimized for inference applications that are **throughput-oriented** and allow **large batch sizes**. Alternative techniques, such as [Accelerate](https://github.com/huggingface/accelerate), [DeepSpeed-Inference](https://www.deepspeed.ai/inference/), and [DeepSpeed-MII](https://github.com/deepspeedai/deepspeed-mii) that fit the entire model into GPU memory, possibly using multiple GPUs, are more suitable for inference applications that are latency sensitive or have small batch sizes. ### How to use ZeRO-Inference ZeRO-Inference is available in the DeepSpeed library versions >= 0.6.6. Integrating ZeRO-Inference into token generation pipelines, such as [Hugging Face generate](https://huggingface.co/docs/transformers/main_classes/text_generation), requires updating the DeepSpeed configuration to set [ZeRO optimization](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training) to stage 3 and [parameter offloading](https://www.deepspeed.ai/docs/config-json/#parameter-offloading) to CPU or NVMe. diff --git a/docs/_posts/2022-10-11-mii.md b/docs/_posts/2022-10-11-mii.md index 8a3973175965..324b7ffbad33 100644 --- a/docs/_posts/2022-10-11-mii.md +++ b/docs/_posts/2022-10-11-mii.md @@ -2,7 +2,7 @@ title: "DeepSpeed-MII: instant speedup on 24,000+ open-source DL models with up to 40x cheaper inference" excerpt: "" date: 2022-10-11 00:09:00 -tags: inference +tags: inference English --- [ ![Text Generation Models](/assets/images/mii/hero.png) ](/assets/images/mii/hero.png){: .align-center} @@ -11,7 +11,7 @@ The Deep Learning (DL) open-source community has seen tremendous growth in the l There has been significant progress in system optimizations for DL model inference that can drastically reduce both latency and cost, but those are not easily accessible. The main reason for this limited accessibility is that the DL model inference landscape is diverse with models varying in size, architecture, system performance characteristics, hardware requirements, etc. Identifying the appropriate set of system optimizations applicable to a given model and applying them correctly is often beyond the scope of most data scientists, making low latency and low-cost inference mostly inaccessible. -[DeepSpeed Model Implementations for Inference (MII)](https://github.com/microsoft/DeepSpeed-MII) is a new open-source python library from DeepSpeed, aimed towards making low-latency, low-cost inference of powerful models not only feasible but also easily accessible. +[DeepSpeed Model Implementations for Inference (MII)](https://github.com/deepspeedai/DeepSpeed-MII) is a new open-source python library from DeepSpeed, aimed towards making low-latency, low-cost inference of powerful models not only feasible but also easily accessible. * MII offers access to highly optimized implementations of **thousands of widely used DL models.** * MII supported models achieve significantly lower latency and cost compared to their original implementation. @@ -33,7 +33,7 @@ Under-the-hood MII is powered by [DeepSpeed-Inference](https://arxiv.org/abs/220 MII supports a growing list of tasks such as text generation, question-answering, text classification, etc, across thousands of transformer models available through multiple open-sourced model repositories such as Hugging Face, FairSeq, EluetherAI, etc. It supports dense models based on BERT, RoBERTa, GPT, OPT, and BLOOM architectures ranging from a few hundred million parameters in size to hundreds of billions of parameters in size. At the same time, it supports recent image generation models such as Stable Diffusion. -See the MII GitHub repo for an up-to-date list of [models and tasks supported by MII](https://github.com/microsoft/deepspeed-mii#supported-models-and-tasks). +See the MII GitHub repo for an up-to-date list of [models and tasks supported by MII](https://github.com/deepspeedai/deepspeed-mii#supported-models-and-tasks). # Inference Optimizations with MII @@ -133,7 +133,7 @@ mii.deploy(task="text-to-image", deployment_type=DeploymentType.AML) ``` -To learn more about these deployment options and get started with MII, please the [MII getting started guide](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii). +To learn more about these deployment options and get started with MII, please the [MII getting started guide](https://github.com/deepspeedai/deepspeed-mii#getting-started-with-mii). # Concluding Remarks diff --git a/docs/_posts/2022-12-12-data-efficiency.md b/docs/_posts/2022-12-12-data-efficiency.md index 3b6adb4d7dab..82931a30e167 100644 --- a/docs/_posts/2022-12-12-data-efficiency.md +++ b/docs/_posts/2022-12-12-data-efficiency.md @@ -2,7 +2,7 @@ title: "DeepSpeed Data Efficiency: A composable library that makes better use of data, increases training efficiency, and improves model quality" excerpt: "" date: 2022-12-12 00:09:00 -tags: training +tags: training English --- [ ![DeepSpeed Data Efficiency](/assets/images/data_efficiency/data_efficiecy_fig0.png) ](/assets/images/data_efficiency/data_efficiecy_fig0.png){: .align-center} @@ -141,4 +141,4 @@ The composed DeepSpeed Data Efficiency solution leverages both data efficiency t # Concluding Remarks -We are very excited to share DeepSpeed Data Efficiency library with the community and improve it with your feedback. Please find the code, tutorial, and documents at the [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed), and [website](/tutorials/data-efficiency/). And for more technical details please read our [Random-LTD paper](https://arxiv.org/abs/2211.11586) and [DeepSpeed Data Efficiency paper](https://arxiv.org/abs/2212.03597). We believe that our composable library and novel data efficiency techniques will help users reduce training cost while maintaining model quality or achieve better quality under similar cost. And we hope DeepSpeed Data Efficiency could become a platform that motivates and accelerates future research on deep learning data efficiency. +We are very excited to share DeepSpeed Data Efficiency library with the community and improve it with your feedback. Please find the code, tutorial, and documents at the [DeepSpeed GitHub](https://github.com/deepspeedai/DeepSpeed), and [website](/tutorials/data-efficiency/). And for more technical details please read our [Random-LTD paper](https://arxiv.org/abs/2211.11586) and [DeepSpeed Data Efficiency paper](https://arxiv.org/abs/2212.03597). We believe that our composable library and novel data efficiency techniques will help users reduce training cost while maintaining model quality or achieve better quality under similar cost. And we hope DeepSpeed Data Efficiency could become a platform that motivates and accelerates future research on deep learning data efficiency. diff --git a/docs/_posts/2023-03-31-multi-modal.md b/docs/_posts/2023-03-31-multi-modal.md index 045c92719521..63ea2f94f850 100644 --- a/docs/_posts/2023-03-31-multi-modal.md +++ b/docs/_posts/2023-03-31-multi-modal.md @@ -2,7 +2,7 @@ title: "Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE " excerpt: "" date: 2023-03-31 00:09:00 -tags: training +tags: training English --- The field of Artificial Intelligence-Generated Content (AIGC) is rapidly growing, with the goal of making content creation more efficient and accessible. One of the most exciting areas of AIGC is the development of large-scale multi-modal models like [Flamingo](https://arxiv.org/abs/2204.14198), [BLIP](https://arxiv.org/abs/2301.12597), and [GPT4](https://arxiv.org/abs/2303.08774), which can accept inputs from multiple resources, e.g., image, text, audio, etc., and generate a variety of formats as outputs. For example, image creation can be made through stable diffusion and DALLE using the prompt text, and the new feature in the coming Office can create slides with texts, images, animations, etc., by leveraging the power of the new Microsoft Office Copilot. @@ -34,4 +34,4 @@ Specifically, we incorporate the MoE structure into the classical single-tower m A sophisticated MoE model design requires a highly efficient and scalable training system that can support multi-dimensional parallelism and efficient memory management. [DeepSpeed MoE](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/) training system offers such advanced capabilities including easy-to-use APIs enabling flexible combinations of data, tensor, and expert parallelism. Furthermore, DeepSpeed MoE enables larger model scale than state-of-the-art systems by exploiting expert parallelism and [ZeRO optimizations](https://arxiv.org/abs/1910.02054) together. By leveraging the DeepSpeed MoE system, VL-MoE Base with 32 experts achieves similar model quality as VLMO-dense Large with about 2.5x training speedup. -[DeepSpeed MoE](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/) system is already open-sourced and can be easily used as plug-and-play component to achieve high-performance low-cost training for any large-scale MoE models. The tutorial of how to use DeepSpeed MoE is available [here](https://www.deepspeed.ai/tutorials/mixture-of-experts/). VL-MoE is currently in the process of being integrated as a model example of [DeepSpeed Examples](https://github.com/microsoft/DeepSpeedExamples). Please stay tuned for our upcoming updates on this thread. +[DeepSpeed MoE](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/) system is already open-sourced and can be easily used as plug-and-play component to achieve high-performance low-cost training for any large-scale MoE models. The tutorial of how to use DeepSpeed MoE is available [here](https://www.deepspeed.ai/tutorials/mixture-of-experts/). VL-MoE is currently in the process of being integrated as a model example of [DeepSpeed Examples](https://github.com/deepspeedai/DeepSpeedExamples). Please stay tuned for our upcoming updates on this thread. diff --git a/docs/_posts/2023-04-24-deepspeed-chat-chinese.md b/docs/_posts/2023-04-24-deepspeed-chat-chinese.md new file mode 100644 index 000000000000..57a77caab32d --- /dev/null +++ b/docs/_posts/2023-04-24-deepspeed-chat-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/chinese/README.md +date: 2023-04-24 00:00:00 +tags: training ZeRO RLHF Chinese +--- diff --git a/docs/_posts/2023-04-24-deepspeed-chat-japanese.md b/docs/_posts/2023-04-24-deepspeed-chat-japanese.md new file mode 100644 index 000000000000..ee3c8dca00fa --- /dev/null +++ b/docs/_posts/2023-04-24-deepspeed-chat-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Chat: ChatGPTライクなモデルを簡単・高速・低コストに、あらゆるスケールで学習" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/japanese/README.md +date: 2023-04-24 00:00:00 +tags: training ZeRO RLHF Japanese +--- diff --git a/docs/_posts/2023-04-24-deepspeed-chat.md b/docs/_posts/2023-04-24-deepspeed-chat.md new file mode 100644 index 000000000000..f6cad798ca99 --- /dev/null +++ b/docs/_posts/2023-04-24-deepspeed-chat.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-chat/README.md +date: 2023-04-24 00:00:00 +tags: training ZeRO RLHF English +--- diff --git a/docs/_posts/2023-06-07-deepspeed-overview-japanese.md b/docs/_posts/2023-06-07-deepspeed-overview-japanese.md new file mode 100644 index 000000000000..8f42093991c8 --- /dev/null +++ b/docs/_posts/2023-06-07-deepspeed-overview-japanese.md @@ -0,0 +1,8 @@ +--- +title: "DeepSpeed主要技術の概要紹介" +excerpt: "" +date: 2023-06-07 00:00:00 +tags: inference training ZeRO RLHF Japanese presentations +--- + +我々が研究開発しているDeepSpeedについて、主要技術を日本語で説明した資料を公開しました。GPT3やChatGPTのような生成型AIのための大規模言語モデルを含む、様々な深層学習の訓練や推論に容易に適用でき、モデルの大規模化、高速化、コスト削減を可能にします。[こちら](/assets/files/DeepSpeed_Overview_Japanese_2023Jun7th.pdf)よりダウンロードしてください。 diff --git a/docs/_posts/2023-06-22-zeropp-chinese.md b/docs/_posts/2023-06-22-zeropp-chinese.md new file mode 100644 index 000000000000..71dc2d51cb70 --- /dev/null +++ b/docs/_posts/2023-06-22-zeropp-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md +date: 2023-06-22 00:00:00 +tags: training ZeRO RLHF Chinese +--- diff --git a/docs/_posts/2023-06-22-zeropp-japanese.md b/docs/_posts/2023-06-22-zeropp-japanese.md new file mode 100644 index 000000000000..e81013d11aba --- /dev/null +++ b/docs/_posts/2023-06-22-zeropp-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed ZeRO++: LLMやチャットモデルの訓練を劇的に高速化 – 通信オーバヘッドを1/4に大幅削減 -" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md +date: 2023-06-22 00:00:00 +tags: training ZeRO RLHF Japanese +--- diff --git a/docs/_posts/2023-06-22-zeropp.md b/docs/_posts/2023-06-22-zeropp.md new file mode 100644 index 000000000000..d301942a00cd --- /dev/null +++ b/docs/_posts/2023-06-22-zeropp.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed ZeRO++: A leap in speed for LLM and chat model training with 4X less communication" +excerpt: "" +link: https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/ +date: 2023-06-22 00:00:00 +tags: training ZeRO RLHF English +--- diff --git a/docs/_posts/2023-08-24-ulysses-chinese.md b/docs/_posts/2023-08-24-ulysses-chinese.md new file mode 100644 index 000000000000..f8d269217b7a --- /dev/null +++ b/docs/_posts/2023-08-24-ulysses-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Ulysses: 训练极长序列Transformer模型的系统优化" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md +date: 2023-08-24 00:00:00 +tags: training ZeRO Chinese +--- diff --git a/docs/_posts/2023-08-24-ulysses-japanese.md b/docs/_posts/2023-08-24-ulysses-japanese.md new file mode 100644 index 000000000000..291407a5523e --- /dev/null +++ b/docs/_posts/2023-08-24-ulysses-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Ulysses: Transformerモデルを非常に長いシーケンスで訓練するための最適化" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md +date: 2023-08-24 00:00:00 +tags: training ZeRO Japanese +--- diff --git a/docs/_posts/2023-08-24-ulysses.md b/docs/_posts/2023-08-24-ulysses.md new file mode 100644 index 000000000000..c10b2d599f02 --- /dev/null +++ b/docs/_posts/2023-08-24-ulysses.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md +date: 2023-08-24 00:00:00 +tags: training ZeRO English +--- diff --git a/docs/_posts/2023-09-12-ZeRO-Inference.md b/docs/_posts/2023-09-12-ZeRO-Inference.md new file mode 100644 index 000000000000..04a6347bec59 --- /dev/null +++ b/docs/_posts/2023-09-12-ZeRO-Inference.md @@ -0,0 +1,6 @@ +title: "ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md +date: 2023-09-12 00:09:00 +tags: inference ZeRO quantization English +--- diff --git a/docs/_posts/2023-09-19-deepspeed4science-chinese.md b/docs/_posts/2023-09-19-deepspeed4science-chinese.md new file mode 100644 index 000000000000..651d61a3b79c --- /dev/null +++ b/docs/_posts/2023-09-19-deepspeed4science-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed4Science:利用先进的AI系统优化技术实现科学发现" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md +date: 2023-09-19 00:00:00 +tags: training inference science Chinese +--- diff --git a/docs/_posts/2023-09-19-deepspeed4science-japanese.md b/docs/_posts/2023-09-19-deepspeed4science-japanese.md new file mode 100644 index 000000000000..20d83c8e0b5a --- /dev/null +++ b/docs/_posts/2023-09-19-deepspeed4science-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed4Scienceイニシアティブ: 洗練されたAIシステムのテクノロジーにより大規模な科学的発見を可能に" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md +date: 2023-09-19 00:00:00 +tags: training inference science Japanese +--- diff --git a/docs/_posts/2023-09-19-deepspeed4science.md b/docs/_posts/2023-09-19-deepspeed4science.md new file mode 100644 index 000000000000..faeaa1331944 --- /dev/null +++ b/docs/_posts/2023-09-19-deepspeed4science.md @@ -0,0 +1,7 @@ +--- +title: "Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies" +excerpt: "" +link: https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/ +date: 2023-09-19 00:00:00 +tags: training inference science English +--- diff --git a/docs/_posts/2023-10-04-deepspeed-visualchat-chinese.md b/docs/_posts/2023-10-04-deepspeed-visualchat-chinese.md new file mode 100644 index 000000000000..1e0ef0bed34b --- /dev/null +++ b/docs/_posts/2023-10-04-deepspeed-visualchat-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-VisualChat:多轮图像+文字,为你展现不一样的AI聊天魅力" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md +date: 2023-10-04 00:00:00 +tags: training Chinese +--- diff --git a/docs/_posts/2023-10-04-deepspeed-visualchat-japanese.md b/docs/_posts/2023-10-04-deepspeed-visualchat-japanese.md new file mode 100644 index 000000000000..745e9052358e --- /dev/null +++ b/docs/_posts/2023-10-04-deepspeed-visualchat-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-VisualChat: 複数ラウンド・複数画像の入力が可能なAIチャット体験を実現" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md +date: 2023-10-04 00:00:00 +tags: training Japanese +--- diff --git a/docs/_posts/2023-10-04-deepspeed-visualchat.md b/docs/_posts/2023-10-04-deepspeed-visualchat.md new file mode 100644 index 000000000000..8226597290b2 --- /dev/null +++ b/docs/_posts/2023-10-04-deepspeed-visualchat.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md +date: 2023-10-04 00:00:00 +tags: training English +--- diff --git a/docs/_posts/2023-11-06-deepspeed-fastgen-chinese.md b/docs/_posts/2023-11-06-deepspeed-fastgen-chinese.md new file mode 100644 index 000000000000..ec936bb6d79e --- /dev/null +++ b/docs/_posts/2023-11-06-deepspeed-fastgen-chinese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-FastGen:通过 MII 和 DeepSpeed-Inference 实现 LLM 高吞吐量文本生成" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md +date: 2023-11-06 00:00:00 +tags: inference Chinese +--- diff --git a/docs/_posts/2023-11-06-deepspeed-fastgen-japanese.md b/docs/_posts/2023-11-06-deepspeed-fastgen-japanese.md new file mode 100644 index 000000000000..a64b29c88163 --- /dev/null +++ b/docs/_posts/2023-11-06-deepspeed-fastgen-japanese.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-FastGen: MIIとDeepSpeed-InferenceによるLLMのための高速なテキスト生成" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md +date: 2023-11-06 00:00:00 +tags: inference Japanese +--- diff --git a/docs/_posts/2023-11-06-deepspeed-fastgen.md b/docs/_posts/2023-11-06-deepspeed-fastgen.md new file mode 100644 index 000000000000..d9062ce56da3 --- /dev/null +++ b/docs/_posts/2023-11-06-deepspeed-fastgen.md @@ -0,0 +1,7 @@ +--- +title: "DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference" +excerpt: "" +link: https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen +date: 2023-11-06 00:00:00 +tags: inference English +--- diff --git a/docs/_sass/minimal-mistakes/_sidebar.scss b/docs/_sass/minimal-mistakes/_sidebar.scss index 63cef338c583..5f657e112d82 100644 --- a/docs/_sass/minimal-mistakes/_sidebar.scss +++ b/docs/_sass/minimal-mistakes/_sidebar.scss @@ -76,7 +76,7 @@ @include breakpoint($large) { position: absolute; - top: 0; + top: auto; right: 0; width: $right-sidebar-width-narrow; margin-right: -1.5 * $right-sidebar-width-narrow; @@ -94,7 +94,7 @@ @include breakpoint($x-large) { width: $right-sidebar-width; - margin-right: -1.5 * $right-sidebar-width; + margin-right: -1.5 * $right-sidebar-width-narrow; } } diff --git a/docs/_tutorials/MoQ-tutorial.md b/docs/_tutorials/MoQ-tutorial.md index 587ef4b1ed67..ffee37b158f6 100644 --- a/docs/_tutorials/MoQ-tutorial.md +++ b/docs/_tutorials/MoQ-tutorial.md @@ -3,7 +3,7 @@ title: "DeepSpeed Mixture-of-Quantization (MoQ)" tags: training quantization --- -DeepSpeed introduces new support for model compression using quantization, called Mixture-of-Quantization (MoQ). MoQ is designed on top of QAT (Quantization-Aware Training), with the difference that it schedules various data precisions across the training process. It starts with quantizing the model with a high precision, such as FP16 or 16-bit quantization, and reduce the precision through a pre-defined schedule until reaching the target quantization bits (like 8-bit). Moreover, we use second-order information of the model parameters to dynamically adjust the quantization schedule for each of layer of the network separately. We have seen that by adding such schedule and using various data precision in the training process, we can quantize the model with better quality and preserve accuracy. For a better understanding of MoQ methodology, please refer to MoQ deep-dive, [here](https://www.deepspeed.ai/2021/05/04/MoQ.html). +DeepSpeed introduces new support for model compression using quantization, called Mixture-of-Quantization (MoQ). MoQ is designed on top of QAT (Quantization-Aware Training), with the difference that it schedules various data precisions across the training process. It starts with quantizing the model with a high precision, such as FP16 or 16-bit quantization, and reduce the precision through a pre-defined schedule until reaching the target quantization bits (like 8-bit). Moreover, we use second-order information of the model parameters to dynamically adjust the quantization schedule for each layer of the network separately. We have seen that by adding such schedule and using various data precision in the training process, we can quantize the model with better quality and preserve accuracy. For a better understanding of MoQ methodology, please refer to MoQ deep-dive, [here](https://www.deepspeed.ai/2021/05/04/MoQ.html). Below, we use fine-tune for the GLUE tasks as an illustration of how to use MoQ. @@ -71,7 +71,7 @@ Before fine-tuning the GLUE tasks using DeepSpeed MoQ, you need: ### DeepSpeed Configuration File -Prepare a config file `test.json` as below, please note following important parameters for quantization training: +Prepare a config file `test.json` as below, please note the following important parameters for quantization training: ``` { @@ -134,7 +134,7 @@ python text-classification/run_glue.py \ --deepspeed test.json ``` -Running this script will get `MPRC` accuracy and F1 metric results with MoQ quantization. +Running this script will get `MRPC` accuracy and F1 metric results with MoQ quantization. ### Quantization with dynamic schedule using second-order information (Eigenvalue) diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md new file mode 100644 index 000000000000..30a362b82d25 --- /dev/null +++ b/docs/_tutorials/accelerator-abstraction-interface.md @@ -0,0 +1,97 @@ +--- +title: DeepSpeed Accelerator Abstraction Interface +tags: getting-started training accelerator +--- + +# Contents +- [Contents](#contents) +- [Introduction](#introduction) +- [Write accelerator agnostic models](#write-accelerator-agnostic-models) + - [Port accelerator runtime calls](#port-accelerator-runtime-calls) + - [Port accelerator device name](#port-accelerator-device-name) + - [Tensor operations](#tensor-operations) + - [Communication backend](#communication-backend) +- [Run DeepSpeed model on different accelerators](#run-deepspeed-model-on-different-accelerators) +- [Implement new accelerator extension](#implement-new-accelerator-extension) + +# Introduction +The DeepSpeed Accelerator Abstraction allows user to run large language model seamlessly on various Deep Learning acceleration hardware with DeepSpeed. It offers a set of accelerator runtime and accelerator op builder interface which can be implemented for different hardware. This means user can write large language model code without hardware specific code. With DeepSpeed Accelerator Abstraction, the same large language model can run on different hardware platform, without the need to rewrite model code. This makes running large language model on different hardware easier. + +This document covers three topics related to DeepSpeed Accelerator Abstraction Interface: +1. Write accelerator agnostic models using DeepSpeed Accelerator Abstraction Interface. +2. Run DeepSpeed model on different accelerators. +3. Implement new accelerator extension for DeepSpeed Accelerator Abstraction Interface. + +# Write accelerator agnostic models +In this part, you will learn how to write a model that does not contain HW specific code, or how to port a model that run on a specific HW only to be accelerator agnostic. To do this, we first import `get_accelerator` from `deepspeed.accelerator` +``` +from deepspeed.accelerator import get_accelerator +``` +Note: `get_accelerator()` is the entrance to DeepSpeed Accelerator Abstraction Interface +## Port accelerator runtime calls +First we need to port accelerator runtime calls. On CUDA device, accelerator runtime call appears in the form of `torch.cuda.(...)`. With DeepSpeed Accelerator Abstract Interface, such accelerator runtime call can be written in the form of `get_accelerator().(...)` which will be accelerator agnostic. + +A typical conversion looks like the following example: + +``` +if torch.cuda.is_available(): + ... +``` +--> +``` +if get_accelerator().is_available(): + ... +``` + +For most `torch.cuda.(...)` call, we can literally replace `torch.cuda` with `get_accelerator()`. However, there are some exceptions that needs attention: +1. For `torch.cuda.current_device()`, we need to know whether calling this interface is to get device index, or supply the return value as a device. If we want to use the return value as a device string, we need to call `get_accelerator().current_device_name()`. For example: +``` +torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name()) +``` +However, if we wish to get device index as a number, we should call `get_accelerator().current_device()` +``` +local_rank = get_accelerator().current_device() +``` +2. For `torch.cuda.default_generators[index]`, convert to `get_accelerator().default_generator(index)` + +## Port accelerator device name +For CUDA specific device name such as `'cuda'` or `'cuda:0'`, or `'cuda:1'`, we convert them to `get_accelerator().device_name()`, `get_accelerator().device_name(0)`, and `get_accelerator().device_name(1)`. + +A device name without index can be used if model need to do specific thing for certain accelerator. We suggest to make as less as such usage only for situations can not be resolve other way. + +## Tensor operations +CUDA specific tensor operations needs to be converted according to the following rules: +- When we convert a torch tensor to accelerator device such as `my_tensor.cuda()`, we use `my_tensor.to(get_accelerator().device_name())` + +- When we check whether a torch tensor is on accelerator device such as `my_tensor.is_cuda`, we use `get_accelerator().on_accelerator(my_tensor)` + +- When pin a tensor to GPU memory such as `my_tensor.pin_memory()`, we use `get_accelerator().pin_memory(my_tensor)` + +## Communication backend +When a communication backend string is used, the interface `get_accelerator().communication_backend_name()` is used get get communication backend name. So instead of: +``` +torch.distributed.init_process_group('nccl') +``` +, we use: +``` +torch.distributed.init_process_group(get_accelerator().communication_backend_name()) +``` + +# Run DeepSpeed model on different accelerators +[Accelerator Setup Guide](accelerator-setup-guide.md) provides a guide on how to setup different accelerators for DeepSpeed. It also comes with simple example how to run deepspeed for different accelerators. The following guides are provided: +1. Run DeepSpeed model on CPU +2. Run DeepSpeed model on XPU +3. Run DeepSpeed model on Huawei Ascend NPU + +# Implement new accelerator extension +It is possible to implement a new DeepSpeed accelerator extension to support new accelerator in DeepSpeed. An example to follow is _[Intel Extension For DeepSpeed](https://github.com/intel/intel-extension-for-deepspeed/)_. An accelerator extension contains the following components: +1. XYZ_Accelerator(DeepSpeedAccelerator) class definition, where 'XYZ' is the accelerator name, such as 'XPU' or 'CPU'. +This class implements `class DeepSpeedAccelerator` and will be returned by `get_accelerator()` in DeepSpeed. +2. Op builders following https://github.com/intel/intel-extension-for-deepspeed/tree/main/intel_extension_for_deepspeed/op_builder. All op builders needs to inherit `deepspeed.ops.op_builder.builder.OpBuilder` directly or indirectly. A common practice is to implement a base op builder (SYCLOpBuilder in the case of Intel Extension for DeepSpeed) and inherit this base op builder instead. +3. Op kernels as in the following [link](https://github.com/intel/intel-extension-for-deepspeed/tree/main/intel_extension_for_deepspeed/op_builder/csrc). + +Note that an extension does not have to implement all op builders under https://github.com/deepspeedai/DeepSpeed/tree/master/op_builder all at a time. A missing op builder usually means certain DeepSpeed functionality cannot be used for that Accelerator, but models that does not use that functionality can still run. + +When implementing op builder for an accelerator extension, one thing needs to be noted is that the op builder native code is being built by DeepSpeed jit load mechanism. This mean the native source file being built needs to be in DeepSpeed installation directory. However these files are defined in accelerator extension installation directory, which cannot be built by DeepSpeed directly. To solve this, follow the example in https://github.com/intel/intel-extension-for-deepspeed/blob/main/intel_extension_for_deepspeed/op_builder/cpu_adam.py to use 'sycl_kernel_path' and 'sycl_kernel_include' (User can change 'sycl' to other prefix in their own accelerator extension) to allow native code be built during DeepSpeed jit load. + +When accelerator extension is installed in the environment, it can be used by either explicit call deepspeed.accelerator.set_accelerator(XYZ_Accelerator()) following the example in https://github.com/deepspeedai/DeepSpeed/blob/master/accelerator/real_accelerator.py, or add an implicit detection code in get_accelerator in the same file above. diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md new file mode 100644 index 000000000000..20e667170eaa --- /dev/null +++ b/docs/_tutorials/accelerator-setup-guide.md @@ -0,0 +1,278 @@ +--- +title: DeepSpeed Accelerator Setup Guides +tags: getting-started training accelerator +--- + +# Contents +- [Contents](#contents) +- [Introduction](#introduction) +- [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu) +- [Intel XPU](#intel-xpu) +- [Huawei Ascend NPU](#huawei-ascend-npu) +- [Intel Gaudi](#intel-gaudi) + +# Introduction +DeepSpeed supports different accelerators from different companies. Setup steps to run DeepSpeed on certain accelerators might be different. This guide allows user to lookup setup instructions for the accelerator family and hardware they are using. + +# Intel Architecture (IA) CPU +DeepSpeed supports CPU with Intel Architecture instruction set. It is recommended to have the CPU support at least AVX2 instruction set and recommend AMX instruction set. + +DeepSpeed has been verified on the following CPU processors: +* 4th Gen Intel® Xeon® Scalarable Processors +* 5th Gen Intel® Xeon® Scalarable Processors +* 6th Gen Intel® Xeon® Scalarable Processors + +## Installation steps for Intel Architecture CPU +To install DeepSpeed on Intel Architecture CPU, use the following steps: +1. Install gcc compiler +DeepSpeed requires gcc-9 or above to build kernels on Intel Architecture CPU, install gcc-9 or above. + +2. Install numactl +DeepSpeed use `numactl` for fine grain CPU core allocation for load-balancing, install numactl on your system. +For example, on Ubuntu system, use the following command: +`sudo apt-get install numactl` + +3. Install PyTorch +`pip install torch` + +4. Install DeepSpeed +`pip install deepspeed` + +## How to launch DeepSpeed on Intel Architecture CPU +DeepSpeed can launch on Intel Architecture CPU with default deepspeed command. However, for compute intensive workloads, Intel Architecture CPU works best when each worker process runs on different set of physical CPU cores, so worker process does not compete CPU cores with each other. To bind cores to each worker (rank), use the following command line switch for better performance. +``` +deepspeed --bind_cores_to_rank +``` +This switch would automatically detect the number of CPU NUMA node on the host, launch the same number of workers, and bind each worker to cores/memory of a different NUMA node. This improves performance by ensuring workers do not interfere with each other, and that all memory allocation is from local memory. + +If a user wishes to have more control on the number of workers and specific cores that can be used by the workload, user can use the following command line switches. +``` +deepspeed --num_accelerators --bind_cores_to_rank --bind_core_list +``` +For example: +``` +deepspeed --num_accelerators 4 --bind_cores_to_rank --bind_core_list <0-27,32-59> inference.py +``` +This would start 4 workers for the workload. The core list range will be divided evenly between 4 workers, with worker 0 take 0-13, worker 1, take 14-27, worker 2 take 32-45, and worker 3 take 46-59. Core 28-31,60-63 are left out because there might be some background process running on the system, leaving some idle cores will reduce performance jitting and straggler effect. + +Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file. + +``` +# hostfile content should follow the format +# worker-1-hostname slots=<#sockets> +# worker-2-hostname slots=<#sockets> +# ... + +deepspeed --hostfile= --bind_cores_to_rank --launcher impi --master_addr +``` + +## Install with Intel Extension for PyTorch and oneCCL +Although not mandatory, Intel Extension for PyTorch and Intel oneCCL provide better optimizations for LLM models. Intel oneCCL also provide optimization when running LLM model on multi-node. To use DeepSpeed with Intel Extension for PyTorch and oneCCL, use the following steps: +1. Install Intel Extension for PyTorch. This is suggested if you want to get better LLM inference performance on CPU. +`pip install intel-extension-for-pytorch` + +The following steps are to install oneCCL binding for PyTorch. This is suggested if you are running DeepSpeed on multiple CPU node, for better communication performance. On single node with multiple CPU socket, these steps are not needed. + +2. Install Intel oneCCL binding for PyTorch +`python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu` + +3. Install Intel oneCCL, this will be used to build direct oneCCL kernels (CCLBackend kernels) +``` +pip install oneccl-devel +pip install impi-devel +``` +Then set the environment variables for Intel oneCCL (assuming using conda environment). +``` +export CPATH=${CONDA_PREFIX}/include:$CPATH +export CCL_ROOT=${CONDA_PREFIX} +export I_MPI_ROOT=${CONDA_PREFIX} +export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/ccl/cpu:${CONDA_PREFIX}/lib/libfabric:${CONDA_PREFIX}/lib +``` + +## Optimize LLM inference with Intel Extension for PyTorch +Intel Extension for PyTorch compatible with DeepSpeed AutoTP tensor parallel inference. It allows CPU inference to benefit from both DeepSpeed Automatic Tensor Parallelism, and LLM optimizations of Intel Extension for PyTorch. To use Intel Extension for PyTorch, after calling deepspeed.init_inference, call +``` +ipex_model = ipex.llm.optimize(deepspeed_model) +``` +to get model optimized by Intel Extension for PyTorch. + +## More examples for using DeepSpeed on Intel CPU +Refer to [LLM examples](https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/llm) for more code samples of running inference with DeepSpeed on Intel CPU. + + +# Intel XPU +DeepSpeed XPU accelerator supports Intel® discrete GPUs with XPU backend through PyTorch. + +DeepSpeed has been verified on the following GPU products: +* Intel® Data Center GPU Max 1100 +* Intel® Data Center GPU Max 1550 +* Intel® Arc Pro B60 + +## Installation steps for Intel XPU +To install DeepSpeed on Intel XPU, use the following steps: + +1. Install PyTorch with XPU support \ +Install the XPU variant of PyTorch from the official PyTorch repository: +``` +pip install torch --index-url https://download.pytorch.org/whl/xpu +``` + +2. Install the Intel® oneAPI DPC++/C++ Compiler (`icpx`) \ +The `icpx` compiler is required at runtime to JIT-compile DeepSpeed's SYCL kernels (e.g. FusedAdam). + +**Important: The `icpx` version must match the SYCL runtime version bundled with +your PyTorch XPU wheel.** A mismatch between the compiler and runtime versions can +cause symbol resolution errors (e.g. unresolved `__devicelib_*` symbols) or subtle +ABI incompatibilities. + +To find out which SYCL runtime version your PyTorch was built with: +``` +pip show intel-sycl-rt +``` +Then install the **same version** of the Intel® oneAPI DPC++/C++ Compiler. For +example, if `intel-sycl-rt` shows version `2025.3.1`, install oneAPI compiler +version `2025.3`. For download and details, see the +[Intel oneAPI DPC++/C++ Compiler](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compiler.html) +page, or install via the +[Intel oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). + +3. Install DeepSpeed \ +`pip install deepspeed` + +## How to use DeepSpeed on Intel XPU +DeepSpeed can be launched on Intel XPU with the `deepspeed` launch command. Before +launching, activate the oneAPI environment so that `icpx` is on `PATH`: +``` +source /setvars.sh +``` + + +To validate the XPU availability and if the XPU accelerator is correctly chosen, here is an example: +``` +$ python +>>> import torch; print('torch:', torch.__version__) +torch: 2.10.0+xpu +>>> print('XPU available:', torch.xpu.is_available()) +XPU available: True +>>> from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name) +accelerator: xpu +``` + + +# Huawei Ascend NPU + +DeepSpeed has been verified on the following Huawei Ascend NPU products: +* Atlas 300T A2 + +## Installation steps for Huawei Ascend NPU + +The following steps outline the process for installing DeepSpeed on an Huawei Ascend NPU: +1. Install the Huawei Ascend NPU Driver and Firmware +
+ Click to expand + + Before proceeding with the installation, please download the necessary files from [Huawei Ascend NPU Driver and Firmware](https://www.hiascend.com/en/hardware/firmware-drivers/commercial?product=4&model=11). + + The following instructions below are sourced from the [Ascend Community](https://www.hiascend.com/document/detail/en/canncommercial/700/quickstart/quickstart/quickstart_18_0002.html) (refer to the [Chinese version](https://www.hiascend.com/document/detail/zh/canncommercial/700/quickstart/quickstart/quickstart_18_0002.html)): + + - Execute the following command to install the driver: + ``` + ./Ascend-hdk--npu-driver_x.x.x_linux-{arch}.run --full --install-for-all + ``` + + - Execute the following command to install the firmware: + ``` + ./Ascend-hdk--npu-firmware_x.x.x.x.X.run --full + ``` +
+ +2. Install CANN +
+ Click to expand + + Prior to installation, download the [CANN Toolkit](https://www.hiascend.com/en/software/cann/commercial). + + - Install third-party dependencies. + - Ubuntu (The operations are the same for Debian, UOS20, and Linux.) + ``` + apt-get install -y gcc g++ make cmake zlib1g zlib1g-dev openssl libsqlite3-dev libssl-dev libffi-dev unzip pciutils net-tools libblas-dev gfortran libblas3 + ``` + - openEuler (The operations are the same for EulerOS, CentOS, and BC-Linux.) + ``` + yum install -y gcc gcc-c++ make cmake unzip zlib-devel libffi-devel openssl-devel pciutils net-tools sqlite-devel lapack-devel gcc-gfortran + ``` + - Install the required Python dependencies: + ``` + pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions + ``` + - Install the CANN Toolkit. + ``` + ./Ascend-cann-toolkit_x.x.x_linux-{arch}.run --install + ``` +
+ +3. Install PyTorch \ + `pip install torch torch_npu` + +4. Install DeepSpeed \ + `pip install deepspeed` + +You can view the installation results using the `ds_report` command, Here is an example: +``` +-------------------------------------------------- +DeepSpeed C++/CUDA extension op report +-------------------------------------------------- +NOTE: Ops not installed will be just-in-time (JIT) compiled at + runtime if needed. Op compatibility means that your system + meet the required dependencies to JIT install the op. +-------------------------------------------------- +JIT compiled ops requires ninja +ninja .................. [OKAY] +-------------------------------------------------- +op name ................ installed .. compatible +-------------------------------------------------- +deepspeed_not_implemented [NO] ....... [OKAY] +async_io ............... [NO] ....... [OKAY] +cpu_adagrad ............ [NO] ....... [OKAY] +cpu_adam ............... [NO] ....... [OKAY] +cpu_lion ............... [NO] ....... [OKAY] +fused_adam ............. [NO] ....... [OKAY] +transformer_inference .. [NO] ....... [OKAY] +-------------------------------------------------- +DeepSpeed general environment info: +torch install path ............... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/torch'] +torch version .................... 2.2.0 +deepspeed install path ........... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/deepspeed'] +deepspeed info ................... 0.14.4, unknown, unknown +deepspeed wheel compiled w. ...... torch 2.2 +torch_npu install path ........... ['/root/miniconda3/envs/ds/lib/python3.10/site-packages/torch_npu'] +torch_npu version ................ 2.2.0 +ascend_cann version .............. 8.0.RC2.alpha002 +shared memory (/dev/shm) size .... 20.00 GB +``` + +## How to launch DeepSpeed on Huawei Ascend NPU + +To validate the Huawei Ascend NPU availability and if the accelerator is correctly chosen, here is an example(Huawei Ascend NPU detection is automatic starting with DeepSpeed v0.12.6): +``` +>>> import torch +>>> print('torch:',torch.__version__) +torch: 2.2.0 +>>> import torch_npu +>>> print('torch_npu:',torch.npu.is_available(),",version:",torch_npu.__version__) +torch_npu: True ,version: 2.2.0 +>>> from deepspeed.accelerator import get_accelerator +>>> print('accelerator:', get_accelerator()._name) +accelerator: npu +``` + +## Multi-card parallel training using Huawei Ascend NPU + +To perform model training across multiple Huawei Ascend NPU cards using DeepSpeed, see the examples provided in [DeepSpeed Examples](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/training/cifar/cifar10_deepspeed.py). + +# Intel Gaudi +PyTorch models can be run on Intel® Gaudi® AI accelerator using DeepSpeed. Refer to the following user guides to start using DeepSpeed with Intel Gaudi: +* [Getting Started with DeepSpeed](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Getting_Started_with_DeepSpeed/Getting_Started_with_DeepSpeed.html#getting-started-with-deepspeed) +* [DeepSpeed User Guide for Training](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/DeepSpeed_User_Guide/DeepSpeed_User_Guide.html#deepspeed-user-guide) +* [Optimizing Large Language Models](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Optimizing_LLM.html#llms-opt) +* [Inference Using DeepSpeed](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Inference_Using_DeepSpeed.html#deepspeed-inference-user-guide) diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index 09a8e6d56234..b251485f8988 100755 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -27,7 +27,7 @@ ds_report ## Pre-install DeepSpeed Ops -**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed c++/cuda ops. However, this is not required if using the default mode of JIT compilation of ops. +**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed C++/CUDA ops. However, this is not required if using the default mode of JIT compilation of ops. {: .notice--info} Sometimes we have found it useful to pre-install either some or all DeepSpeed @@ -56,16 +56,22 @@ DS_BUILD_FUSED_LAMB=1 pip install deepspeed ``` Available `DS_BUILD` options include: -* `DS_BUILD_OPS` toggles all ops -* `DS_BUILD_CPU_ADAM` builds the CPUAdam op -* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)) -* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op -* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op -* `DS_BUILD_TRANSFORMER` builds the transformer op -* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op -* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op -* `DS_BUILD_UTILS` builds various optimized utilities -* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op +* `DS_BUILD_OPS` toggles all ops. +* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op. +* `DS_BUILD_CCL_COMM` builds the communication collective libs. +* `DS_BUILD_CPU_ADAM` builds the CPUAdam op. +* `DS_BUILD_CPU_LION` builds the CPULion op. +* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)). +* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)). +* `DS_BUILD_FUSED_LION` builds the FusedLion op. +* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op. +* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op. +* `DS_BUILD_QUANTIZER` builds the quantizer op. +* `DS_BUILD_RANDOM_LTD` builds the random ltd op. +* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op. +* `DS_BUILD_TRANSFORMER` builds the transformer op. +* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op. +* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op. To speed up the build-all process, you can parallelize the compilation process with: @@ -75,10 +81,10 @@ DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option This should complete the full build 2-3 times faster. You can adjust `-j` to specify how many cpu-cores are to be used during the build. In the example it is set to 8 cores. -You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, pytorch, python, etc.) +You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, PyTorch, Python, etc.) ```bash -DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel +DS_BUILD_OPS=1 python -m build --wheel --no-isolation --config-setting="--build-option=build_ext" --config-setting="--build-option=-j8" ``` This will create a pypi binary wheel under `dist`, e.g., ``dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` and then you can install it directly on multiple machines, in our example: @@ -100,8 +106,8 @@ pip install . For installs spanning multiple nodes we find it useful to install DeepSpeed using the -[install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh) -script in the repo. This will build a python wheel locally and copy it to all +[install.sh](https://github.com/deepspeedai/DeepSpeed/blob/master/install.sh) +script in the repo. This will build a Python wheel locally and copy it to all the nodes listed in your hostfile (either given via `--hostfile`, or defaults to `/job/hostfile`). @@ -112,7 +118,7 @@ extensions will be loaded form that directory. If you use multiple virtual environments this could be a problem, since by default there is only one `torch_extensions` directory, but different virtual environments may use different setups (e.g., different -python or cuda versions) and then the loading of a CUDA extension built by another environment will +Python or CUDA versions) and then the loading of a CUDA extension built by another environment will fail. Therefore, if you need to you can override the default location with the help of the `TORCH_EXTENSIONS_DIR` environment variable. So in each virtual environment you can point it to a unique directory and DeepSpeed will use it to save and load CUDA extensions. @@ -123,6 +129,16 @@ fail. Therefore, if you need to you can override the default location with the h TORCH_EXTENSIONS_DIR=./torch-extensions deepspeed ... ``` +### Conda environment for building from source + +If you encounter difficulties during compilation using the default system environment, you can try the conda environment provided, which includes the necessary compilation toolchain and PyTorch. + +```bash +conda env create -n deepspeed -f environment.yml --force +``` + +and try above install commands after activating it. + ## Building for the correct architectures If you're getting the following error: @@ -130,9 +146,9 @@ If you're getting the following error: ``` RuntimeError: CUDA error: no kernel image is available for execution on the device ``` -when running deepspeed, that means that the cuda extensions weren't built for the card you're trying to use it for. +when running deepspeed, that means that the CUDA extensions weren't built for the card you're trying to use it for. -When building from source deepspeed will try to support a wide range of architectures, but under jit-mode it'll only +When building from source DeepSpeed will try to support a wide range of architectures, but under jit-mode it'll only support the architectures visible at the time of building. You can build specifically for a desired range of architectures by setting a `TORCH_CUDA_ARCH_LIST` env variable: @@ -143,16 +159,43 @@ TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... It will also make the build faster when you only build for a few architectures. -This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed pytorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the deepspeed build from source - save the log and grep for `-gencode` arguments. +This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed PyTorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the DeepSpeed build from source - save the log and grep for `-gencode` arguments. + +The full list of Nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). + +## CUDA version mismatch + +If you're getting the following error: + +``` +Exception: >- DeepSpeed Op Builder: Installed CUDA version {VERSION} does not match the version torch was compiled with {VERSION}, unable to compile cuda/cpp extensions without a matching cuda version. +``` +You have a misaligned version of CUDA installed compared to the version of CUDA +used to compile Torch. A mismatch in the major version is likely to result in +errors or unexpected behavior. + +The easiest fix for this error is changing the CUDA version installed (check +with `nvcc --version`) or updating the torch version to match the installed +CUDA version (check with `python3 -c "import torch; print(torch.__version__)"`). -The full list of nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). +We only require that the major version matches (e.g., 11.1 and 11.8). However, +note that even a mismatch in the minor version _may still_ result in unexpected +behavior and errors, so it's recommended to match both major and minor versions. +When there's a minor version mismatch, DeepSpeed will log a warning. + +If you want to skip this check and proceed with the mismatched CUDA versions, +use the following environment variable, but beware of unexpected behavior: + +```bash +DS_SKIP_CUDA_CHECK=1 +``` ## Feature specific dependencies Some DeepSpeed features require specific dependencies outside the general dependencies of DeepSpeed. * Python package dependencies per feature/op please -see our [requirements directory](https://github.com/microsoft/DeepSpeed/tree/master/requirements). +see our [requirements directory](https://github.com/deepspeedai/DeepSpeed/tree/master/requirements). * We attempt to keep the system level dependencies to a minimum, however some features do require special system-level packages. Please see our `ds_report` tool output to see if you are missing any system-level packages for a given feature. diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md old mode 100644 new mode 100755 index 6991d5caf925..94f08757e111 --- a/docs/_tutorials/automatic-tensor-parallelism.md +++ b/docs/_tutorials/automatic-tensor-parallelism.md @@ -3,10 +3,13 @@ title: "Automatic Tensor Parallelism for HuggingFace Models" tags: inference --- +> **Note:** This tutorial covers AutoTP for **inference**. For **training** with tensor parallelism and ZeRO optimization, see [Automatic Tensor Parallelism (Training)](/tutorials/autotp-training/). + # Contents * [Introduction](#introduction) * [Example Script](#example-script) * [Launching](#launching) + * [T5 11B Inference Performance Comparison](#t5-11b-inference-performance-comparison) * [OPT 13B Inference Performance Comparison](#opt-13b-inference-performance-comparison) * [Supported Models](#supported-models) * [Unsupported Models](#unsupported-models) @@ -65,7 +68,7 @@ With automatic tensor parallelism, we do not need to provide the injection polic # Example Script -We can observe performance improvement with automatic tensor parallelism using the [inference test suite](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py). The script includes per token latency, bandwidth, throughput and memory checks for comparison. See the [README](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation#deepspeed-huggingface-text-generation-examples) for more information. +We can observe performance improvement with automatic tensor parallelism using the [inference test suite](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py). This script is for testing text-generation models and includes per token latency, bandwidth, throughput and memory checks for comparison. See the [README](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/inference/huggingface/text-generation#deepspeed-huggingface-text-generation-examples) for more information. ## Launching @@ -83,19 +86,31 @@ To enable tensor parallelism, you need to use the flag `ds_inference` for the co deepspeed --num_gpus DeepSpeedExamples/inference/huggingface/text-generation/inference-test.py --name --batch_size --test_performance --ds_inference ``` -## OPT 13B Inference Performance Comparison +## T5 11B Inference Performance Comparison The following results were collected using V100 SXM2 32GB GPUs. -### Max New Tokens = 50 +### Latency -| Test | Memory Allocated per GPU | Max Batch Size | Max Throughput per GPU | -| ---------- | -------------------------- | ---------------- | ------------------------ | -| No TP | 23.94 GB | 64 | 18.84 TFlops | -| 2 GPU TP | 12.23 GB | 320 | 27.17 TFlops | -| 4 GPU TP | 6.36 GB | 664 | 27.63 TFlops | +![T5 Latency Graph](/assets/images/auto-tp-chart-latency.png){: .align-center} + +### Throughput + +![T5 Throughput Graph](/assets/images/auto-tp-chart-throughput.png){: .align-center} + +### Memory + +| Test | Memory Allocated per GPU | Max Batch Size | Max Throughput per GPU | +| -------------- | -------------------------- | -------------- | ---------------------- | +| No TP or 1 GPU | 21.06 GB | 64 | 9.29 TFLOPS | +| 2 GPU TP | 10.56 GB | 320 | 13.04 TFLOPS | +| 4 GPU TP | 5.31 GB | 768 | 14.04 TFLOPS | -### Max New Tokens = 1024 +## OPT 13B Inference Performance Comparison + +The following results were collected using V100 SXM2 32GB GPUs. + +![OPT Throughput Graph](/assets/images/auto-tp-chart-opt-throughput.png){: .align-center} | Test | Memory Allocated per GPU | Max Batch Size | Max Throughput per GPU | | ---------- | -------------------------- | ---------------- | ------------------------ | @@ -108,42 +123,62 @@ The following results were collected using V100 SXM2 32GB GPUs. The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet. - albert +- arctic +- baichuan - bert - bigbird_pegasus +- bloom - camembert +- chatglm2 +- chatglm3 +- codegen +- codellama - deberta_v2 - electra - ernie - esm +- falcon +- glm - gpt-j - gpt-neo - gpt-neox - longt5 - luke +- llama +- llama2 - m2m_100 - marian +- mistral +- mixtral +- mpt - mvp - nezha - openai - opt - pegasus - perceiver +- phi - plbart +- qwen +- qwen2 +- qwen2-moe +- qwen2.5 +- qwen3 - reformer - roberta - roformer - splinter +- starcode - t5 - xglm - xlm_roberta - yoso +- yuan # Unsupported Models The following models are not currently supported with automatic tensor parallelism. They may still be compatible with other DeepSpeed features (e.g., kernel injection for Bloom): -- bloom -- codegen - deberta - flaubert - fsmt diff --git a/docs/_tutorials/autotp-training.md b/docs/_tutorials/autotp-training.md new file mode 100644 index 000000000000..7f8d2cbd52df --- /dev/null +++ b/docs/_tutorials/autotp-training.md @@ -0,0 +1,224 @@ +--- +title: "Automatic Tensor Parallelism (Training)" +tags: training tensor-parallelism +--- + +This tutorial covers **Automatic Tensor Parallelism** for combining tensor parallelism with ZeRO optimization during training. For inference-only tensor parallelism, see [Automatic Tensor Parallelism (Inference)](/tutorials/automatic-tensor-parallelism/). + +## Contents +- [Introduction](#introduction) +- [Quick Start](#quick-start) +- [HuggingFace tp_plan Support](#huggingface-tp_plan-support) +- [Custom Layer Specifications](#custom-layer-specifications) +- [Limitations](#limitations) + +## Introduction + +The AutoTP Training API enables hybrid parallelism by combining: +- **Tensor Parallelism (TP)**: Split model weights across GPUs within a node +- **Data Parallelism (DP)**: Replicate model across GPU groups +- **ZeRO Optimization**: Memory-efficient optimizer states (Stage 0, 1, or 2) + +Tensor parallelism (TP) splits the computations and parameters of large layers +across multiple GPUs so each rank holds only a shard of the weight matrix. This +is an efficient way to train large-scale transformer models by reducing per-GPU +memory pressure while keeping the layer math distributed across the TP group. + + +## Quick Start + +### Basic Usage + +AutoTP training can be enabled entirely through the DeepSpeed config. When +`tensor_parallel` is set in the config, `deepspeed.initialize(...)` applies +AutoTP sharding during engine initialization, so the training loop itself does +not change. + +```python +import torch +import deepspeed + +# 1. Create your model +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") + +# 2. Define the DeepSpeed config with tensor_parallel settings +ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": {"stage": 2}, + "bf16": {"enabled": True}, + "tensor_parallel": {"autotp_size": 4}, +} + +# 3. Initialize DeepSpeed with AutoTP + ZeRO +engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + mpu=mpu # Model parallel unit (optional if you provide tp_group elsewhere) +) + +# 4. Train as usual +for batch in dataloader: + outputs = engine(input_ids=batch["input_ids"], labels=batch["labels"]) + engine.backward(outputs.loss) + engine.step() +``` + +Compatibility note: For backward compatibility, you can still call +`set_autotp_mode(training=True)` and `deepspeed.tp_model_init(...)`, but they +are not required when the DeepSpeed config provides the necessary +`tensor_parallel` settings. + +### Preset-based Sharding + +If your model matches a built-in preset, set `tensor_parallel.preset_model` in the DeepSpeed config: + +```json +{ + "train_batch_size": 8, + "train_micro_batch_size_per_gpu": 1, + "bf16": { "enabled": true }, + "zero_optimization": { "stage": 2 }, + "tensor_parallel": { + "autotp_size": 4, + "preset_model": "llama" + } +} +``` + +For the list of available presets, see [supported models](/code-docs/training#autotp-supported-models). + + + +## HuggingFace tp_plan Support + +Many HuggingFace models (e.g. Llama, Qwen, Gemma2) ship with a built-in +`base_model_tp_plan` in their model config that describes how each layer +should be partitioned for tensor parallelism. DeepSpeed can automatically +detect and use this plan, so you do not need to configure `preset_model` or +`partition_config` for these models. + +When `tensor_parallel` is set in the DeepSpeed config, the initialization +follows this priority: + +1. **Custom `partition_config`** (highest): User-defined regex patterns. +2. **HuggingFace `tp_plan`**: Automatically extracted from + `model._tp_plan` or `model.config.base_model_tp_plan`. +3. **AutoTP heuristics** (lowest): Built-in parser based on module structure. + +For models that define a `tp_plan`, you only need a minimal config: + +```json +{ + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { "stage": 2 }, + "bf16": { "enabled": true }, + "tensor_parallel": { "autotp_size": 4 } +} +``` + +DeepSpeed will read the model's `tp_plan` at initialization and convert it to +internal partition rules. Currently `colwise` and `rowwise` partition types +are supported. Additional types defined by HuggingFace (such as +`colwise_rep`, `local_colwise`, `local_rowwise`, etc.) are not yet handled +and will raise an error if encountered. + +If you need to override the model's built-in `tp_plan`, provide a +`partition_config` in the DeepSpeed config -- it takes precedence. + + +## Custom Patterns + +If you are training a custom model, define regex-based patterns and partition rules in `tensor_parallel.partition_config`: + +```json +{ + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +## Custom Layer Specifications + +For models not covered by presets, define custom layer specs: + +```json +{ + "tensor_parallel": { + "autotp_size": 4, + "partition_config": { + "use_default_specs": false, + "layer_specs": [ + { + "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], + "partition_type": "row" + }, + { + "patterns": [".*\\.[qkv]_proj\\.weight$"], + "partition_type": "column" + }, + { + "patterns": [".*\\.gate_up_proj\\.weight$"], + "partition_type": "column", + "shape": [2, -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +### Fused Layers with Unequal Sub-parameters (GQA) + +For Grouped Query Attention with different Q/K/V sizes: + +```json +{ + "tensor_parallel": { + "partition_config": { + "layer_specs": [ + { + "patterns": [".*\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [[q_size, kv_size, kv_size], -1], + "partition_dim": 0 + } + ] + } + } +} +``` + +## Limitations + +1. **ZeRO Stage 3 not supported**: AutoTP currently only works with ZeRO stages 0, 1, and 2. + +2. **TP size must divide model dimensions**: The tensor parallel size must evenly divide the attention head count and hidden dimensions. + + +## See Also + +- [Automatic Tensor Parallelism (Inference)](/tutorials/automatic-tensor-parallelism/) +- [ZeRO Optimization](/tutorials/zero/) +- [DeepSpeed Configuration](/docs/config-json/) diff --git a/docs/_tutorials/autotuning.md b/docs/_tutorials/autotuning.md index 38648daa89f2..2935f38946ac 100644 --- a/docs/_tutorials/autotuning.md +++ b/docs/_tutorials/autotuning.md @@ -8,23 +8,23 @@ Make sure you've read the DeepSpeed tutorials on [Getting Started](https://www.d One pain point in model training is to figure out good performance-relevant configurations such as micro-batch size to fully utilize the hardware and achieve a high throughput number. This configuration exploring process is commonly done manually but is important since model training is repeated many times and benefits from using a good configuration. Not only is the hand-tuning process time-consuming, but the outcome is hardware-dependent. This means that a good configuration on one hardware might not be the best on another different hardware. The user thus has to hand tune the configuration again. With DeepSpeed, there are more configuration parameters that could potentially affect the training speed, thus making it more tedious to manually tune the configuration. -The DeepSpeed Autotuner mitigates this pain point and automatically discovers the optimal DeepSpeed configuration that delivers good training speed. It not only reduces the time and resources users spend on tuning, but also can discover configurations better than hand-tuned methods. In this tutorial, we showcase the usage and benefits of the autotuning feature in DeepSpeed. For more details, please see the [README.md](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning). +The DeepSpeed Autotuner mitigates this pain point and automatically discovers the optimal DeepSpeed configuration that delivers good training speed. It not only reduces the time and resources users spend on tuning, but also can discover configurations better than hand-tuned methods. In this tutorial, we showcase the usage and benefits of the autotuning feature in DeepSpeed. For more details, please see the [README.md](https://github.com/deepspeedai/DeepSpeed/tree/master/deepspeed/autotuning). ## Tuning scope and strategy The DeepSpeed Autotuner uses model information, system information, and heuristics to efficiently tune system knobs that affect compute and memory efficiencies, such as ZeRO optimization stages, micro-batch sizes, and many other ZeRO optimization configurations. Currently, the DeepSpeed Autotuner tunes ZeRO stages, micro-batch size per GPU, and ZeRO configurations (offloading is not yet supported) on top of other configurations such as optimizer, scheduler, fp16 defined by the user in the DeepSpeed configuration file. -Note that ZeRO stages, micro-batch sizes, and other ZeRO configurations to tune are also configurable and can be overwritten by the user through the DeepSpeed configuration file. See [Configuring Tuning Scope](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning#configuring-tuning-scope) for details. +Note that ZeRO stages, micro-batch sizes, and other ZeRO configurations to tune are also configurable and can be overwritten by the user through the DeepSpeed configuration file. See [Configuring Tuning Scope](https://github.com/deepspeedai/DeepSpeed/tree/master/deepspeed/autotuning#configuring-tuning-scope) for details. ## Ease of use DeepSpeed Autotuning is easy to use, requiring no code change from DeepSpeed users. -Compared to the original training script (`deepspeed your_program.py --deepspeed ds_config.json`), invoking the autotuning feature in DeepSpeed only requires setting an `autotuning` flag after the DeepSpeed launcher (see [Usage](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning#usage) for details), and adding `" autotuning": {"enabled": true}` to the DeepSpeed configuration file. Users can further tailor the autotuning process by changing the autotuning configuration in the DeepSpeed configuration JSON file (See [Autotuning Configuration](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning#autotuning-configuration) for details). +Compared to the original training script (`deepspeed your_program.py --deepspeed ds_config.json`), invoking the autotuning feature in DeepSpeed only requires setting an `autotuning` flag after the DeepSpeed launcher (see [Usage](https://github.com/deepspeedai/DeepSpeed/tree/master/deepspeed/autotuning#usage) for details), and adding `" autotuning": {"enabled": true}` to the DeepSpeed configuration file. Users can further tailor the autotuning process by changing the autotuning configuration in the DeepSpeed configuration JSON file (See [Autotuning Configuration](https://github.com/deepspeedai/DeepSpeed/tree/master/deepspeed/autotuning#autotuning-configuration) for details). ## Example -We demonstrate the usage and benefit of autotuning using the training of a 0.77 billion parameter [GPT2-large model](https://huggingface.co/gpt2-large) from Hugging Face on 16 Nvidia V100 GPUs. For more examples, refer to [autotuning](https://github.com/microsoft/DeepSpeedExamples/tree/master/autotuning) in the DeepSpeedExamples repo. Note that autotuning works with any DeepSpeed-accelerated model training, not limited to Hugging Face models. +We demonstrate the usage and benefit of autotuning using the training of a 0.77 billion parameter [GPT2-large model](https://huggingface.co/gpt2-large) from Hugging Face on 16 Nvidia V100 GPUs. For more examples, refer to [autotuning](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/autotuning) in the DeepSpeedExamples repo. Note that autotuning works with any DeepSpeed-accelerated model training, not limited to Hugging Face models. The model has: @@ -119,7 +119,7 @@ Note that the performance metric used in autotuning is calculated using the timi Tuning completed in 0:27:33.988447. Total number of experiments: 13. -As we can see the DeepSpeed Autotuner can select a better than hand-tuned configuration with a reasonable number of experiments. Examples in [Autotuning Hugging Face Examples](https://github.com/microsoft/DeepSpeedExamples/tree/master/autotuning/hf#autotuning-hugging-face-examples) would demonstrate the effectiveness of autotuning across different models. +As we can see the DeepSpeed Autotuner can select a better than hand-tuned configuration with a reasonable number of experiments. Examples in [Autotuning Hugging Face Examples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/autotuning/hf#autotuning-hugging-face-examples) would demonstrate the effectiveness of autotuning across different models. ### DeepSpeed Autotuning with AzureML diff --git a/docs/_tutorials/azure.md b/docs/_tutorials/azure.md index 6c7cded7b27c..1bbfb687d812 100644 --- a/docs/_tutorials/azure.md +++ b/docs/_tutorials/azure.md @@ -13,10 +13,10 @@ The recommended and simplest method to try DeepSpeed on Azure is through [AzureM For AzureML v1 examples, please take a look at easy-to-use examples for Megatron-DeepSpeed, Transformers and CIFAR training [here](https://github.com/Azure/azureml-examples/tree/main/v1/python-sdk/workflows/train/deepspeed). -> Our [Megatron-DeepSpeed](https://github.com/microsoft/megatron-deepspeed) contains the most up to date [recipe](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml) for end-to-end training on AzureML. +> Our [Megatron-DeepSpeed](https://github.com/deepspeedai/megatron-deepspeed) contains the most up to date [recipe](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azureml) for end-to-end training on AzureML. # DeepSpeed on Azure VMs If you don't have access to AzureML or if want to build a custom environments using [Azure virtual machines](https://azure.microsoft.com/en-us/services/virtual-machines/) or Azure VM Scale-Sets ([VMSS](https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/overview)), we are working on easy-to-use cluster setup scripts that will be published in the next few weeks. -If you already have a cluster setup, you can use the [azure recipes](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azure) that can easily be modified to train various model configurations. +If you already have a cluster setup, you can use the [azure recipes](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/azure) that can easily be modified to train various model configurations. diff --git a/docs/_tutorials/bert-finetuning.md b/docs/_tutorials/bert-finetuning.md index 3014be18d682..efb8fa268e29 100755 --- a/docs/_tutorials/bert-finetuning.md +++ b/docs/_tutorials/bert-finetuning.md @@ -10,14 +10,14 @@ In this tutorial we will be adding DeepSpeed to the BingBert model for the SQuAD If you don't already have a copy of the DeepSpeed repository, please clone in now and checkout the DeepSpeedExamples submodule the contains the BingBertSquad -example (DeepSpeedExamples/BingBertSquad) we will be going over in the rest of +example (DeepSpeedExamples/training/BingBertSquad) we will be going over in the rest of this tutorial. ```shell -git clone https://github.com/microsoft/DeepSpeed +git clone https://github.com/deepspeedai/DeepSpeed cd DeepSpeed git submodule update --init --recursive -cd DeepSpeedExamples/BingBertSquad +cd DeepSpeedExamples/training/BingBertSquad ``` ### Pre-requisites diff --git a/docs/_tutorials/bert-pretraining.md b/docs/_tutorials/bert-pretraining.md index cef60540b232..342918de958d 100755 --- a/docs/_tutorials/bert-pretraining.md +++ b/docs/_tutorials/bert-pretraining.md @@ -5,7 +5,7 @@ tags: training pre-training --- **Note:** -On 08/15/2022 we have added another BERT pre-training/fine-tuning example at [github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/bert_with_pile](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/bert_with_pile), which includes a README.md that describes how to use it. Compared to the example described below, the new example in Megatron-DeepSpeed adds supports of ZeRO and tensor-slicing model parallelism (thus support larger model scale), uses a public and richer [Pile dataset](https://github.com/EleutherAI/the-pile) (user can also use their own data), together with some changes to the model architecture and training hyperparameters as described in [this paper](https://arxiv.org/abs/1909.08053). As a result, the BERT models trained by the new example is able to provide better MNLI results than original BERT, but with a slightly different model architecture and larger computation requirements. If you want to train a larger-scale or better quality BERT-style model, we recommend to follow the new example in Megatron-DeepSpeed. If your goal is to strictly reproduce the original BERT model, we recommend to follow the example under DeepSpeedExamples/bing_bert as described below. On the other hand, the tutorial below helps explaining how to integrate DeepSpeed into a pre-training codebase, regardless of which BERT example you use. +On 08/15/2022 we have added another BERT pre-training/fine-tuning example at [github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/bert_with_pile](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/bert_with_pile), which includes a README.md that describes how to use it. Compared to the example described below, the new example in Megatron-DeepSpeed adds supports of ZeRO and tensor-slicing model parallelism (thus support larger model scale), uses a public and richer [Pile dataset](https://github.com/EleutherAI/the-pile) (user can also use their own data), together with some changes to the model architecture and training hyperparameters as described in [this paper](https://arxiv.org/abs/1909.08053). As a result, the BERT models trained by the new example is able to provide better MNLI results than original BERT, but with a slightly different model architecture and larger computation requirements. If you want to train a larger-scale or better quality BERT-style model, we recommend to follow the new example in Megatron-DeepSpeed. If your goal is to strictly reproduce the original BERT model, we recommend to follow the example under DeepSpeedExamples/bing_bert as described below. On the other hand, the tutorial below helps explaining how to integrate DeepSpeed into a pre-training codebase, regardless of which BERT example you use. {: .notice--info} In this tutorial we will apply DeepSpeed to pre-train the BERT @@ -26,7 +26,7 @@ We work from adaptations of [huggingface/transformers](https://github.com/huggingface/transformers) and [NVIDIA/DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples). We have forked this repo under -[DeepSpeedExamples/bing_bert](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert) +[DeepSpeedExamples/bing_bert](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/bing_bert) and made several modifications in their script: * We adopted the modeling code from NVIDIA's BERT under `bing_bert/nvidia/`. @@ -360,7 +360,7 @@ the scripts/json configs in our DeepSpeedExamples repo. Below is a table contain summary of the configurations. Specifically see the `ds_train_bert_bsz64k_seq128.sh` and `ds_train_bert_bsz32k_seq512.sh` scripts for more details in -[DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert). +[DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/bing_bert). | Parameters | 128 Sequence | 512 Sequence | diff --git a/docs/_tutorials/cifar-10.md b/docs/_tutorials/cifar-10.md index 74ee04502f18..2bd06abf0e89 100644 --- a/docs/_tutorials/cifar-10.md +++ b/docs/_tutorials/cifar-10.md @@ -8,21 +8,21 @@ If you haven't already, we advise you to first read through the [Getting Started](/getting-started/) guide before stepping through this tutorial. -In this tutorial we will be adding DeepSpeed to CIFAR-10 model, which is small image classification model. +In this tutorial we will be adding DeepSpeed to the CIFAR-10 model, which is a small image classification model. -First we will go over how to run original CIFAR-10. Then we will proceed step-by-step in enabling this model to run with DeepSpeed. +First we will go over how to run the original CIFAR-10 model. Then we will proceed step-by-step in enabling this model to run with DeepSpeed. ## Running Original CIFAR-10 -Original model code from [CIFAR-10 Tutorial](https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py), We've copied this repo under [DeepSpeedExamples/cifar/](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) and made it available as a submodule. To download, execute: +Original model code from the [CIFAR-10 Tutorial](https://github.com/pytorch/tutorials/blob/main/beginner_source/blitz/cifar10_tutorial.py), We've copied this repo under [DeepSpeedExamples/training/cifar/](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/cifar) and made it available as a submodule. To download, execute: ```bash git submodule update --init --recursive ``` -To install requirements for CIFAR-10: +To install the requirements for the CIFAR-10 model: ```bash cd DeepSpeedExamples/cifar pip install -r requirements.txt @@ -82,14 +82,14 @@ The first step to apply DeepSpeed is adding DeepSpeed arguments to CIFAR-10 mode parser=argparse.ArgumentParser(description='CIFAR') - #data - # cuda + # Data. + # Cuda. parser.add_argument('--with_cuda', default=False, action='store_true', help='use CPU in case there\'s no GPU support') parser.add_argument('--use_ema', default=False, action='store_true', help='whether use exponential moving average') - # train + # Train. parser.add_argument('-b', '--batch_size', default=32, type=int, help='mini-batch size (default: 32)') parser.add_argument('-e', '--epochs', default=30, type=int, @@ -97,7 +97,7 @@ The first step to apply DeepSpeed is adding DeepSpeed arguments to CIFAR-10 mode parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') - # Include DeepSpeed configuration arguments + # Include DeepSpeed configuration arguments. parser = deepspeed.add_config_arguments(parser) args=parser.parse_args() @@ -123,16 +123,16 @@ def initialize(args, collate_fn=None): ``` -Here we initialize DeepSpeed with CIFAR-10 model (`net`), `args`, `parameters` and `trainset`: +Here we initialize DeepSpeed with the CIFAR-10 model (`net`), `args`, `parameters` and `trainset`: ```python parameters = filter(lambda p: p.requires_grad, net.parameters()) args=add_argument() # Initialize DeepSpeed to use the following features - # 1) Distributed model - # 2) Distributed data loader - # 3) DeepSpeed optimizer + # 1) Distributed model. + # 2) Distributed data loader. + # 3) DeepSpeed optimizer. model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=args, model=net, model_parameters=parameters, training_data=trainset) ``` @@ -155,7 +155,7 @@ The `model` returned by `deepspeed.initialize` is the _DeepSpeed Model Engine_ t ```python for i, data in enumerate(trainloader): - # get the inputs; data is a list of [inputs, labels] + # Get the inputs; data is a list of [inputs, labels]. inputs = data[0].to(model_engine.device) labels = data[1].to(model_engine.device) @@ -206,13 +206,13 @@ The next step to use DeepSpeed is to create a configuration JSON file (ds_config ### Run CIFAR-10 Model with DeepSpeed Enabled -To start training CIFAR-10 model with DeepSpeed applied, execute the following command, it will use all detected GPUs by default. +To start training the CIFAR-10 model with DeepSpeed applied, execute the following command, it will use all detected GPUs by default. ```bash deepspeed cifar10_deepspeed.py --deepspeed_config ds_config.json ``` -DeepSpeed usually prints more training details for user to monitor, including training settings, performance statistics and loss trends. +DeepSpeed usually prints more training details for the user to monitor, including training settings, performance statistics and loss trends. ``` deepspeed.pt cifar10_deepspeed.py --deepspeed_config ds_config.json Warning: Permanently added '[192.168.0.22]:42227' (ECDSA) to the list of known hosts. diff --git a/docs/_tutorials/comms-logging.md b/docs/_tutorials/comms-logging.md index b6a352b60f68..c4f6141a5b6c 100644 --- a/docs/_tutorials/comms-logging.md +++ b/docs/_tutorials/comms-logging.md @@ -13,7 +13,7 @@ In this tutorial, we introduce DeepSpeed communication logging and provide examp NOTE: All logging communication calls are synchronized in order to provide accurate timing information. This may hamper performance if your model heavily uses asynchronous communication operations. -Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` in the client code at the completion of training, an epoch, after N training iterations, etc. +Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` or `deepspeed.com.log_summary(show_straggler=True)` in the client code at the completion of training, an epoch, after N training iterations, etc. ## Usage @@ -64,7 +64,7 @@ The steps to add DeepSpeed communication log summaries are as follows: 2. (Optional) If your application contains `torch.distributed` calls that you wish to log, import `deepspeed.comm` package and modify `torch.distributed` calls to use `deepspeed.comm` (Note: The `deepspeed.comm` collective and pt2pt APIs exactly match `torch.distributed`) 3. Call `deepspeed.comm.log_summary` -For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) example: +For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/cifar) example: ```python # Step 2: (Optional) Import deepspeed.comm @@ -114,3 +114,14 @@ broadcast | [Caller Func: _broadcast_model] reduce_scatter_tensor | [Caller Func: reduce_scatter_fn] 678.86 MB 80 1527.17 13.94 1211.75 1136.01 ``` + +Straggler effect can be shown by supplying optional argument `show_straggler=True` to `deepspeed.comm.log_summary()` call. Straggler effect is defined as the time a rank waits for the slowest rank to start communication. For each collective, `log_summary` would get the minimum collective time among all ranks, compute straggler effect as follows: + +``` +straggler = sum(t_collectives - allreduce(t_collectives, MIN)) +``` + +Print straggler effect with the following `log_summary` call in the example above: +``` + dist.log_summary(show_straggler=True) +``` diff --git a/docs/_tutorials/curriculum-learning.md b/docs/_tutorials/curriculum-learning.md index 817bf622e851..0b74945d3715 100644 --- a/docs/_tutorials/curriculum-learning.md +++ b/docs/_tutorials/curriculum-learning.md @@ -8,7 +8,7 @@ On 12/12/2022, we released DeepSpeed Data Efficiency Library which provides a mo {: .notice--warning} **Note:** -This tutorial was updated on 10/29/2021. Changes include: 1) A more detailed tuning strategy. 2) Pipeline parallelism support. 3) Token-based learning rate decay. 4) A new GPT-2 example at [github.com/microsoft/Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed). See details below. +This tutorial was updated on 10/29/2021. Changes include: 1) A more detailed tuning strategy. 2) Pipeline parallelism support. 3) Token-based learning rate decay. 4) A new GPT-2 example at [github.com/deepspeedai/Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed). See details below. {: .notice--info} In this tutorial, we introduce DeepSpeed's curriculum learning-based data pipeline, which presents easier or simpler examples earlier during training. By enabling stable training with 8x/4x larger batch size/learning rate (whereas the baseline approach struggles with training divergence), we observe that curriculum learning (based on sequence length) provides stable and 3.3x faster GPT-2 pre-training (tested on 117M and 1.5B parameters), together with better token-wise convergence speed and zero-shot WikiText-103/LAMBADA evaluation results. In addition, since curriculum learning only affects the data pipeline, its benefit is complementary to many DeepSpeed features and other system optimization techniques. For example, curriculum learning is compatible with DeepSpeed's [ZeRO Redundancy Optimizer](/tutorials/zero/), [ZeRO-Offload](/tutorials/zero-offload/), and [3D Parallelism](/tutorials/pipeline/). @@ -37,6 +37,7 @@ Curriculum learning can be used by setting the `curriculum_learning` key in the "loss_scale": 0, "loss_scale_window": 1000, "hysteresis": 2, + "consecutive_hysteresis": false, "min_loss_scale": 1 }, "curriculum_learning": { @@ -113,17 +114,17 @@ After the update on 10/29/2021, now there are two curriculum learning examples f We provide two curriculum learning examples for Megatron-LM GPT-2 pre-training: -The first one is at [Megatron-DeepSpeed/tree/main/examples/curriculum_learning](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/curriculum_learning). This integration is based on a newer Megatron-LM fork, and only this curriculum learning example supports pipeline parallelism. However, as of 10/29/2021, we haven't verified ZeRO-2 and ZeRO-3 on this fork. Overall, we highly recommend you to use this example if your model does not require ZeRO-2/3. +The first one is at [Megatron-DeepSpeed/tree/main/examples_deepspeed/curriculum_learning](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/curriculum_learning). This integration is based on a newer Megatron-LM fork, and only this curriculum learning example supports pipeline parallelism. However, as of 10/29/2021, we haven't verified ZeRO-2 and ZeRO-3 on this fork. Overall, we highly recommend you to use this example if your model does not require ZeRO-2/3. -The second one is at [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning). This integration is based on an older Megatron-LM hard copy that we will eventually deprecate and this curriculum learning example does not support pipeline parallelism. We recommend you to ONLY use this example if your model requires ZeRO-2/3. +The second one is at [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning). This integration is based on an older Megatron-LM hard copy that we will eventually deprecate and this curriculum learning example does not support pipeline parallelism. We recommend you to ONLY use this example if your model requires ZeRO-2/3. Besides the DeepSpeed curriculum learning json configurations described above, there are some other necessary changes on the user side to integrate curriculum learning: ### 2.1 Training data truncation -To enable `seqlen`-based curriculum learning, we need to add the functionality of training data truncation based on the given curriculum sequence length. For the case without pipeline parallelism, it is necessary to add a `curriculum_seqlen` argument in the model's forward pass and use it to perform training data sequence length truncation. For Megatron-LM GPT-2 pre-training, we implement this in `forward()` in [megatron/model/gpt2_model.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py) and in `forward_step()` in [pretrain_gpt2.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py). +To enable `seqlen`-based curriculum learning, we need to add the functionality of training data truncation based on the given curriculum sequence length. For the case without pipeline parallelism, it is necessary to add a `curriculum_seqlen` argument in the model's forward pass and use it to perform training data sequence length truncation. For Megatron-LM GPT-2 pre-training, we implement this in `forward()` in [megatron/model/gpt2_model.py](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py) and in `forward_step()` in [pretrain_gpt2.py](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py). -For the case with pipeline parallelism, due to DeepSpeed engine limitations we cannot inject the `curriculum_seqlen` argument in the forward pass. Instead, we create a duplicate of `deepspeed.runtime.data_pipeline.curriculum_scheduler` on the user side, and use it to retrieve the `curriculum_seqlen`. This implementation can be found in [megatron/training.py](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py). +For the case with pipeline parallelism, due to DeepSpeed engine limitations we cannot inject the `curriculum_seqlen` argument in the forward pass. Instead, we create a duplicate of `deepspeed.runtime.data_pipeline.curriculum_scheduler` on the user side, and use it to retrieve the `curriculum_seqlen`. This implementation can be found in [megatron/training.py](https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/megatron/training.py). ### 2.2 Disable batch size warmup (`--rampup-batch-size`) In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (`seqlen`-based) provides much better training stability than the batch size warmup technique introduced by Open AI GPT-3. So when using curriculum learning you need to remove the `--rampup-batch-size` config in your training script. It's not recommended using both curriculum learning and batch size warmup, because both of them reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. diff --git a/docs/_tutorials/data-efficiency.md b/docs/_tutorials/data-efficiency.md index 329e3bb89e2f..b49974f1fa78 100644 --- a/docs/_tutorials/data-efficiency.md +++ b/docs/_tutorials/data-efficiency.md @@ -20,18 +20,18 @@ Curriculum learning has been successfully applied to various training tasks (see ### 1.3 How to use Curriculum Learning #### 1.3.1 GPT-3 and BERT pretraining -The `examples/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/microsoft/Megatron-DeepSpeed) includes our examples of how to apply curriculum learning to GPT-3 and BERT pretraining. There are 3 steps: data analysis, pretraining, and eval/finetuning. +The `examples_deepspeed/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/deepspeedai/Megatron-DeepSpeed) includes our examples of how to apply curriculum learning to GPT-3 and BERT pretraining. There are 3 steps: data analysis, pretraining, and eval/finetuning. **Data analysis:** Curriculum learning requires a data analysis before pretraining that calculate the difficulty of each data sample (based on the metric provided by user), and build an index that map difficulty value to corresponding data samples. (There are exceptions: for example the truncation-based sequence length metric can be achieved by data postprocessing without data analysis.) We provide a data analyzer to perform the offline CPU-only data analysis. -`examples/data_efficiency/gpt/ds_analyze_*.sh` and `examples/data_efficiency/bert/ds_analyze_*.sh` are example scripts for GPT-3 and BERT's data analysis. Our data analyzer employs a simple Map-Reduce scheme. First, at the Map stage the `ds_analyze_*_data_map.sh` is used to split the dataset and compute the difficulty value for each data sample. User would need to provide a function to compute the metric (we implement ours in `examples/data_efficiency/analyze_data.py`), the raw training dataset, and other configurations such as number of CPU nodes and number of threads per node. Then the data analyzer will automatically splits the dataset based on number of workers, compute the difficulty values in a batched fashion, and write the results to two indexes: one index maps each data sample to its difficulty value, and another index maps each distinct difficulty value to the corresponding samples. Second, at the Reduce stage the `ds_analyze_*_data_reduce.sh` is used to merge the index files produced by all workers. One thing to note is that in order to enable speedup by distribution yet still being able to merge all the output, the Map stage will potentially generate a lot of output files, which is proportional to number of CPU nodes, number of threads per node, and number of possible metric values. Thus to avoid generating too much output files, we recommend to start with a smaller number of nodes/threads (in the output log we provide an estimate required time for users to judge if they want to increase number of workers), and we recommend to limit number of possible difficulty values when designing your difficulty metric (our experience shows that a few thousands of distinct values is already sufficient to enjoy the benefit of curriculum learning). +`examples_deepspeed/data_efficiency/gpt/ds_analyze_*.sh` and `examples_deepspeed/data_efficiency/bert/ds_analyze_*.sh` are example scripts for GPT-3 and BERT's data analysis. Our data analyzer employs a simple Map-Reduce scheme. First, at the Map stage the `ds_analyze_*_data_map.sh` is used to split the dataset and compute the difficulty value for each data sample. User would need to provide a function to compute the metric (we implement ours in `examples_deepspeed/data_efficiency/analyze_data.py`), the raw training dataset, and other configurations such as number of CPU nodes and number of threads per node. Then the data analyzer will automatically splits the dataset based on number of workers, compute the difficulty values in a batched fashion, and write the results to two indexes: one index maps each data sample to its difficulty value, and another index maps each distinct difficulty value to the corresponding samples. Second, at the Reduce stage the `ds_analyze_*_data_reduce.sh` is used to merge the index files produced by all workers. One thing to note is that in order to enable speedup by distribution yet still being able to merge all the output, the Map stage will potentially generate a lot of output files, which is proportional to number of CPU nodes, number of threads per node, and number of possible metric values. Thus to avoid generating too much output files, we recommend to start with a smaller number of nodes/threads (in the output log we provide an estimate required time for users to judge if they want to increase number of workers), and we recommend to limit number of possible difficulty values when designing your difficulty metric (our experience shows that a few thousands of distinct values is already sufficient to enjoy the benefit of curriculum learning). -**Pretraining** `examples/data_efficiency/gpt/pretrain` and `examples/data_efficiency/bert/pretrain` include the example pretraining scripts with curriculum learning feature. Several changes are needed to enable curriculum learning during pretraining: (1) User need to provide a DeepSpeed json config file which includes configurations for curriculum learning (see [list of configuration](/docs/config-json/#data-efficiency) for details). We provide tested example configurations in `examples/data_efficiency/gpt/pretrain/ds_pretrain_gpt_1.3B_dense_run.sh` and `examples/data_efficiency/bert/pretrain/ds_pretrain_bert_336M_run.sh`. (2) When initializing the DeepSpeed engine via `deepspeed.initialize`, user needs to provide the train dataset and use the dataloader returned by the initialization (this dataloader includes the curriculum learning capability). We provide an example implementation of this change in `megatron/training.py` function `setup_model_and_optimizer`. (3) If the curriculum learning metric requires data postprocessing (such as truncation-based sequence length), user needs to use the DeepSpeed engine's `set_data_post_process_func` API to provide the postprocessing function. We provide an example implementation of this change in `megatron/training.py`, `pretrain_bert.py`, and `pretrain_gpt.py`. (4) If the curriculum learning metric requires a custom scheduling strategy (the pacing function), user needs to use the DeepSpeed engine's `set_custom_curriculum_learning_schedule` API to provide the function to update the max accepted difficulty during training. DeepSpeed engine will provide a global train step input to this callback function. +**Pretraining** `examples_deepspeed/data_efficiency/gpt/pretrain` and `examples_deepspeed/data_efficiency/bert/pretrain` include the example pretraining scripts with curriculum learning feature. Several changes are needed to enable curriculum learning during pretraining: (1) User need to provide a DeepSpeed json config file which includes configurations for curriculum learning (see [list of configuration](/docs/config-json/#data-efficiency) for details). We provide tested example configurations in `examples_deepspeed/data_efficiency/gpt/pretrain/ds_pretrain_gpt_1.3B_dense_run.sh` and `examples_deepspeed/data_efficiency/bert/pretrain/ds_pretrain_bert_336M_run.sh`. (2) When initializing the DeepSpeed engine via `deepspeed.initialize`, user needs to provide the train dataset and use the dataloader returned by the initialization (this dataloader includes the curriculum learning capability). We provide an example implementation of this change in `megatron/training.py` function `setup_model_and_optimizer`. (3) If the curriculum learning metric requires data postprocessing (such as truncation-based sequence length), user needs to use the DeepSpeed engine's `set_data_post_process_func` API to provide the postprocessing function. We provide an example implementation of this change in `megatron/training.py`, `pretrain_bert.py`, and `pretrain_gpt.py`. (4) If the curriculum learning metric requires a custom scheduling strategy (the pacing function), user needs to use the DeepSpeed engine's `set_custom_curriculum_learning_schedule` API to provide the function to update the max accepted difficulty during training. DeepSpeed engine will provide a global train step input to this callback function. -**Eval/finetuning** `examples/data_efficiency/gpt/eval/` and `examples/data_efficiency/bert/finetune` include the example scripts for GPT-3 model's zero-/few-shot evaluation and BERT model's finetuning. Our [paper](https://arxiv.org/abs/2212.03597) includes the reference eval/finetune results if you follow our example scripts to perform the pretraining/eval/finetuning. +**Eval/finetuning** `examples_deepspeed/data_efficiency/gpt/eval/` and `examples_deepspeed/data_efficiency/bert/finetune` include the example scripts for GPT-3 model's zero-/few-shot evaluation and BERT model's finetuning. Our [paper](https://arxiv.org/abs/2212.03597) includes the reference eval/finetune results if you follow our example scripts to perform the pretraining/eval/finetuning. #### 1.3.2 GPT-2 finetuning -The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples) includes our examples of how to apply curriculum learning to GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script. For CL metrics that require data analysis (e.g., the vocabulary rarity metric), you need to first use ```data_efficiency/gpt_finetuning/finetune/ds_analyze_gpt_data_*``` to analyze and index the dataset, similar to the GPT-3 pre-training case described above in 1.3.1. +The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/deepspeedai/DeepSpeedExamples) includes our examples of how to apply curriculum learning to GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script. For CL metrics that require data analysis (e.g., the vocabulary rarity metric), you need to first use ```data_efficiency/gpt_finetuning/finetune/ds_analyze_gpt_data_*``` to analyze and index the dataset, similar to the GPT-3 pre-training case described above in 1.3.1. ## 2. Random layerwise token dropping (random-LTD) @@ -44,14 +44,14 @@ When you want to pretrain/fine-tune a transformer-based model, it is always a go ### 2.3 How to use random-LTD #### 2.3.1 GPT-3 and BERT pretraining -The `examples/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/microsoft/Megatron-DeepSpeed) includes our examples of how to apply random-LTD to GPT-3 and BERT pretraining. +The `examples_deepspeed/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/deepspeedai/Megatron-DeepSpeed) includes our examples of how to apply random-LTD to GPT-3 and BERT pretraining. -`examples/data_efficiency/gpt/pretrain` and `examples/data_efficiency/bert/pretrain` include the example pretraining scripts with random-LTD feature. Several changes are needed to enable random-LTD during pretraining: (1) User need to provide a DeepSpeed json config file which includes configurations for random-LTD (see [list of configuration](/docs/config-json/#data-efficiency) for details). We provide tested example configurations in `examples/data_efficiency/gpt/pretrain/ds_pretrain_gpt_1.3B_dense_run.sh` and `examples/data_efficiency/bert/pretrain/ds_pretrain_bert_336M_run.sh`. (2) After initializing the DeepSpeed engine via `deepspeed.initialize`, user needs to use the `convert_to_random_ltd` API to convert and wrap the model layers in order to enable the random-LTD feature. We provide an example implementation of this change in `megatron/training.py` function `setup_model_and_optimizer`. (3) In order for random-LTD to understand the input argument mapping of the forward function, user need to change all the input arguments (except the hidden_states input) into keyword/named argument. For example, in `megatron/model/transformer.py` we changed the forward function from `def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False):` to `def forward(self, hidden_states, attention_mask=None, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False):`. (4) When saving model checkpoints, (especially if the state dictionary has non-traditional structure) user needs to use the `remove_random_ltd_state_dict` API to convert the random-LTD-wrapped layers back to original model layers. We provide an example implementation of this change in `megatron/model/language_model.py`. +`examples_deepspeed/data_efficiency/gpt/pretrain` and `examples_deepspeed/data_efficiency/bert/pretrain` include the example pretraining scripts with random-LTD feature. Several changes are needed to enable random-LTD during pretraining: (1) User need to provide a DeepSpeed json config file which includes configurations for random-LTD (see [list of configuration](/docs/config-json/#data-efficiency) for details). We provide tested example configurations in `examples_deepspeed/data_efficiency/gpt/pretrain/ds_pretrain_gpt_1.3B_dense_run.sh` and `examples_deepspeed/data_efficiency/bert/pretrain/ds_pretrain_bert_336M_run.sh`. (2) After initializing the DeepSpeed engine via `deepspeed.initialize`, user needs to use the `convert_to_random_ltd` API to convert and wrap the model layers in order to enable the random-LTD feature. We provide an example implementation of this change in `megatron/training.py` function `setup_model_and_optimizer`. (3) In order for random-LTD to understand the input argument mapping of the forward function, user need to change all the input arguments (except the hidden_states input) into keyword/named argument. For example, in `megatron/model/transformer.py` we changed the forward function from `def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False):` to `def forward(self, hidden_states, attention_mask=None, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False):`. (4) When saving model checkpoints, (especially if the state dictionary has non-traditional structure) user needs to use the `remove_random_ltd_state_dict` API to convert the random-LTD-wrapped layers back to original model layers. We provide an example implementation of this change in `megatron/model/language_model.py`. For eval/finetuning of the pretrained model, see [previous section](#131-gpt-3-and-bert-pretraining) about how to use our example scripts. #### 2.3.2 GPT-2 and ViT finetuning -The `data_efficiency` directory in our [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples) includes our examples of how to apply random-LTD to GPT-2 and ViT finetuning. +The `data_efficiency` directory in our [DeepSpeedExamples repo](https://github.com/deepspeedai/DeepSpeedExamples) includes our examples of how to apply random-LTD to GPT-2 and ViT finetuning. Just like pretraining case, similar changes are required to enable random-LTD for finetuning: (1) DeepSpeed json config file. (2) Use the `convert_to_random_ltd` API to convert and wrap the model layers. (3) When saving model checkpoints, use the `remove_random_ltd_state_dict` API to convert the random-LTD-wrapped layers back to original model layers. @@ -85,16 +85,16 @@ And the reference final result is: ```shell For run_cifar.sh: -13 epoch at time 480.6546013355255s | researved_length 197 +13 epoch at time 480.6546013355255s | reserved_length 197 iter 5474 | LR [0.0001]| val_acc 97.97000122070312 | layer_token 305784192 ``` ## 3. Composing curriculum learning and random-LTD to achieve more ### 3.1 GPT-3 and BERT pretraining -The `examples/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/microsoft/Megatron-DeepSpeed) includes our examples of how to compose curriculum learning random-LTD, and apply both of them to GPT-3 and BERT pretraining. +The `examples_deepspeed/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/deepspeedai/Megatron-DeepSpeed) includes our examples of how to compose curriculum learning random-LTD, and apply both of them to GPT-3 and BERT pretraining. The changes needed are the same as described in previous two sections, since DeepSpeed Data Efficiency already handles the complexity when composing the two techniques. However, one thing to note is that since both random-LTD and some of the curriculum learning metrics will change the sequence length, it could require some extra code to calculate the effective sequence length at each step. We provide an example implementation of this change in `megatron/training.py` function `train` where we calculate the `actual_seq_length`. #### 3.2 GPT-2 finetuning -The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples) includes our examples of how to compose curriculum learning random-LTD for GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script. +The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/deepspeedai/DeepSpeedExamples) includes our examples of how to compose curriculum learning random-LTD for GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script. diff --git a/docs/_tutorials/datastates-async-checkpointing.md b/docs/_tutorials/datastates-async-checkpointing.md new file mode 100644 index 000000000000..db4064935687 --- /dev/null +++ b/docs/_tutorials/datastates-async-checkpointing.md @@ -0,0 +1,55 @@ +--- +title: "DataStates-LLM Checkpointing Engine" +tags: asynchronous checkpointing for minimizing I/O overheads. +--- +This tutorial will show how to use [DataStates-LLM](https://github.com/DataStates/datastates-llm) for asynchronous checkpointing. DataStates-LLM introduces a lazy asynchronous checkpointing mechanism tailored for LLMs, aiming to minimize I/O overhead and enhance training efficiency. This tutorial provides a guide on integrating DataStates-LLM with the DeepSpeed framework. + +## Overview of DataStates-LLM + +DataStates-LLM is designed to address the challenges of frequent checkpointing in LLM training by introducing a lazy asynchronous multi-level approach. It leverages the immutability of model parameters and optimizer states during forward and backward passes to perform non-blocking data transfers, thereby reducing interference with the training process. This method has demonstrated up to 48x faster checkpointing and 2.2x faster end-to-end training times compared to traditional approaches as outlined in [DataStates-LLM: Lazy Asynchronous Checkpointing for Large Language Models](https://arxiv.org/abs/2406.10707). + +## Prerequisites + +Before integrating DataStates-LLM with DeepSpeed, ensure the following: + +- **DeepSpeed Installation**: DeepSpeed should be installed in your environment. If not, refer to the [DeepSpeed Getting Started Guide](https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/getting-started.md) for installation instructions. + +- **DataStates-LLM Repository**: Access the DataStates-LLM source code from its [GitHub repository](https://github.com/DataStates/datastates-llm) and follow the installation instructions provided therein. + +## Configuring DeepSpeed for DataStates-LLM + +To enable DataStates-LLM's asynchronous checkpointing within DeepSpeed, please modify the `deepspeed_config.json` file to include specific configurations under the `datastates_ckpt` section. Below is an example configuration: + +```json +{ + // ... other DeepSpeed configuration options + "datastates_ckpt": { + "host_cache_size": 16 + } +} +``` + +### Configuration Parameters + +- **`host_cache_size`**: Specifies the amount of pinned host memory (in gigabytes) reserved for asynchronous checkpoint flushing. Adjust this value based on your system's memory capacity and the size of your model checkpoints. + +## Implementing DataStates-LLM in Your Training Script + +After enabling datastates checkpointing the `deepspeed_config.json`, the frequency of checkpointing can be configured by specifying the number of iterations after which the checkpoints should be captured using command-line parameter ` --save-interval`. + +## Limitations and Ongoing Work + +1. DataStates-LLM currently only supports the CUDA runtime on Nvidia-based GPUs. + + +2. DataStates-LLM has only been tested with ZeRO stage-1 without offloading to any other tiers. + + +3. While the checkpoint layout of datastates matches Huggingface's [safetensor](https://huggingface.co/docs/safetensors/) format, due to pickled objects required by DeepSpeed during restart, it is not fully compatible with safetensor library yet. + +4. DataStates-LLM does not yet support universal or elastic checkpointing. + + +## Questions and Support + +Please use the [DataStates-LLM Github repository](https://github.com/DataStates/datastates-llm) for any questions, issues, or feature requests. diff --git a/docs/_tutorials/deepnvme.md b/docs/_tutorials/deepnvme.md new file mode 100644 index 000000000000..a6d4545815dc --- /dev/null +++ b/docs/_tutorials/deepnvme.md @@ -0,0 +1,297 @@ +--- +title: "DeepNVMe" +tags: training inference IO large-model +--- +This tutorial will show how to use [DeepNVMe](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) for data transfers between persistent storage and tensors residing in host or device memory. DeepNVMe improves the performance and efficiency of I/O operations in Deep Learning applications through powerful optimizations built on Non-Volatile Memory Express (NVMe) Solid State Drives (SSDs), Linux Asynchronous I/O (`libaio`), and NVIDIA Magnum IOTM GPUDirect® Storage (GDS). + +## Requirements +Ensure your environment is properly configured to use DeepNVMe. First, you need to install DeepSpeed version >= [0.15.0](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.15.0). Next, ensure that the DeepNVMe operators are available in the DeepSpeed installation. The `async_io` operator is required for any DeepNVMe functionality, while the `gds` operator is required only for GDS functionality. You can confirm availability of each operator by inspecting the output of `ds_report` to check that compatible status is [OKAY]. Below is a snippet of `ds_report` output confirming the availability of both `async_io` and `gds` operators. + +![deepnvme_ops_report](/assets/images/deepnvme_ops_report.png) + +If `async_io` operator is unavailable, you will need to install the appropriate `libaio` library binaries for your Linux flavor. For example, Ubuntu users will need to run `apt install libaio-dev`. In general, you should carefully inspect `ds_report` output for helpful tips such as the following: + +```bash +[WARNING] async_io requires the dev libaio .so object and headers but these were not found. +[WARNING] async_io: please install the libaio-dev package with apt +[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found. +``` + +To enable `gds` operator, you will need to install NVIDIA GDS by consulting the appropriate guide for [bare-metal systems](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/index.html) or Azure VMs (coming soon). + + +## Creating DeepNVMe Handles +DeepNVMe functionality can be accessed through two abstractions: `aio_handle` and `gds_handle`. The `aio_handle` is usable on both host and device tensors. while `gds_handle` works only on CUDA tensors, but is more efficient. The first step to use DeepNVMe is to create a desired handle. `aio_handle` requires `async_io` operator, while `gds_handle` requires both `async_io` and `gds` operators. The following snippets illustrate `aio_handle` and `gds_handle` creation respectively. + +```python +### Create aio_handle +from deepspeed.ops.op_builder import AsyncIOBuilder +aio_handle = AsyncIOBuilder().load().aio_handle() +``` + +```python +### Create gds_handle +from deepspeed.ops.op_builder import GDSBuilder +gds_handle = GDSBuilder().load().gds_handle() +``` + +For simplicity, the above examples illustrate handle creation using default parameters. We expect that handles created with default parameters to provide good performance in most environments. However, you can see [below](#advanced-handle-creation) for advanced handle creation. + +## Using DeepNVMe Handles +`aio_handle` and `gds_handle` provide identical APIs for storing tensors to files or loading tensors from files. A common feature of these APIs is that they take a tensor and a file path as arguments for the desired I/O operation. For best performance, pinned device or host tensors should be used for I/O operations (see [here](#pinned-tensors) for details). For brevity, this tutorial will use `aio_handle` for illustration, but keep in mind that `gds_handle` works similarly. + +You can see the available APIs in a Python shell via tab completion on an `aio_handle` object . This is illustrated using tab completion of `h.`. + +```bash +>python +Python 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] on linux +Type "help", "copyright", "credits" or "license" for more information. +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> h. +h.async_pread( h.free_cpu_locked_tensor( h.get_overlap_events( h.get_single_submit( h.new_cpu_locked_tensor( h.pwrite( h.sync_pread( h.wait( +h.async_pwrite( h.get_block_size( h.get_queue_depth( h.get_intra_op_parallelism( h.pread( h.read( h.sync_pwrite( h.write( +``` +The APIs of interest for performing I/O operations are those named with `pread` and `pwrite` substrings. For brevity, we will focus on the file write APIs, namely `sync_pwrite`, `async_pwrite`, and `pwrite`. We will discuss only `sync_pwrite` and `async_pwrite` below because they are specializations of `pwrite`. + +### Blocking File Write +`sync_pwrite` provides the standard blocking semantics of Python file write. The example below illustrates using `sync_pwrite` to store a 1GB CUDA tensor to a local NVMe file. + +```bash +>>> import os +>>> os.path.isfile('/local_nvme/test_1GB.pt') +False +>>> import torch +>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> h.sync_pwrite(t,'/local_nvme/test_1GB.pt') +>>> os.path.isfile('/local_nvme/test_1GB.pt') +True +>>> os.path.getsize('/local_nvme/test_1GB.pt') +1073741824 + +``` + +### Non-Blocking File Write +An important DeepNVMe optimization is the non-blocking I/O semantics which enables Python threads to overlap computations with I/O operations. `async_pwrite` provides the non-blocking semantics for file writes. The Python thread can later use `wait()` to synchronize with the I/O operation. `async_write` can also be used to submit multiple back-to-back non-blocking I/O operations, of which can then be later blocked on using a single `wait()`. The example below illustrates using `async_pwrite` to store a 1GB CUDA tensor to a local NVMe file. + +```bash +>>> import os +>>> os.path.isfile('/local_nvme/test_1GB.pt') +False +>>> import torch +>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') +>>> h.wait() +1 +>>> os.path.isfile('/local_nvme/test_1GB.pt') +True +>>> os.path.getsize('/local_nvme/test_1GB.pt') +1073741824 +``` + +Warning for non-blocking I/O operations: To avoid data races and corruptions, `.wait()` must be carefully used to serialize the writing of source tensors, and the reading of destination tensors. For example, the following update of `t` during a non-blocking file write is unsafe and could corrupt `/local_nvme/test_1GB.pt`. + +```bash +>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') +>>> t += 1 # <--- Data race; avoid by preceding with `h.wait()` +``` + +Similar safety problems apply to reading the destination tensor of a non-blocking file read without `.wait()` synchronization. + + +### Parallel File Write +An important DeepNVMe optimization is the ability to parallelize individual I/O operations. This optimization is enabled by specifying the desired parallelism degree when constructing a DeepNVMe handle. Subsequent I/O operations with that handle are automatically parallelized over the requested number of host or device threads, as appropriate. I/O parallelism is composable with either the blocking or non-blocking I/O APIs. The example below illustrates 4-way parallelism of a file write using `async_pwrite`. Note the use of `intra_op_parallelism` argument to specify the desired parallelism degree in handle creation. + +```bash +>>> import os +>>> os.path.isfile('/local_nvme/test_1GB.pt') +False +>>> import torch +>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle(intra_op_parallelism=4) +>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') +>>> h.wait() +1 +>>> os.path.isfile('/local_nvme/test_1GB.pt') +True +>>> os.path.getsize('/local_nvme/test_1GB.pt') +1073741824 +``` + +### Pinned Tensors +A key part of DeepNVMe optimizations is using direct memory access (DMA) for I/O operations, which requires that the host or device tensor be pinned. To pin host tensors, you can use mechanisms provided by [Pytorch](https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html) or [DeepSpeed Accelerators](/tutorials/accelerator-abstraction-interface/#tensor-operations). The following example illustrates writing a pinned CPU tensor to a local NVMe file. + +```bash +>>> import os +>>> os.path.isfile('/local_nvme/test_1GB.pt') +False +>>> import torch +>>> t=torch.empty(1024**3, dtype=torch.uint8).pin_memory() +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle() +>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') +>>> h.wait() +1 +>>> os.path.isfile('/local_nvme/test_1GB.pt') +True +>>> os.path.getsize('/local_nvme/test_1GB.pt') +1073741824 +``` + +On the other hand,`gds_handle` provides `new_pinned_device_tensor()` and `pin_device_tensor()` functions for pinning CUDA tensors. The following example illustrates writing a pinned CUDA tensor to a local NVMe file. + +```bash +>>> import os +>>> os.path.isfile('/local_nvme/test_1GB.pt') +False +>>> import torch +>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda() +>>> from deepspeed.ops.op_builder import GDSBuilder +>>> h = GDSBuilder().load().gds_handle() +>>> h.pin_device_tensor(t) +>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt') +>>> h.wait() +1 +>>> os.path.isfile('/local_nvme/test_1GB.pt') +True +>>> os.path.getsize('/local_nvme/test_1GB.pt') +1073741824 +>>> h.unpin_device_tensor(t) +``` + + +## Putting it together +We hope that the above material helps you to get started with DeepNVMe. You can also use the following links to see DeepNVMe usage in real-world Deep Learning applications. + +1. [Parameter swapper](https://github.com/deepspeedai/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117) in [ZeRO-Inference](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) and [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/). +2. [Optimizer swapper](https://github.com/deepspeedai/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L36-L38) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/). +3. [Gradient swapper](https://github.com/deepspeedai/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L41-L43) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/). +4. Simple file read and write [operations](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/deepnvme/file_access/README.md). + + + + +## Acknowledgements +This tutorial has been significantly improved by feedback from [Guanhua Wang](https://github.com/GuanhuaWang), [Masahiro Tanaka](https://github.com/tohtana), and [Stas Bekman](https://github.com/stas00). + +## Appendix + +### Advanced Handle Creation +Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, PCIe, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `intra_op_parallelism`. The `aio_handle` constructor parameters and default values are illustrated below: +```bash +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> help(AsyncIOBuilder().load().aio_handle()) +Help on aio_handle in module async_io object: + +class aio_handle(pybind11_builtins.pybind11_object) + | Method resolution order: + | aio_handle + | pybind11_builtins.pybind11_object + | builtins.object + | + | Methods defined here: + | + | __init__(...) + | __init__(self: async_io.aio_handle, block_size: int = 1048576, queue_depth: int = 128, single_submit: bool = False, overlap_events: bool = False, intra_op_parallelism: int = 1) -> None + | + | AIO handle constructor +``` + +### Performance Tuning +As discussed [earlier](#advanced-handle-creation), achieving peak DeepNVMe performance for a target workload or environment requires using optimally configured `aio_handle` or `gds_handle` handles. For configuration convenience, we provide a utility called `ds_nvme_tune` to automate the discovery of optimal DeepNVMe configurations. `ds_nvme_tune` automatically explores a user-specified or default configuration space and recommends the option that provides the best read and write performance. Below is an example usage of `ds_nvme_tune` to tune `aio_handle` data transfers between GPU memory and a local NVVMe SSD mounted on `/local_nvme`. This example used the default configuration space of `ds_nvme_tune` for tuning. + +```bash +$ ds_nvme_tune --nvme_dir /local_nvme --gpu +Running DeepNVMe performance tuning on ['/local_nvme/'] +Best performance (GB/sec): read = 3.69, write = 3.18 +{ + "aio": { + "single_submit": "false", + "overlap_events": "true", + "intra_op_parallelism": 8, + "queue_depth": 32, + "block_size": 1048576 + } +} +``` +The above tuning was executed on a Lambda workstation equipped with two NVIDIA A6000-48GB GPUs, 252GB of DRAM, and a [CS3040 NVMe 2TB SDD](https://www.pny.com/CS3040-M2-NVMe-SSD?sku=M280CS3040-2TB-RB) with peak read and write speeds of 5.6 GB/s and 4.3 GB/s respectively. The tuning required about four and half minutes. Based on the results, one can expect to achieve read and write transfer speeds of 3.69 GB/sec and 3.18 GB/sec respectively by using an `aio_handle` configured as below. + +```python +>>> from deepspeed.ops.op_builder import AsyncIOBuilder +>>> h = AsyncIOBuilder().load().aio_handle(block_size=1048576, + queue_depth=32, + single_submit=False, + overlap_events=True, + intra_op_parallelism=8) +``` + + +The full command line options of `ds_nvme_tune` can be obtained via the normal `-h` or `--help`. +```bash +usage: ds_nvme_tune [-h] --nvme_dir NVME_DIR [NVME_DIR ...] [--sweep_config SWEEP_CONFIG] [--no_read] [--no_write] [--io_size IO_SIZE] [--gpu] [--gds] [--flush_page_cache] [--log_dir LOG_DIR] [--loops LOOPS] [--verbose] + +options: + -h, --help show this help message and exit + --nvme_dir NVME_DIR [NVME_DIR ...] + Directory in which to perform I/O tests. A writeable directory on a NVMe device. + --sweep_config SWEEP_CONFIG + Performance sweep configuration json file. + --no_read Disable read performance measurements. + --no_write Disable write performance measurements. + --io_size IO_SIZE Number of I/O bytes to read/write for performance measurements. + --gpu Test tensor transfers between GPU device and NVME device. + --gds Run the sweep over NVIDIA GPUDirectStorage operator + --flush_page_cache Page cache will not be flushed and reported read speeds may be higher than actual ***Requires sudo access***. + --log_dir LOG_DIR Output directory for performance log files. Default is ./_aio_bench_logs + --loops LOOPS Count of operation repetitions + --verbose Print debugging information. +``` + +### DeepNVMe APIs +For convenience, we provide listing and brief descriptions of the DeepNVMe APIs. + +#### General I/O APIs +The following functions are used for I/O operations with both `aio_handle` and `gds_handle`. + +Function | Description | +|---|---| +async_pread | Non-blocking file read into tensor | +sync_pread | Blocking file read into tensor | +pread | File read with blocking and non-blocking options | +async_pwrite | Non-blocking file write from tensor | +sync_pwrite | Blocking file write from tensor | +pwrite | File write with blocking and non-blocking options | +wait | Wait for non-blocking I/O operations to complete | + +#### GDS-specific APIs +The following functions are available only for `gds_handle` + +Function | Description +|---|---| +new_pinned_device_tensor | Allocate and pin a device tensor | +free_pinned_device_tensor | Unpin and free a device tensor | +pin_device_tensor | Pin a device tensor | +unpin_device_tensor | unpin a device tensor | + + +#### Handle Settings APIs +The following APIs can be used to probe handle configuration. + +Function | Description +|---|---| +get_queue_depth | Return queue depth setting | +get_single_submit | Return whether single_submit is enabled | +get_intra_op_parallelism | Return I/O parallelism degree | +get_block_size | Return I/O block size setting | +get_overlap_events | Return whether overlap_event is enabled | diff --git a/docs/_tutorials/domino.md b/docs/_tutorials/domino.md new file mode 100644 index 000000000000..e1cb704fc229 --- /dev/null +++ b/docs/_tutorials/domino.md @@ -0,0 +1,6 @@ +--- +title: "Domino" +tags: training +--- + +Domino achieves near-complete communication hiding behind computation for tensor parallel training. Please find our [Domino-tutorial](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/training/DeepSpeed-Domino/README.md) in DeepSpeedExample repo. diff --git a/docs/_tutorials/ds-sequence.md b/docs/_tutorials/ds-sequence.md new file mode 100755 index 000000000000..3f7a44cfb34d --- /dev/null +++ b/docs/_tutorials/ds-sequence.md @@ -0,0 +1,120 @@ +--- +title: "Getting Started with DeepSpeed-Ulysses for Training Transformer Models with Extreme Long Sequences" +tags: training sequence-parallelism +--- + +In this tutorial we describe how to enable DeepSpeed-Ulysses for Megatron-Deepspeed. DeepSpeed-Ulysses is a simple but highly communication and memory efficient mechanism sequence parallelism approach for training of large transformer models with massive sequence lengths. It partitions input tensors along the sequence dimension and uses a communication-efficient all-2-all collective for distributed attention computations. Additionally, DeepSpeed-Ulysses incorporates advanced modeling and system optimizations, such as Flash attention, sparse attention, and ZeRO optimizer, to optimize both computational efficiency and memory usage. Training with DeepSpeed sequence parallelism allows both model size and sequence length to scale near indefinitely unbounded by single GPU memory limitation and at a high fraction of peak compute performance. Currently, DeepSpeed-Ulysses can handle sequences up to 1 million in length (10 times the size of a complete Harry Potter book!) on 64 A100 GPUs. Please read our [DeepSpeed-Ulysses blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses) to learn more! + +If you're interested in a newer version that works with HF Transformers, please see https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism + + +## 1. Installation + +You will need to install DeepSpeed v0.10.2 or higher to use the DeepSpeed Sequence feature. Installing DeepSpeed is as simple as `pip install deepspeed`, [see more details](/tutorials/getting-started/). + + +## 2. How to use DeepSpeed-Ulysses in your application? + +Integrating DS-Seq into your training code is easy, and in this section we describe how to integrate DeepSpeed-Ulysses through our [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) code repo. + + +* **Replace attention module**: First, you need to update your attention module with DeepSpeed-Ulysses DistributedAttention. Here, we use the attention from [Megatron-DeepSpeed ](https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py) which is the causal attention used in GPT-3 like model training. Rewrite the attention block: + +```python +def __init__(): + ... + self.local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type) + self.core_attention = local_attn + ... + +def forward(): + ... + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + ... +``` + +with: + +```python +from deepspeed.sequence.layer import DistributedAttention + +def __init__(): + ... + self.local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type) + self.dist_attn = DistributedAttention(self.local_attn, parallel_state.get_sequence_parallel_group()) + ... + +def forward(): + ... + context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask) + ... + +``` + +* **Add sequence parallel communication group**: Note that DistributedAttention takes `local_attn` and `sequence_parallel_group` as the parameters, where local_attn can be your original attention block. You also need to build the sequence parallel communication group and pass that the DistributedAttention. One way to do this is to build the sequence parallel group at the model initialization stage. + + +```python +def initialize_model_parallel( + ... + sequence_parallel_size, + ... +): + ... + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size + ... + global _SEQUENCE_PARALLEL_GROUP + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, + (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + return _SEQUENCE_PARALLEL_GROUP + +``` + +In the Megatron-DeepSpeed exampele, to enable sequence parallelism, set the degree of parallelism using the --ds-sequence-parallel-size argument. You also need to ensure that the number of attention heads is divisible by this value. +We have prepared scripts for you to quickly get some examples for training GPT-3 like models with very long sequences: + +```shell +Megatron-DeepSpeed/examples_deepspeed/sequence_parallel$ bash ds_pretrain_gpt_1.3B_seq_parallel_32k.sh +Megatron-DeepSpeed/examples_deepspeed/sequence_parallel$ bash ds_pretrain_gpt_30B_seq_parallel_32k.sh +``` + +Please note that our sequence parallelism feature is currently incompatible with Megatron-LM's tensor or pipeline parallelism. + +## 3. Enabling DeepSpeed-Ulysses with FlashAttention? + +DeepSpeed's sequence parallelism can be combined with different types of attention implementations to further improve the memory and compute efficiency of long sequence training: + +`Classic attention`: attention mechanism implemented via PyTorch. + +`FlashAttention`: the implementation from [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135). Enabled by `--use-flash-attn`. + +`FlashAttention + Triton`: FlashAttention in Triton (tested with triton==2.0.0.dev20221202). Enabled by `--use-flash-attn-triton`. + +For the best performance, we recommend using FlashAttention + Triton. Below are the installation steps. Note that FlashAttention is compatible only with NVIDIA Turing, Ampere, Ada, or Hopper GPUs. + +```bash +# install triton +git clone -b legacy-backend https://github.com/openai/triton +cd triton/python/ +pip install cmake +pip install . +``` + +```bash +# install +cd ${WORK_DIR} +git clone -b v1.0.4 https://github.com/HazyResearch/flash-attention +cd flash-attention +python -m pip install . +``` + +You may also want to ensure your model configuration is compliant with FlashAttention's requirements. For instance, to achieve optimal performance, the head size should be divisible by 8. Refer to the FlashAttention documentation for more details. diff --git a/docs/_tutorials/ds4sci_evoformerattention.md b/docs/_tutorials/ds4sci_evoformerattention.md new file mode 100644 index 000000000000..9c3f3e1c6051 --- /dev/null +++ b/docs/_tutorials/ds4sci_evoformerattention.md @@ -0,0 +1,112 @@ +--- +title: "DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models" +tags: training inference +--- + +## 1. What is DS4Sci_EvoformerAttention +`DS4Sci_EvoformerAttention` is a collection of kernels built to scale the [Evoformer](https://www.nature.com/articles/s41586-021-03819-2) computation to larger number of sequences and residuals by reducing the memory footprint and increasing the training speed. + +## 2. When to use DS4Sci_EvoformerAttention +`DS4Sci_EvoformerAttention` is most beneficial when the number of sequences and residuals is large. The forward kernel is optimized to accelerate computation. It is beneficial to use the forward kernel during inference for various attention mechanisms. The associated backward kernel can be used during training to reduce the memory footprint at the cost of some computation. Therefore, it is beneficial to use `DS4Sci_EvoformerAttention` in training for memory-constrained operations such as MSA row-wise attention and MSA column-wise attention. + +## 3. How to use DS4Sci_EvoformerAttention + +### 3.1 Installation + +`DS4Sci_EvoformerAttention` is released as part of DeepSpeed >= 0.10.3. + +`DS4Sci_EvoformerAttention` is implemented based on [CUTLASS](https://github.com/NVIDIA/cutlass). DeepSpeed automatically looks for CUTLASS in the [nvidia-cutlass](https://pypi.org/project/nvidia-cutlass/) Python package, Python environment and CMake prefixes, compiler include path environment variables, a `cutlass` checkout next to DeepSpeed or in the current working directory, and common system install prefixes such as `/usr/local`. +CUTLASS setup detection can be ignored by setting ```CUTLASS_PATH="DS_IGNORE_CUTLASS_DETECTION"```, which is useful if you have a well setup compiler (e.g., compiling in a conda package with cutlass and the cuda compilers installed). +If automatic detection does not find the intended installation, set `CUTLASS_PATH` to either the CUTLASS checkout root or its `include` directory. + +You can always simply clone cutlass next to DeepSpeed: +```shell +git clone https://github.com/NVIDIA/cutlass +``` +The kernels will be compiled when `DS4Sci_EvoformerAttention` is called for the first time. + +`DS4Sci_EvoformerAttention` requires GPUs with compute capability 7.0 or higher +(NVIDIA V100 or later GPUs) and the minimal CUDA version is 11.3. It is +recommended to use CUDA 11.7 or later for better performance. Besides, the +performance of backward kernel on V100 is not as good as on A100 for now. + +The extension checks both requirements and fails if any is not met. To disable +the check (for example cross-compiling in a system without GPUs), set +`DS_IGNORE_CUDA_DETECTION=TRUE`. + +### Multi-Arch Build Behavior + +Evoformer now supports mixed-architecture packaging directly via +`TORCH_CUDA_ARCH_LIST`. + +Example: + +```shell +TORCH_CUDA_ARCH_LIST='7.0;8.0' \ +DS_BUILD_OPS=0 DS_BUILD_EVOFORMER_ATTN=1 \ +pip install -e . +``` + +- `TORCH_CUDA_ARCH_LIST` controls generated CUDA slices (order-independent). +- Targets below `sm_70` are pruned for Evoformer because Tensor Cores are + required. +- `DS_EVOFORMER_GPU_ARCH` is **deprecated** and ignored for Evoformer builds. + Use `TORCH_CUDA_ARCH_LIST` instead. + +Supported dtype matrix by architecture family: + +| Arch family | fp16 | bf16 | +|-------------|------|------| +| Sm70 (Volta) | Yes | No | +| Sm75 (Turing) | Yes | No | +| Sm80+ (Ampere/Ada/Hopper) | Yes | Yes | + +### 3.2 Unit test and benchmark + +The unit test and benchmark are available in the `tests` folder in DeepSpeed repo. You can use the following command to run the unit test and benchmark. + +```shell +pytest -s tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py +python tests/benchmarks/DS4Sci_EvoformerAttention_bench.py +``` + +### 3.3 Applying DS4Sci_EvoformerAttention to your own model + +To use `DS4Sci_EvoformerAttention` in user's own models, you need to import `DS4Sci_EvoformerAttention` from `deepspeed.ops.deepspeed4science`. + +```python +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +``` + +`DS4Sci_EvoformerAttention` supports four attention mechanisms in Evoformer (MSA row-wise, MSA column-wise, and 2 kinds of Triangular) by using different inputs as shown in the following examples. In the examples, we denote the number of sequences as `N_seq` and the number of residuals as `N_res`. The dimension of the hidden states `Dim` and head number `Head` are different among different attention. Note that `DS4Sci_EvoformerAttention` requires the input tensors to be in `torch.float16` or `torch.bfloat16` data type. + +(a) **MSA row-wise attention** builds attention weights for residue pairs and integrates the information from the pair representation as an additional bias term. +```python +# Q, K, V: [Batch, N_seq, N_res, Head, Dim] +# res_mask: [Batch, N_seq, 1, 1, N_res] +# pair_bias: [Batch, 1, Head, N_res, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, pair_bias]) +``` + +(b) **MSA column-wise attention** lets the elements that belong to the same target residue exchange information. +```python +# Q, K, V: [Batch, N_res, N_seq, Head, Dim] +# res_mask: [Batch, N_seq, 1, 1, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask]) +``` + +(c) **Triangular self-attention** updates the pair representation. There are two kinds of Triangular self-attention: around starting and around ending node. Below is the example of triangular self-attention around starting node. The triangular self-attention around ending node is similar. +```python +# Q, K, V: [Batch, N_res, N_res, Head, Dim] +# res_mask: [Batch, N_res, 1, 1, N_res] +# right_edges: [Batch, 1, Head, N_res, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, right_edges]) +``` + +## 4. DS4Sci_EvoformerAttention scientific application + +### 4.1 DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models in OpenFold + +[OpenFold](https://github.com/aqlaboratory/openfold) is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. Training AlphaFold2 incurs a memory explosion problem because it contains several custom Evoformer attention variants that manifest unusually large activations. By leveraging DeepSpeed4Science's DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about the methodology can be found at [our website](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/). + + diff --git a/docs/_tutorials/flops-profiler.md b/docs/_tutorials/flops-profiler.md index 24efc238615a..d4a7496405b9 100644 --- a/docs/_tutorials/flops-profiler.md +++ b/docs/_tutorials/flops-profiler.md @@ -184,7 +184,7 @@ When using DeepSpeed for model training, the profiler can be configured in the d #### Example: Megatron-LM -For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/megatron/Megatron-LM). +For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/megatron/Megatron-LM). An example output of 12-layer Megatron-LM model (`hidden_size = 8192, num_attention_heads = 32, batch_size = 1024, seq_length = 1024`) is shown below. diff --git a/docs/_tutorials/gan.md b/docs/_tutorials/gan.md index 1389c91617dd..db3734fb3b96 100755 --- a/docs/_tutorials/gan.md +++ b/docs/_tutorials/gan.md @@ -16,7 +16,7 @@ Please go through the [original tutorial](https://pytorch.org/tutorials/beginner ## Enabling DeepSpeed -The codes may be obtained [here](https://github.com/microsoft/DeepSpeedExamples/tree/master/gan). +The codes may be obtained [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/gan). ### Argument Parsing @@ -101,7 +101,7 @@ deepspeed gan_deepspeed_train.py --dataset celeba --cuda --deepspeed_config gan_ ## Performance Comparison -We use a total batch size of 64 and perform the training on 16 GPUs for 1 epoch on a DGX-2 node which leads to 3x speed-up. The summary of the the results is given below: +We use a total batch size of 64 and perform the training on 16 GPUs for 1 epoch on a DGX-2 node which leads to 3x speed-up. The summary of the results is given below: - Baseline total wall clock time for 1 epochs is 393 secs diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index eea063171c5c..2c6e27d1319d 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -8,9 +8,10 @@ tags: getting-started ## Installation * Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/). -* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/python-sdk/workflows/train/deepspeed) -* DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/transformers/main_classes/trainer.html#deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed). +* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/cli/jobs/deepspeed) +* DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/docs/transformers/deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed). * DeepSpeed on AMD can be used via our [ROCm images](https://hub.docker.com/r/deepspeed/rocm501/tags), e.g., `docker pull deepspeed/rocm501:ds060_pytorch110`. +* DeepSpeed also supports Intel Xeon CPU, Intel Data Center Max Series XPU, Intel Gaudi HPU, Huawei Ascend NPU etc, please refer to the [accelerator setup guide](/tutorials/accelerator-setup-guide/) @@ -226,6 +227,36 @@ deepspeed --include="worker-2:0,1" \ \ --deepspeed --deepspeed_config ds_config.json ``` +### Launching without passwordless SSH + +DeepSpeed now supports launching training jobs without the need for passwordless SSH. This mode is +particularly useful in cloud environments such as Kubernetes, where flexible container orchestration +is possible, and setting up a leader-worker architecture with passwordless SSH adds unnecessary +complexity. + +To use this mode, you need to run the DeepSpeed command separately on all nodes. The command should +be structured as follows: + +```bash +deepspeed --hostfile=myhostfile --no_ssh --node_rank= \ + --master_addr= --master_port= \ + \ + --deepspeed --deepspeed_config ds_config.json +``` + +- `--hostfile=myhostfile`: Specifies the hostfile that contains information about the nodes and GPUs. +- `--no_ssh`: Enables the no-SSH mode. +- `--node_rank=`: Specifies the rank of the node. This should be a unique integer from 0 to n - 1. +- `--master_addr=`: The address of the leader node (rank 0). +- `--master_port=`: The port of the leader node. + +In this setup, the hostnames in the hostfile do not need to be reachable via passwordless SSH. +However, the hostfile is still required for the launcher to collect information about the environment, +such as the number of nodes and the number of GPUs per node. + +Each node must be launched with a unique `node_rank`, and all nodes must be provided with the address +and port of the leader node (rank 0). This mode causes the launcher to act similarly to the `torchrun` +launcher, as described in the [PyTorch documentation](https://pytorch.org/docs/stable/elastic/run.html). ## Multi-Node Environment Variables @@ -235,7 +266,11 @@ propagate all NCCL and PYTHON related environment variables that are set. If you would like to propagate additional variables you can specify them in a dot-file named `.deepspeed_env` that contains a new-line separated list of `VAR=VAL` entries. The DeepSpeed launcher will look in the local path you are -executing from and also in your home directory (`~/`). +executing from and also in your home directory (`~/`). If you would like to +override the default name of this file or path and name with your own, you +can specify this with the environment variable, `DS_ENV_FILE`. This is +mostly useful if you are launching multiple jobs that all require different +variables. As a concrete example, some clusters require special NCCL variables to set prior to training. The user can simply add these variables to a @@ -281,10 +316,14 @@ local machine to discover the number of slots available. The `--include` and `--exclude` arguments work as normal, but the user should specify 'localhost' as the hostname. -Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control -which devices should be used. For example, to use only gpu1 of the current -node, do: +Also note that `CUDA_VISIBLE_DEVICES` can be used with `deepspeed` to control +which devices should be used on a single node. So either of these would work +to launch just on devices 0 and 1 of the current node: + +```bash +deepspeed --include localhost:0,1 ... +``` ```bash -deepspeed --include localhost:1 ... +CUDA_VISIBLE_DEVICES=0,1 deepspeed ... ``` diff --git a/docs/_tutorials/inference-tutorial.md b/docs/_tutorials/inference-tutorial.md index 411d5e756504..ddf287f24b96 100644 --- a/docs/_tutorials/inference-tutorial.md +++ b/docs/_tutorials/inference-tutorial.md @@ -3,15 +3,17 @@ title: "Getting Started with DeepSpeed for Inferencing Transformer based Models" tags: inference --- +>**DeepSpeed-Inference v2 is here and it's called DeepSpeed-FastGen! For the best performance, latest features, and newest model support please see our [DeepSpeed-FastGen release blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen)!** + DeepSpeed-Inference introduces several features to efficiently serve transformer-based PyTorch models. It supports model parallelism (MP) to fit large models that would otherwise not fit in GPU memory. Even for smaller models, MP can be used to reduce latency for inference. To further reduce latency and cost, we introduce inference-customized kernels. Finally, we propose a novel approach to quantize models, called MoQ, to both shrink the model and reduce the inference cost at production. For more details on the inference related optimizations in DeepSpeed, please refer to our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/). -DeepSpeed provides a seamless inference mode for compatible transformer based models trained using DeepSpeed, Megatron, and HuggingFace, meaning that we don’t require any change on the modeling side such as exporting the model or creating a different checkpoint from your trained checkpoints. To run inference on multi-GPU for compatible models, provide the model parallelism degree and the checkpoint information or the model which is already loaded from a checkpoint, and DeepSpeed will do the rest. It will automatically partition the model as necessary, inject compatible high performance kernels into your model and manage the inter-gpu communication. For list of compatible models please see [here](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py). +DeepSpeed provides a seamless inference mode for compatible transformer based models trained using DeepSpeed, Megatron, and HuggingFace, meaning that we don’t require any change on the modeling side such as exporting the model or creating a different checkpoint from your trained checkpoints. To run inference on multi-GPU for compatible models, provide the model parallelism degree and the checkpoint information or the model which is already loaded from a checkpoint, and DeepSpeed will do the rest. It will automatically partition the model as necessary, inject compatible high performance kernels into your model and manage the inter-gpu communication. For list of compatible models please see [here](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py). ## Initializing for Inference For inference with DeepSpeed, use `init_inference` API to load the model for inference. Here, you can specify the MP degree, and if the model has not been loaded with the appropriate checkpoint, you can also provide the checkpoint description using a `json` file or the checkpoint path. -To inject the high-performance kernels, you need to set the `replace_with_kernel_inject` to True for the compatible models. For models not supported by DeepSpeed, the users can submit a PR that defines a new policy in [replace_policy class](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py) that specifies the different parameters of a Transformer layer, such as attention and feed-forward parts. The policy classes in DeepSpeed create a mapping between the parameters of the original user-supplied layer implementation with DeepSpeed's inference-optimized Transformer layer. +To inject the high-performance kernels, you need to set the `replace_with_kernel_inject` to True for the compatible models. For models not supported by DeepSpeed, the users can submit a PR that defines a new policy in [replace_policy class](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py) that specifies the different parameters of a Transformer layer, such as attention and feed-forward parts. The policy classes in DeepSpeed create a mapping between the parameters of the original user-supplied layer implementation with DeepSpeed's inference-optimized Transformer layer. ```python # create the model @@ -19,18 +21,22 @@ if args.pre_load_checkpoint: model = model_class.from_pretrained(args.model_name_or_path) else: model = model_class() + +# create the tokenizer +tokenizer = model_class.from_pretrained(args.model_name_or_path) ... import deepspeed # Initialize the DeepSpeed-Inference engine ds_engine = deepspeed.init_inference(model, - mp_size=2, - dtype=torch.half, - checkpoint=None if args.pre_load_checkpoint else args.checkpoint_json, - replace_with_kernel_inject=True) + tensor_parallel={"tp_size": world_size}, + dtype=torch.half, + checkpoint=None if args.pre_load_checkpoint else args.checkpoint_json, + replace_with_kernel_inject=True) model = ds_engine.module -output = model('Input String') +pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) +output = pipe('Input String') ``` To run inference with only model-parallelism for the models that we don't support kernels, you can pass an injection policy that shows the two specific linear layers on a Transformer Encoder/Decoder layer: 1) the attention output GeMM and 2) layer output GeMM. We need these part of the layer to add the required all-reduce communication between GPUs to merge the partial results across model-parallel ranks. Below, we bring an example that shows how you can use deepspeed-inference with a T5 model: @@ -47,7 +53,7 @@ pipe = pipeline("text2text-generation", model="google/t5-v1_1-small", device=loc # Initialize the DeepSpeed-Inference engine pipe.model = deepspeed.init_inference( pipe.model, - mp_size=world_size, + tensor_parallel={"tp_size": world_size}, dtype=torch.float, injection_policy={T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')} ) @@ -61,7 +67,7 @@ For the models trained using HuggingFace, the model checkpoint can be pre-loaded ```json "checkpoint.json": { - "type": "Megatron", + "type": "Megatron", "version": 0.0, "checkpoints": [ "mp_rank_00/model_optim_rng.pt", @@ -73,9 +79,9 @@ For models that are trained with DeepSpeed, the checkpoint `json` file only requ ```json "checkpoint.json": { - "type": "DeepSpeed", - "version": 0.3, - "checkpoint_path": "path_to_checkpoints", + "type": "ds_model", + "version": 0.0, + "checkpoints": "path_to_checkpoints", } ``` @@ -108,7 +114,7 @@ generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B', generator.model = deepspeed.init_inference(generator.model, - mp_size=world_size, + tensor_parallel={"tp_size": world_size}, dtype=torch.float, replace_with_kernel_inject=True) @@ -140,7 +146,7 @@ model = deepspeed.init_inference(model, checkpoint='./checkpoint.json', dtype=torch.int8, quantization_setting=(quantize_groups, - mlp_exra_grouping) + mlp_extra_grouping) ) ``` diff --git a/docs/_tutorials/large-models-w-deepspeed.md b/docs/_tutorials/large-models-w-deepspeed.md index 21b9956decc2..3d0bae0144b4 100644 --- a/docs/_tutorials/large-models-w-deepspeed.md +++ b/docs/_tutorials/large-models-w-deepspeed.md @@ -28,7 +28,7 @@ Since, ZeRO is a replacement to data parallelism, it offers a seamless integrati ## Deciding which technology to use -**3D Parallelism for GPT-2/GPT-3 like models**: If you are attempting to train a model whose architecture resembles very closely with GPT-2 or GPT-3, then we have already done the hard work of porting 3D parallelism to a GPT-2/GPT-3 architecture-based model and have created a training pipeline that you can use to efficiently train models with hundreds of billion or even trillions of parameters. Both Megatron-Turing NLG 530B and Big Science use a variation of this code base to scale the model training. You can find the code and tutorial to get started in the [DeepSpeed-Megatron GPT-3](https://github.com/microsoft/megatron-deepspeed) repo. For more information on 3D parallelism please chekcout the resources below: +**3D Parallelism for GPT-2/GPT-3 like models**: If you are attempting to train a model whose architecture resembles very closely with GPT-2 or GPT-3, then we have already done the hard work of porting 3D parallelism to a GPT-2/GPT-3 architecture-based model and have created a training pipeline that you can use to efficiently train models with hundreds of billion or even trillions of parameters. Both Megatron-Turing NLG 530B and Big Science use a variation of this code base to scale the model training. You can find the code and tutorial to get started in the [DeepSpeed-Megatron GPT-3](https://github.com/deepspeedai/megatron-deepspeed) repo. For more information on 3D parallelism please checkout the resources below: [3D Parallelism Tutorial](https://www.deepspeed.ai/tutorials/pipeline/) A generic tutorial on how to port your model to use DeepSpeed 3D parallelism @@ -36,7 +36,7 @@ Since, ZeRO is a replacement to data parallelism, it offers a seamless integrati **ZeRO based technologies**: For most training scenarios, ZeRO offer training efficiency that is on par with 3D parallelism without requiring model code refactoring. Therefore, if you do not already have your code ported to use 3D parallelism, we suggest first trying ZeRO lines of technology to see if it fits your need. Adding ZeRO to your training pipeline with DeepSpeed is simple and does not require you to make changes to your model. Given the trivial cost of trying out ZeRO with DeepSpeed, it is the fastest way to evaluate and decide if you should further invest in porting your model to use 3D parallelism. Enabling ZeRO with DeepSpeed also gives you access to ZeRO-Offload and ZeRO-Infinity that can enable fine tuning large models on limited GPU resources. To get started, please checkout our [ZeRO Tutorial](https://www.deepspeed.ai/tutorials/zero/). -For more indepth information on ZeRO lines of technologies, please checkout our papers: +For more in-depth information on ZeRO lines of technologies, please checkout our papers: [ZeRO (SC20)](https://arxiv.org/pdf/1910.02054.pdf), [ZeRO Offload (ATC21) ](https://www.usenix.org/system/files/atc21-ren-jie.pdf), and [ZeRO-Infinity (SC21)](https://arxiv.org/pdf/2104.07857.pdf), diff --git a/docs/_tutorials/lrrt.md b/docs/_tutorials/lrrt.md index 1659ab5bbd4d..80fcfa3c78db 100644 --- a/docs/_tutorials/lrrt.md +++ b/docs/_tutorials/lrrt.md @@ -137,7 +137,7 @@ In our experience these are four most critical parameters of 1Cycle schedules. 1. We chose to use the slower LRRT schedule (`lr_range_test_step_rate=5`) to set `cycle_min_lr` because it achieves the best loss and the faster schedule diverges fairly quickly. -2. We set `cycle_min_lr` to 0.005 even though the plot shows that performance +2. We set `cycle_max_lr` to 0.005 even though the plot shows that performance was still improving at slightly higher learning rate. This is because we observed that if we wait till the maximum learning rate, the model could be at the point of divergence and impossible to recover. diff --git a/docs/_tutorials/megatron.md b/docs/_tutorials/megatron.md index 0ccfd3ec02f1..490abdb60122 100644 --- a/docs/_tutorials/megatron.md +++ b/docs/_tutorials/megatron.md @@ -19,7 +19,7 @@ reduction_** from using DeepSpeed. ## Training GPT-2 with the Original Megatron-LM -We've copied the original model code from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) into DeepSpeed [Megatron-LM](https://github.com/microsoft/Megatron-DeepSpeed) and made it available as a submodule. To download, execute: +We've copied the original model code from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) into DeepSpeed [Megatron-LM](https://github.com/deepspeedai/Megatron-DeepSpeed) and made it available as a submodule. To download, execute: ```bash git submodule update --init --recursive ``` @@ -31,31 +31,31 @@ git submodule update --init --recursive ### Running Unmodified Megatron-LM GPT2 model * For a single GPU run: - - change `scripts/pretrain_gpt2.sh`, set its `--train-data` argument as `"webtext"`. - - run `bash scripts/pretrain_gpt2.sh` + - change `examples/pretrain_gpt.sh`, set its `--train-data` argument as `"webtext"`. + - run `bash examples/pretrain_gpt.sh` * For multiple GPUs and/or nodes run: - - change `scripts/pretrain_gpt2_model_parallel.sh` + - change `examples/pretrain_gpt_distributed_with_mp.sh` - set its `--train-data` argument as `"webtext"` - `GPUS_PER_NODE` indicates how many GPUs per node involved in the testing - `NNODES` indicates how many nodes involved in the testing - - run `bash scripts/pretrain_gpt2_model_parallel.sh` + - run `bash examples/pretrain_gpt_distributed_with_mp.sh` ## Enabling DeepSpeed To use DeepSpeed we will modify three files : -* `arguments.py` : Arguments configurations -* `pretrain_gpt2.py` : Main entry point for training -* `utils.py` : Checkpoint saving and loading utilities +* `megatron/arguments.py` : Arguments configurations +* `pretrain_gpt.py` : Main entry point for training +* `megatron/utils.py` : Checkpoint saving and loading utilities ### Argument Parsing -The first step is to apply DeepSpeed is adding DeepSpeed arguments to +The first step is adding DeepSpeed arguments to Megatron-LM GPT2 model, using `deepspeed.add_config_arguments()` in -`arguments.py`. +`megatron/arguments.py`. ```python def get_args(): @@ -360,9 +360,9 @@ Megatron-LM GPT2 model with DeepSpeed applied, execute the following command to start training. - Single GPU run - - run `bash scripts/ds_pretrain_gpt2.sh` + - run `bash scripts/ds_pretrain_gpt.sh` - Multiple GPUs/Nodes run - - run `bash scripts/ds_zero2_pretrain_gpt2_model_parallel.sh` + - run `bash scripts/ds_zero2_pretrain_gpt_model_parallel.sh` ## DeepSpeed Evaluation using GPT-2 diff --git a/docs/_tutorials/mixed_precision_zeropp.md b/docs/_tutorials/mixed_precision_zeropp.md new file mode 100644 index 000000000000..1cb62d0d17d5 --- /dev/null +++ b/docs/_tutorials/mixed_precision_zeropp.md @@ -0,0 +1,55 @@ +--- +title: "Mixed Precision ZeRO++" +tags: training ZeRO communication-efficiency large-model +--- + +Mixed Precision ZeRO++ (MixZ++) is a set of optimization strategies based on [ZeRO](/tutorials/zero/) and [ZeRO++](/tutorials/zeropp/) to improve the efficiency and reduce memory usage for large model training and inference when users use [Low-Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) training. MixZ++ partitions model parameters across GPUs to reduce footprint and gathers them with quantized communication only when needed similar to its ZeRO and ZeRO++ siblings. Our evaluation indicates MixZ++ increases the training throughput by up to [3.3x](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31) for the Llama-2-70B model running on 128 V100 GPUs. Read our [DeepSpeed Chat Blog](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31), [ZeRO++ blog](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/) and [paper](https://arxiv.org/pdf/2306.10209.pdf) to learn more! + +We recommend that you read the tutorials on [Getting Started](/getting-started/), [ZeRO](/tutorials/zero/) and [Megatron-DeepSpeed](/tutorials/megatron/) before stepping through this tutorial. + +## Key Designs +Mixed Precision ZeRO++ (MixZ++) inherits key designs from [ZeRO++](/tutorials/zeropp/), namely quantized weights (*qwZ*), hierarchical partitioning ZeRO (*hpZ*) but has different applicability: + - *qwZ* applies block-based quantization on frozen weights to reduce memory usage and all-gather communication volume. Compared with ZeRO++, *qwZ* in Mixed Precision ZeRO++ keeps the frozen weights quantized so there is no quantization overhead during runtime and memory usage is reduced. + - *hpZ* eliminates inter-node parameter all-gather communication through data remapping and recomputation. Compared with ZeRO++, *hpZ* in Mixed Precision ZeRO++ applies to both backward and generation passes. + +Collectively, the optimizations bring better scalability and efficiency to LoRA training. Each of the components can be enabled independent of each other and collectively as a group. + +## Enabling Mixed Precision ZeRO++ (MixZ++) + +A ready to go MixZ++ example has been prepared at [MixZ++ example script](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh). If you prefer to manually enable MixZ++ in your pipeline, please refer to the instructions below. + +### DeepSpeed Configuration Changes +An example snippet of deepspeed configurations with all MixZ++ optimization enabled is shown below: +```json +{ + "zero_optimization": { + "stage": 3, + "..." + "zero_quantized_nontrainable_weights": true, + "zero_hpz_partition_size": 16, + "..." + } +} +``` +Note that for multi-node training, the `"zero_hpz_partition_size"` should be set to the number of GPUs per node. For example, if you have 8 GPUs per node, then `"zero_hpz_partition_size"` should be set to 8. For single-node training, the `"zero_hpz_partition_size"` should not be set. + +### Training Script Changes +DeepSpeed engine will identify the LoRA frozen parameters if the LoRA model is passed when DeepSpeed initializes. However, the popular implementation is to initialize a base model and then convert to LoRA model later. In such cases, users need to explicitly call DeepSpeed engine after LoRA model is converted. This is only a 1-line effort. An example snippet of training script is shown below: + +```python +model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model, + optimizer=optimizer, + args=args, + config=ds_config, + lr_scheduler=lr_scheduler, + dist_init_required=True) +# ... +# (the custom code to convert base model to LoRA model) +# ... +# call DeepSpeed engine again to identify LoRA frozen parameters +model.optimizer.quantize_nontrainable_params() +# ... +``` + +Congratulations! You have completed the Mixed Precision ZeRO++ tutorial. diff --git a/docs/_tutorials/mixture-of-experts-inference.md b/docs/_tutorials/mixture-of-experts-inference.md index 7a75c84935d7..675815dd5d57 100644 --- a/docs/_tutorials/mixture-of-experts-inference.md +++ b/docs/_tutorials/mixture-of-experts-inference.md @@ -23,7 +23,7 @@ In this part, we elaborate the usage of MoE inference support in the DeepSpeed l ### Initializing for Inference -For inference with DeepSpeed-MoE, use `init_inference` API to load the DeepSpeed MoE model for inference. Here, you can specify the model-parallelism/tensor-slicing degree (mp_size), expert parallelism degree (ep_size), and number of experts (moe_exeperts). We create various process groups based on minimum of the world\_size (total number of GPUs) and expert parallel size. By using this group, we can partition the experts among expert-parallel GPUs. If number of experts is lower than total number of GPUs, DeepSpeed-MoE leverages expert-slicing for partitioning the expert parameters between the expert-parallel GPUs. Furthermore, if the model has not been loaded with the appropriate checkpoint, you can also provide the checkpoint description using a `json` file or simply pass the `'checkpoint'` path to load the model. To inject the high-performance inference kernels, you can set `replace_with_kernel_inject` to True. +For inference with DeepSpeed-MoE, use `init_inference` API to load the DeepSpeed MoE model for inference. Here, you can specify the model-parallelism/tensor-slicing degree (mp_size), expert parallelism degree (ep_size), and number of experts (moe_experts). We create various process groups based on minimum of the world\_size (total number of GPUs) and expert parallel size. By using this group, we can partition the experts among expert-parallel GPUs. If number of experts is lower than total number of GPUs, DeepSpeed-MoE leverages expert-slicing for partitioning the expert parameters between the expert-parallel GPUs. Furthermore, if the model has not been loaded with the appropriate checkpoint, you can also provide the checkpoint description using a `json` file or simply pass the `'checkpoint'` path to load the model. To inject the high-performance inference kernels, you can set `replace_with_kernel_inject` to True. ```python @@ -54,7 +54,7 @@ output = model('Input String') Here, we show a text-generation example using an MoE model for which we can specify the model-parallel size and number of experts. DeepSpeed inference-engine takes care of creating the different parallelism groups using the tensor-slicing degree, number of experts, and the total number of GPUs used for running the MoE model. Regarding the expert parameters, we first use the expert-parallelism to assign each group of experts to one GPU. If number of GPUs is higher than number of experts, we use expert-slicing to partition each expert vertically/horizontally across the GPUs. -Let's take a look at some of the parameters passed to run our example. Please refer to [DeepSpeed-Example](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples/generate_text.sh) for a complete generate-text inference example. +Let's take a look at some of the parameters passed to run our example. Please refer to [DeepSpeed-Example](https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/examples_deepspeed/generate_text.sh) for a complete generate-text inference example. ```bash @@ -66,7 +66,7 @@ generate_samples_gpt.py \ --num-attention-heads 32 \ --max-position-embeddings 1024 \ --tokenizer-type GPT2BPETokenizer \ - --load $checpoint_path \ + --load $checkpoint_path \ --fp16 \ --ds-inference \ ``` diff --git a/docs/_tutorials/mixture-of-experts-nlg.md b/docs/_tutorials/mixture-of-experts-nlg.md index c88df2df75e0..c4fb072dd82d 100755 --- a/docs/_tutorials/mixture-of-experts-nlg.md +++ b/docs/_tutorials/mixture-of-experts-nlg.md @@ -7,7 +7,7 @@ In this tutorial, we introduce how to apply DeepSpeed Mixture of Experts (MoE) t ## 1. Installation -You would need to install DeepSpeed v0.6.0 or higher to use the MoE feature. The MoE for NLG model examples are in the [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed) repo under the MoE folder. +You would need to install DeepSpeed v0.6.0 or higher to use the MoE feature. The MoE for NLG model examples are in the [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) repo under the MoE folder. ## 2. Training NLG+MoE models @@ -15,7 +15,7 @@ You would need to install DeepSpeed v0.6.0 or higher to use the MoE feature. The To apply MoE to the GPT-style model, we made several changes in Megatron framework, mostly in `megatron/model/` where we add the MoE layers into the model. ### 2.2. Pre-training the Standard MoE model -We provide example training scripts under [examples/MoE](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/MoE) which we used to perform the experiments in our [Blog]({{ site.press_release_v6 }}). There are a few new hyperparameters for standard MoE model: +We provide example training scripts under [examples_deepspeed/MoE](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/MoE) which we used to perform the experiments in our [Blog]({{ site.press_release_v6 }}). There are a few new hyperparameters for standard MoE model: `--num-experts`: the number of experts per MoE layer. In our experiments we set it to 128. Larger number of experts tend to provide better convergence, but it's a diminishing return. @@ -30,7 +30,7 @@ We provide example training scripts under [examples/MoE](https://github.com/micr ### 2.3. Pre-training the PR-MoE model -PR-MoE is a new designed MoE models, standing for Pyramid-Residual-MoE, which improves the parameter efficiency up to 3x as compared to standard MoE. Please see our [Blog]({{ site.press_release_v6 }}) for more details. We provide example training scripts under [examples/MoE](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/MoE). There are a few different hyperparameters for PR-MoE model compared to standard MoE: +PR-MoE is a new designed MoE models, standing for Pyramid-Residual-MoE, which improves the parameter efficiency up to 3x as compared to standard MoE. Please see our [Blog]({{ site.press_release_v6 }}) for more details. We provide example training scripts under [examples_deepspeed/MoE](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/MoE). There are a few different hyperparameters for PR-MoE model compared to standard MoE: `--num-experts`: Instead of providing a single number, to enable Pyramid-MoE, you need to provide a list, whose length is the same as the number of MoE layers. We suggest to use more experts in the latter stage (close to output) of the model. @@ -57,14 +57,14 @@ Regarding training data, we are not able to release our internal data but any pu Table 1: Zero-shot evaluation results (last six columns) for different dense and MoE NLG models. All zero-shot evaluation results use the accuracy metric. ### 2.4. Training MoS with reduced model size -MoS, standing for Mixture-of-Students, is a staged distillation-based technique for compressing large MoE models. MoS further reduces the model size by 12.5%, leading to up 3.7x model size reduction when combined with PR-MoE over the standard MoE. The reduced model size helps reduce the latecy and cost during inference. To train an MoS model, one needs to specify a few additional parameters. We will use PR-MoE as an example: +MoS, standing for Mixture-of-Students, is a staged distillation-based technique for compressing large MoE models. MoS further reduces the model size by 12.5%, leading to up 3.7x model size reduction when combined with PR-MoE over the standard MoE. The reduced model size helps reduce the latency and cost during inference. To train an MoS model, one needs to specify a few additional parameters. We will use PR-MoE as an example: `--mos`: This would enable Mixture-of-Students via knowledge distillation. -`--load-teacher`: This specifies the path to the teacher model checkpoint. This is a mandatory argumentment for using MoS and the teacher model checkpoint can be obtained by either training a standard MoE or the PR-MoE. +`--load-teacher`: This specifies the path to the teacher model checkpoint. This is a mandatory argument for using MoS and the teacher model checkpoint can be obtained by either training a standard MoE or the PR-MoE. `num-layers-teacher`, `--hidden-size-teacher`, `--hidden-size-teacher`, `--num-experts-teacher`: In addition to the teacher model checkpoint path, we also need to specify the model architecture of the teacher model such as its number of layers, hidden dimension size, and the number of experts per MoE layer. In the case of PR-MoE, we need to also provide a list of experts for the teacher model, where we remove a few expert layers from the teacher model. In addition to the new parameters above, we observe that using the teacher PR-MoE during the entire training process may adversely impact the final student model accuracy. In our experiments, we use a staged distillation method by stopping distillation early in the training process (e.g., after 400K steps) and perform optimization only against the standard language modeling loss for the rest of the training. -We provide example training scripts under [examples/MoE](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/MoE). Details of our parameter settings can be found in the example training scripts. The performance results of MoS can be seen from our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/) and our [paper](https://arxiv.org/abs/2201.05596). +We provide example training scripts under [examples_deepspeed/MoE](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/MoE). Details of our parameter settings can be found in the example training scripts. The performance results of MoS can be seen from our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-powers-8x-larger-moe-model-training-with-high-performance/) and our [paper](https://arxiv.org/abs/2201.05596). diff --git a/docs/_tutorials/mixture-of-experts.md b/docs/_tutorials/mixture-of-experts.md index e7739a6a5051..b4a1c2f86d6a 100644 --- a/docs/_tutorials/mixture-of-experts.md +++ b/docs/_tutorials/mixture-of-experts.md @@ -13,7 +13,7 @@ For more details on results and further discussion, please see our press release {: .notice--info} As a simple starting point we will show how to apply DeepSpeed MoE to a cifar10 example. Please refer to -our [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) going forward. +our [cifar10 example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/cifar) going forward. If you are adding MoE to an existing model you can use the snippet below to help guide you: @@ -28,10 +28,10 @@ DeepSpeed MoE supports five different forms of parallelism, and it exploits both | E + D | Expert + Data | Accelerates training throughput by scaling to multiple data parallel groups | | E + Z | Expert + ZeRO-powered data | Partitions the nonexpert parameters to support larger base models | | E + D + M | Expert + Data + Model | Supports massive hidden sizes and even larger base models than E+Z | -| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z | +| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z+M | | E + Z-Off + M | Expert + ZeRO-Offload + Model | Leverages both GPU and CPU memory for large MoE models on limited # of GPUs | -To support different forms of parallelism, we create various process groups inside DeepSpeed. The helper functions that DeepSpeed uses reside in ```deepspeed.utils.groups.py``` +To support different forms of parallelism, we create various process groups inside DeepSpeed. The helper functions that DeepSpeed uses reside in ```deepspeed/utils/groups.py``` Note: The following function has been deprecated now and model training code does not need to call this anymore. @@ -45,7 +45,7 @@ The GPUs (or ranks) participating in an expert-parallel group of size ```ep_size ### MoE layer API -The hidden_size is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. +The `hidden_size` is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. Original model config @@ -63,7 +63,10 @@ Updated with MoE Layers ### Pyramid-Residual MoE -Recently, we proposed a novel [Pyramid-Residual MoE](https://arxiv.org/abs/2201.05596) (PR-MoE) model architecture. To create such an MoE model, the users need to do two additional things: 1) To make a pyramid structure, pass num_experts as a list e.g. [4, 8] and 2) Use the ```use_residual``` flag to indicate that the MoE layer is now a Residual MoE layer. +Recently, we proposed a novel [Pyramid-Residual MoE](https://arxiv.org/abs/2201.05596) (PR-MoE) model architecture. To create such an MoE model, the users need to do two additional things: + +1. To make a pyramid structure, pass `num_experts` as a list e.g. `[4, 8]`. +2. Use the ```use_residual``` flag to indicate that the MoE layer is now a Residual MoE layer. ```python self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=[..], ep_size=ep_size, use_residual=True) @@ -79,13 +82,13 @@ EP_WORLD_SIZE = 2 EXPERTS = [8] ``` -The model code needs to use the deepspeed.moe.layer.MoE API as follows. +The model code needs to use the `deepspeed.moe.layer.MoE` API as follows. ```python self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=EXPERTS, ep_size=EP_WORLD_SIZE) ``` -With the above two commands, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. +With the above code, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. ```python @@ -104,11 +107,11 @@ fc4 = torch.nn.Linear(84, 10) ``` -For a runnable end-to-end example that covers both the standard MoE architecture as well as the PR-MoE model , please look at the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). In addition, see the advanced usage section of this tutorial that links to a more comprehensive example for NLG models. +For a runnable end-to-end example that covers both the standard MoE architecture, as well as the PR-MoE model, please look at the [cifar10 example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/cifar). In addition, see the advanced usage section of this tutorial that links to a more comprehensive example for NLG models. ### Combining ZeRO-Offload and DeepSpeed MoE for very large models -To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). +To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/cifar). The relevant function that creates these param groups is as follows. @@ -124,7 +127,6 @@ def create_moe_param_groups(model): The above param groups can then be fed to the ZeRO stage-2 optimizer as follows. ```python - net = Net() parameters = create_moe_param_groups(net) @@ -135,7 +137,7 @@ model_engine, optimizer, trainloader, __ = deepspeed.initialize( We are working on automating this functionality in the DeepSpeed ZeRO optimizer so the model training code can be simplified further. -To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) with ZeRO-Offload (stage 2) and MoE, please set the ds_config flags +To run the [cifar10 example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/cifar) with ZeRO-Offload (stage 2) and MoE, please set the `ds_config` flags ```json "zero_optimization": { @@ -150,7 +152,7 @@ To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree } ``` -An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in ds_config. +An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in `ds_config`. ```json "fp16": { diff --git a/docs/_tutorials/model-compression.md b/docs/_tutorials/model-compression.md index 20f2e6a6b25b..d11eadc3d726 100644 --- a/docs/_tutorials/model-compression.md +++ b/docs/_tutorials/model-compression.md @@ -25,7 +25,7 @@ If the model is very deep, you may consider using this method. It works much bet Layer reduction can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#layer-reduction)). Users have the freedom to select any depth by `keep_number_layer` and any subset of the network layers by `teacher_layer`. In addition, users also can choose whether to reinitialize the input/output layers from the given model (teacher model) by `other_module_name`. -To apply layer reduction for task-specific compression, we provide an example on how to do so for BERT fine-tuning. Layer reduction is about resetting the depth of network architecture and reinitialization of weight parameters, which happens before the training process. The example includes the following changes to the client code (`model_compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)): +To apply layer reduction for task-specific compression, we provide an example on how to do so for BERT fine-tuning. Layer reduction is about resetting the depth of network architecture and reinitialization of weight parameters, which happens before the training process. The example includes the following changes to the client code (`compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples)): (1) When initial the model, the number of layers in the model config should be the same as `keep_number_layer` in DeepSpeed config JSON file. For Hugging Face BERT example, set `config.num_hidden_layers = ds_config["compression_training"]["layer_reduction"]["keep_number_layer"]`. @@ -33,11 +33,11 @@ To apply layer reduction for task-specific compression, we provide an example on (3) During training, if KD is not used, nothing needs to be done. Otherwise, one needs to consider applying KD with the `teacher_layer` JSON configuration when calculating the difference between teacher’s and student’s output. -One can run our layer reduction example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our layer reduction example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/layer_reduction.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/layer_reduction.sh ``` And the final result is: @@ -49,9 +49,9 @@ Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.834029 To apply layer reduction for task-agnostic compression, we provide an example on how to do so in the GPT pre-training stage. -Step 1: Obtain the latest version of the [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed). +Step 1: Obtain the latest version of the [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed). -Step 2: Enter `Megatron-DeepSpeed/examples/compression` directory. +Step 2: Enter `Megatron-DeepSpeed/examples_deepspeed/compression` directory. Step 3: Run the example bash script such as `ds_pretrain_gpt_125M_dense_cl_kd.sh`. The args related to the pre-training distillation are: @@ -97,17 +97,17 @@ Weight quantization can be enabled and configured using the DeepSpeed config JSO (4)`start_bit` and `target_bit`, to simplify the first experiment we suggest to set them the same such that we apply quantization to the target bit once the iteration reaches `schedule_offset`. -There are two changes to the client code (`model_compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)): +There are two changes to the client code (`compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples)): (1) After initialization of the model, apply `init_compression` function to the model with DeepSpeed JSON configurations. (2) After training, apply `redundancy_clean` function to save the quantized weight. -One can run our weight quantization example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our weight quantization example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/quant_weight.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/quant_weight.sh ``` And the final result is: @@ -130,17 +130,17 @@ It can improve computation efficiency similar to [weight quantization](#12-weigh Activation quantization can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#activation-quantization)). Some of the components are same as weight quantization, such as `schedule_offset` and `quantization_type`. The key configurations we would like to point out are: -(1)`range_calibration`, user has option to set dynamic or static. When using “dynamic”, the activation quantization groups will be automatically set to be token-wise (for Transformer-based models) and image-wise (for CNN-based models). See more in [our ZeroQuant paper](https://arxiv.org/abs/2206.01861) and the code (`deepspeed/compression/basic_layer.py` in [DeepSpeed](https://github.com/microsoft/DeepSpeed)). +(1)`range_calibration`, user has option to set dynamic or static. When using “dynamic”, the activation quantization groups will be automatically set to be token-wise (for Transformer-based models) and image-wise (for CNN-based models). See more in [our ZeroQuant paper](https://arxiv.org/abs/2206.01861) and the code (`deepspeed/compression/basic_layer.py` in [DeepSpeed](https://github.com/deepspeedai/DeepSpeed)). (2)`aq1`/`aq2`, users can expand more groups such as `aq3`, `aq4`, etc. The client code change is the same as [weight quantization](#12-weight-quantization). -One can run our activation quantization example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our activation quantization example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/quant_activation.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/quant_activation.sh ``` And the final result is: @@ -158,7 +158,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener | **Method** | **Type** | | --------------------- | ------------ | -| [Sparse pruning](#141-sparse-pruning) | Unstructured | +| [Sparse pruning](#141-sparse-pruning) | Unstructured and Structured | | [Row pruning](#142-row-pruning) | Structured | | [Head pruning](#143-head-pruning) | Structured | | [Channel pruning](#144-channel-pruning) | Structured | @@ -166,7 +166,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener #### 1.4.1 Sparse Pruning **What is sparse pruning** -Sparse pruning means we set some of the elements in each weight matrix with zero values. There is no structure pattern in the zero values. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626). +Sparse pruning means we set some of the elements in each weight matrix with zero values. Relying on the pruning method user chosen, the zero values may have structured pattern or unstructured pattern. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626). Another way to perform pruning is based on the weights' effect to the loss function when they are masked, see for instance [this paper](https://arxiv.org/abs/1810.02340). **When to use sparse pruning** @@ -178,19 +178,21 @@ Sparse pruning can be enabled and configured using the DeepSpeed config JSON fil (1)`schedule_offset`, we empirically find that when using `method: topk`, it’s better to set the `schedule_offset` to a large value such as 10% of the total training steps. -(2)`method`, we support L1 norm and topk methods. Users are welcome to contribute more methods. +(2)`method`, we support L1 norm, topk and snip_momentum methods. Users are welcome to contribute more methods. -(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc. +(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc. Note this is not needed for snip_momentum method. -(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet. +(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet. for structured sparse pruning like snip_momentum, the dense ratio should be specified in shared_parameters and is used to calculate the global sparsity ratio. + +(5)`frequency`, `block_pattern` and `schedule_offset_end`, they are used to specify the pruning frequency on steps, the block-wise pruning pattern (NxM and N in M), and the end steps for pruning. For snip_momentum method, these configurations are mandatory. The client code change is the same as [weight quantization](#12-weight-quantization). -One can run our sparse pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our sparse pruning example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_sparse.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/pruning_sparse.sh ``` And the final result is: @@ -221,11 +223,11 @@ Row pruning can be enabled and configured using the DeepSpeed config JSON file ( The client code change is the same as [weight quantization](#12-weight-quantization). -One can run our row pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our row pruning example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_row.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/pruning_row.sh ``` And the final result is: @@ -258,11 +260,11 @@ Head pruning can be enabled and configured using the DeepSpeed config JSON file The client code change is the same as [weight quantization](#12-weight-quantization). -One can run our head pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our head pruning example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_head.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/pruning_head.sh ``` And the final result is: @@ -284,11 +286,11 @@ Channel pruning is a feature designed for two back-to-back CONV2d layers (e.g., Channel pruning can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#channel-pruning)). -One can run our channel pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our channel pruning example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell pip install torch torchvision -DeepSpeedExamples/model_compression/cifar$ bash run_compress.sh +DeepSpeedExamples/compression/cifar$ bash run_compress.sh ``` And the final result is: @@ -313,11 +315,11 @@ When you want to quantize the transformer-based model to INT8 or INT4/INT8 forma **How to use ZeroQuant** -One can run our BERT example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by: +One can run our BERT example in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/bert$ bash bash_script/ZeroQuant/zero_quant.sh +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ bash bash_script/ZeroQuant/zero_quant.sh ``` And the final result is: @@ -329,8 +331,8 @@ Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.842791 One can run our GPT example by: ```shell -DeepSpeedExamples/model_compression/gpt2$ pip install -r requirements.txt -DeepSpeedExamples/model_compression/gpt2$ bash bash_script/run_zero_quant.sh +DeepSpeedExamples/compression/gpt2$ pip install -r requirements.txt +DeepSpeedExamples/compression/gpt2$ bash bash_script/run_zero_quant.sh ``` And the final result is: @@ -361,22 +363,22 @@ If you want to significantly compress your models while retaining competitive pe **How to use XTC** -**Installation:** Examples of XTC extreme compression for BERT models are at `model_compression/bert/bash_script/XTC` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples). You will need to install the requirements by: +**Installation:** Examples of XTC extreme compression for BERT models are at `compression/bert/bash_script/XTC` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples). You will need to install the requirements by: ```shell -DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt +DeepSpeedExamples/compression/bert$ pip install -r requirements.txt ``` **Implementation of XTC methods:** To accommodate users who do not have a fine-tuned model or task-specific model for compression, with the arg `--model_name_or_path yoshitomo-matsubara/bert-base-uncased-${TASK_NAME}` our python script `run_glue_no_trainer.py` automatically downloads the models from Hugging Face. Users can also use their own models with better accuracy as the teacher and the student model initialization. ### 3.1 One-bit or Two-bit BERT-base (12-layer) with 8-bit activation quantization -For the configurations, see `model_compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples). In our paper, we used FP32 (`"fp16": {"enabled": false}`) to perform training, while directly applying 8-bit quantization (`"bits": 8`) to the activations and 1-bit quantization (`"start_bits": 1, "target_bits": 1`) to the attention (query, key, val) and feedforward weight matrices (`"modules": ["attention.self", "intermediate", "output.dense"]`) at the beginning of the training (`"schedule_offset": 0`). In addition, we also apply 1-bit quantization to `word_embeddings` as weight quantization. +For the configurations, see `compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples). In our paper, we used FP32 (`"fp16": {"enabled": false}`) to perform training, while directly applying 8-bit quantization (`"bits": 8`) to the activations and 1-bit quantization (`"start_bits": 1, "target_bits": 1`) to the attention (query, key, val) and feedforward weight matrices (`"modules": ["attention.self", "intermediate", "output.dense"]`) at the beginning of the training (`"schedule_offset": 0`). In addition, we also apply 1-bit quantization to `word_embeddings` as weight quantization. One can run this example by: ```shell -DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/quant_1bit.sh +DeepSpeedExamples/compression/bert$ bash bash_script/XTC/quant_1bit.sh ``` And the final result is: @@ -385,7 +387,7 @@ And the final result is: Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8293428425878757/0.8396053702196908 ``` -The other important feature we would like to mention is the `quantize_groups` inside `weight_quantization`, which is set to be 1 here to match our XTC paper's FP32 training setup. We find that under FP16 training, smaller number of quantization group (e.g., 1 or 2) could lead to unstable training. Thus, we recommend using larger number of groups (e.g., 64) under FP16. `model_compression/bert/config/ds_config_W1A8_Qgroup64_fp16.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the FP16 example configurations, where `"fp16": {"enabled": true}` and `"weight_quantization": {"shared_parameters": {"quantize_weight_in_forward": false}}` are different from FP32 case. +The other important feature we would like to mention is the `quantize_groups` inside `weight_quantization`, which is set to be 1 here to match our XTC paper's FP32 training setup. We find that under FP16 training, smaller number of quantization group (e.g., 1 or 2) could lead to unstable training. Thus, we recommend using larger number of groups (e.g., 64) under FP16. `compression/bert/config/ds_config_W1A8_Qgroup64_fp16.json` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) is the FP16 example configurations, where `"fp16": {"enabled": true}` and `"weight_quantization": {"shared_parameters": {"quantize_weight_in_forward": false}}` are different from FP32 case. With this config, we quantize the existing fined-tuned models downloaded from Hugging Face. For 2-bit weight quantization, user needs to update the ds_config JSON file. To give a sense of the compression performance of downloaded models compared to our paper, we collect the results (1/2-bit BERT on MNLI and QQP with 18 training epochs) in table below. The difference between this tutorial and paper is because they use different checkpoints. Data augmentation introduces in [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT) will help significantly for smaller tasks (such as mrpc, rte, sst-b and cola). See more details in [our paper](https://arxiv.org/abs/2206.01859). @@ -397,12 +399,12 @@ This section consists of two parts: (a) we first perform a light-weight layer re **3.2.1 Light-weight Layer Reduction** -`model_compression/bert/config/XTC/ds_config_layer_reduction_fp16.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the example configuration for reducing the 12-layer BERT-base to a 6-layer one. The student’s layers are initialized from i-layer of the teacher with i= [1, 3 ,5 ,7 ,9 ,11] (note that the layer starts from 0), which is called `Skip-BERT_5` in our XTC paper. In addition, student’s modules including embedding, pooler and classifier are also initialized from teacher. For 5-layer layer reduction, one needs to change the configs in `ds_config_layer_reduction_fp16.json` to `"keep_number_layer": 5`, `"teacher_layer": [2, 4 ,6, 8, 10]`(like in `model_compression/bert/config/ds_config_TEMPLATE.json`). +`compression/bert/config/XTC/ds_config_layer_reduction_fp16.json` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) is the example configuration for reducing the 12-layer BERT-base to a 6-layer one. The student’s layers are initialized from i-layer of the teacher with i= [1, 3 ,5 ,7 ,9 ,11] (note that the layer starts from 0), which is called `Skip-BERT_5` in our XTC paper. In addition, student’s modules including embedding, pooler and classifier are also initialized from teacher. For 5-layer layer reduction, one needs to change the configs in `ds_config_layer_reduction_fp16.json` to `"keep_number_layer": 5`, `"teacher_layer": [2, 4 ,6, 8, 10]`(like in `compression/bert/config/ds_config_TEMPLATE.json`). One can run this example by: ```shell -DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/layer_reduction.sh +DeepSpeedExamples/compression/bert$ bash bash_script/XTC/layer_reduction.sh ``` And the final result is: @@ -411,7 +413,7 @@ And the final result is: Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8377992868059093/0.8365541090317331 ``` -Notably, when using one-stage knowledge distillation (`--distill_method one_stage`), the difference between the outputs of teacher and student models (att_loss and rep_loss) also need to be consistent with the initialization. See the function `_kd_function` under `forward_loss` in `model_compression/bert/util.py`. +Notably, when using one-stage knowledge distillation (`--distill_method one_stage`), the difference between the outputs of teacher and student models (att_loss and rep_loss) also need to be consistent with the initialization. See the function `_kd_function` under `forward_loss` in `compression/bert/util.py`. For mnli/qqp, we set `--num_train_epochs 36`, `--learning_rate 5e-5`, and with the JSON config above. The results are given below (we also include the fp16 training results). Using fp32 clearly results in more stable performance than fp16, although fp16 can speed up the training time. @@ -419,12 +421,12 @@ For mnli/qqp, we set `--num_train_epochs 36`, `--learning_rate 5e-5`, and with t **3.2.2 One-bit or Two-bit quantization for 6-layer (5-layer) BERT** -Given the above layer-reduced models ready, we now continue to compress the model with 1/2-bit quantization. `model_compression/bert/config/XTC/ds_config_layer_reduction_W1Q8_fp32.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the example configuration where we set the layer reduction to be true on top of `model_compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json`. In addition to the configuration, we need to update the path for the student model using `--pretrained_dir_student` in the script `model_compression/bert/bash_script/XTC/layer_reduction_1bit.sh`. User can train with a different teacher model by adding `--pretrained_dir_teacher`. +Given the above layer-reduced models ready, we now continue to compress the model with 1/2-bit quantization. `compression/bert/config/XTC/ds_config_layer_reduction_W1Q8_fp32.json` in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) is the example configuration where we set the layer reduction to be true on top of `compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json`. In addition to the configuration, we need to update the path for the student model using `--pretrained_dir_student` in the script `compression/bert/bash_script/XTC/layer_reduction_1bit.sh`. User can train with a different teacher model by adding `--pretrained_dir_teacher`. One can run this example by: ```shell -DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/layer_reduction_1bit.sh +DeepSpeedExamples/compression/bert$ bash bash_script/XTC/layer_reduction_1bit.sh ``` And the final result is: diff --git a/docs/_tutorials/monitor.md b/docs/_tutorials/monitor.md index a9c111f8eeec..5e7a6fc4e834 100644 --- a/docs/_tutorials/monitor.md +++ b/docs/_tutorials/monitor.md @@ -11,7 +11,7 @@ In this tutorial, we introduce the DeepSpeed Monitor and provide examples of its ## Overview -Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), and simple CSV files. +Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=tutorial) and simple CSV files. Below is a live monitoring view for TensorBoard: @@ -21,16 +21,20 @@ Below is a live monitoring view for WandB: ![WandB Example Output](/assets/images/wandb_monitor.PNG){: .align-center} +Below is a live monitoring view for Comet: + +![CometML Example Output](/assets/images/comet_monitor.png){: .align-center} + ## Usage -The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. +The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. - [Automatic Monitoring](#automatic-monitoring) - [Custom Monitoring](#custom-monitoring) ### Automatic Monitoring -When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module-tensorboard-wandb-csv) for details. +When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module) for details. ```json { @@ -45,6 +49,11 @@ When using DeepSpeed for model training, the Monitor can be configured in the De "group": "my_group", "project": "my_project" } + "comet": { + "enabled": true, + "project": "my_project", + "experiment_name": "my_experiment" + } "csv_monitor": { "enabled": true, "output_path": "output/ds_logs/", @@ -72,7 +81,7 @@ The steps to create a custom monitor are as follows: \* Note - Some Monitor backends don't support mixed sample values. Be sure to use your DeepSpeed engine object's `global_samples` attribute in each 3-tuple -For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) example: +For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/cifar) example: ```python # Step 1: Import monitor (and DeepSpeed config, if needed) diff --git a/docs/_tutorials/one-cycle.md b/docs/_tutorials/one-cycle.md index 12967ad56ad5..0b3c8ff0bcf0 100644 --- a/docs/_tutorials/one-cycle.md +++ b/docs/_tutorials/one-cycle.md @@ -42,33 +42,33 @@ of learning rate and momentum because they are correlated hyperparameters. We have leveraged this recommendation to reduce configuration burden by organizing the 1-cycle parameters into two groups: -1. Global parameters for configuring the cycle and decay phase -2. Local parameters for configuring learning rate and momentum +1. Global parameters for configuring the cycle and decay phase. +2. Local parameters for configuring learning rate and momentum. The global parameters for configuring the 1-cycle phases are: -1. `cycle_first_step_size`: The count of training steps to complete first step of cycle phase -2. `cycle_first_stair_count`: The count of updates (or stairs) in first step of cycle phase -3. `cycle_second_step_size`: The count of training steps to complete second step of cycle phase -4. `cycle_second_stair_count`: The count of updates (or stairs) in the second step of cycle phase -5. `post_cycle_decay_step_size`: The interval, in training steps, to decay hyperparameter in decay phase +1. `cycle_first_step_size`: The count of training steps to complete first step of cycle phase. +2. `cycle_first_stair_count`: The count of updates (or stairs) in first step of cycle phase. +3. `cycle_second_step_size`: The count of training steps to complete second step of cycle phase. +4. `cycle_second_stair_count`: The count of updates (or stairs) in the second step of cycle phase. +5. `post_cycle_decay_step_size`: The interval, in training steps, to decay hyperparameter in decay phase. The local parameters for the hyperparameters are: **Learning rate**: -1. `cycle_min_lr`: minimum learning rate in cycle phase -2. `cycle_max_lr`: maximum learning rate in cycle phase -3. `decay_lr_rate`: decay rate for learning rate in decay phase +1. `cycle_min_lr`: Minimum learning rate in cycle phase. +2. `cycle_max_lr`: Maximum learning rate in cycle phase. +3. `decay_lr_rate`: Decay rate for learning rate in decay phase. Although appropriate values `cycle_min_lr` and `cycle_max_lr` values can be selected based on experience or expertise, we recommend using [learning rate range test](/tutorials/lrrt/) feature of DeepSpeed to configure them. **Momentum** -1. `cycle_min_mom`: minimum momentum in cycle phase -2. `cycle_max_mom`: maximum momentum in cycle phase -3. `decay_mom_rate`: decay rate for momentum in decay phase +1. `cycle_min_mom`: Minimum momentum in cycle phase. +2. `cycle_max_mom`: Maximum momentum in cycle phase. +3. `decay_mom_rate`: Decay rate for momentum in decay phase. ## Required Model Configuration Changes @@ -122,9 +122,9 @@ GPU, but was converging slowly to target performance (AUC) when training on 8 GPUs (8X batch size). The plot below shows model convergence with 8 GPUs for these learning rate schedules: -1. **Fixed**: using an optimal fixed learning rate for 1-GPU training. -2. **LinearScale**: using a fixed learning rate that is 8X of **Fixed**. -3. **1Cycle**: using 1-Cycle schedule. +1. **Fixed**: Using an optimal fixed learning rate for 1-GPU training. +2. **LinearScale**: Using a fixed learning rate that is 8X of **Fixed**. +3. **1Cycle**: Using 1-Cycle schedule. ![model_convergence](/assets/images/model_convergence.png) diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index a64439018db4..e24dc8f86554 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -33,7 +33,7 @@ If you don't already have a copy of the DeepSpeed repository, please clone it now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples. ```shell -git clone https://github.com/microsoft/DeepSpeed +git clone https://github.com/deepspeedai/DeepSpeed cd DeepSpeed git submodule update --init --recursive cd DeepSpeedExamples/ @@ -46,7 +46,7 @@ cd DeepSpeedExamples/ In 1-bit Adam v2, we introduce a new system implementation for compressed communication using the NCCL backend of PyTorch distributed. This significantly improves the usability due to NCCL’s integration with PyTorch distributed. The performance of our new NCCL-based implementation is also better than our earlier MPI-based implementation for Ethernet-based systems and on-par for InfiniBand-based systems. Thus we highly recommend users to choose this implementation. **Watch out!** -This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0. +This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0. {: .notice--warning} #### 1.2.2 MPI-based implementation @@ -75,6 +75,12 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation + +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. + +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 1-bit Algorithm The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/2020/09/08/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888). @@ -106,10 +112,10 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_ `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients -Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. +Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. **Watch out!** 1-bit Adam relies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence. @@ -130,7 +136,7 @@ You can also use a pre-trained BERT model checkpoint from either DeepSpeed, [Hug ### 2.1 Running BingBertSQuAD with DeepSpeed and 1-bit Adam -We provide example scripts under [DeepSpeedExamples/BingBertSquad/1-bit_adam/](https://github.com/microsoft/DeepSpeedExamples/tree/master/BingBertSquad/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun. +We provide example scripts under [DeepSpeedExamples/training/BingBertSquad/1-bit_adam/](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/BingBertSquad/1-bit_adam). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun. +* [2026/05] [Using Muon Optimizer with DeepSpeed](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/muon-optimizer/README.md) -# Extreme Speed and Scale for DL Training and Inference +* [2026/05] [System DMA (SDMA) for ZeRO-3: offload collectives off compute units on AMD GPUs for better overlap](https://github.com/deepspeedai/DeepSpeed/blob/master/examples/sdma_allgather/README.md) - ***[DeepSpeed](https://www.deepspeed.ai/) enables world's most powerful language models like [MT-530B](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/) and [BLOOM](https://huggingface.co/blog/bloom-megatron-deepspeed)***. It is an easy-to-use deep learning optimization software suite that powers unprecedented scale and speed for both training and inference. With DeepSpeed you can: +* [2025/12] [DeepSpeed Core API updates: PyTorch-style backward and low-precision master states](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/core_api_update/README.md) -* Train/Inference dense or sparse models with billions or trillions of parameters -* Achieve excellent system throughput and efficiently scale to thousands of GPUs -* Train/Inference on resource constrained GPU systems -* Achieve unprecedented low latency and high thoughput for inference -* Achieve extreme compression for an unparalleled inference latency and model size reduction with low costs +* [2025/10] [SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips](https://pytorch.org/blog/superoffload-unleashing-the-power-of-large-scale-llm-training-on-superchips/) +* [2025/10] [Study of ZenFlow and ZeRO offload performance with DeepSpeed CPU core binding](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/zenflow-corebinding/README.md) -# DeepSpeed has three innovation pillars: -![Three innovation pillars](/assets/images/3pillars.png){: .align-center} + +
+ + More news + +
-DeepSpeed brings together innovations in parallelism technology such as tensor, pipeline, expert and ZeRO-parallelism, and combines them with high performance custom inference kernels, communication optimizations and heterogeneous memory technologies to enable inference at an unprecedented scale, while achieving unparalleled latency, thoughput and cost reduction. This systematic composition of system technologies for inference falls under the DeepSpeed-Inference. Learn more: [DeepSpeed-Inference](https://www.deepspeed.ai/inference) +# Extreme Speed and Scale for DL Training -## DeepSpeed-Compression - -To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the DeepSpeed-Compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression) - -# DeepSpeed Software Suite - -## DeepSpeed Library - - The [DeepSpeed](https://github.com/microsoft/deepspeed) library implements and packages the innovations and technologies in DeepSpeed Training, Inference and Compression Pillars into a single easy-to-use, open-sourced repository. It allows for easy composition of multitude of features within a single training, infernece or compression pipeline. The DeepSpeed Library is heavily adopted by the DL community, and has been used to enable some of the most powerful models (see [DeepSpeed Adoption](#deepspeed-adoption)). - -## Model Implementations for Inference (MII) - - [Model Implementations for Inference (MII)](https://github.com/microsoft/deepspeed-mii) is an open-sourced repository for making low-latency and high-throughput inference accessible to all data scientists by alleviating the need to apply complex system optimization techniques themselves. Out-of-box, MII offers support for thousands of widely used DL models, optimized using DeepSpeed-Inference, that can be deployed with a few lines of code, while achieving significant latency reduction compared to their vanilla open-sourced versions. - -## DeepSpeed on Azure - - DeepSpeed users are diverse and have access to different environments. We recommend to try DeepSpeed on Azure as it is the simplest and easiest method. The recommended method to try DeepSpeed on Azure is through AzureML [recipes](https://github.com/Azure/azureml-examples/tree/main/python-sdk/workflows/train/deepspeed). The job submission and data preparation scripts have been made available [here](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azureml). For more details on how to use DeepSpeed on Azure, please follow the [Azure tutorial](https://www.deepspeed.ai/tutorials/azure/). +***[DeepSpeed](https://www.deepspeed.ai/) enabled the world's most powerful language models (at the time of this writing) such as [MT-530B](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/) and [BLOOM](https://huggingface.co/blog/bloom-megatron-deepspeed)***. DeepSpeed offers a confluence of [system innovations](https://www.deepspeed.ai/training/), that has made large scale DL training effective, and efficient, greatly improved ease of use, and redefined the DL training landscape in terms of scale that is possible. These innovations include ZeRO, 3D-Parallelism, DeepSpeed-MoE, ZeRO-Infinity, etc. # DeepSpeed Adoption -DeepSpeed has been used to train many different large-scale models, below is a list of several examples that we are aware of (if you'd like to include your model please submit a PR): +DeepSpeed has been used to train many different large-scale models. Below is a list of several examples that we are aware of (if you'd like to include your model please submit a PR): * [Megatron-Turing NLG (530B)](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/) * [Jurassic-1 (178B)](https://uploads-ssl.webflow.com/60fd4503684b466578c0d307/61138924626a6981ee09caf6_jurassic_tech_paper.pdf) @@ -77,8 +56,8 @@ DeepSpeed has been integrated with several different popular open-source DL fram | | Documentation | | ---------------------------------------------------------------------------------------------- | -------------------------------------------- | -| | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/main/main_classes/deepspeed) | -| | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/main/en/deepspeed) | +| | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/deepspeed) | +| | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) | | | [Lightning with DeepSpeed](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html) | | | [MosaicML with DeepSpeed](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) | @@ -92,14 +71,13 @@ etc. ## Contributor License Agreement This project welcomes contributions and suggestions. Most contributions require you to -agree to a Contributor License Agreement (CLA) declaring that you have the right to, and -actually do, grant us the rights to use your contribution. For details, visit -https://cla.opensource.microsoft.com. +agree to a Developer Certificate of Origin (DCO)[https://wiki.linuxfoundation.org/dco] +stating that they agree to the terms published at https://developercertificate.org for +that *particular* contribution. -When you submit a pull request, a CLA bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply -follow the instructions provided by the bot. You will only need to do this once across -all repos using our CLA. +DCOs are per-commit, so each commit needs to be signed off. These can be signed in +the commit by adding the `-s` flag. DCO enforcement can also be signed off in the PR +itself by clicking on the DCO enforcement check. ## Code of Conduct This project has adopted the [Microsoft Open Source Code of @@ -112,23 +90,40 @@ comments. 1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: memory optimizations toward training trillion parameter models. [arXiv:1910.02054](https://arxiv.org/abs/1910.02054) and [In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis (SC '20)](https://dl.acm.org/doi/10.5555/3433701.3433727). 2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703). 3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html). -4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840) and [USENIX ATC 2021](https://www.usenix.org/conference/atc21/presentation/ren-jie). +4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840) and [USENIX ATC 2021](https://www.usenix.org/conference/atc21/presentation/ren-jie). [[paper]](https://arxiv.org/abs/2101.06840) [[slides]](https://www.usenix.org/system/files/atc21_slides_ren-jie.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/) 5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888) and [ICML 2021](http://proceedings.mlr.press/v139/tang21a.html). -6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857) and [SC 2021](https://dl.acm.org/doi/abs/10.1145/3458817.3476205). +6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857) and [SC 2021](https://dl.acm.org/doi/abs/10.1145/3458817.3476205). [[paper]](https://arxiv.org/abs/2104.07857) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/SC21-ZeRO-Infinity.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) 7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069) and [HiPC 2022](https://hipc.org/advance-program/). 8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) The Stability-Efficiency Dilemma: Investigating Sequence Length Warmup for Training GPT Models. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084) and [NeurIPS 2022](https://openreview.net/forum?id=JpZ5du_Kdh). 9. Yucheng Lu, Conglong Li, Minjia Zhang, Christopher De Sa, Yuxiong He. (2022) Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam. [arXiv:2202.06009](https://arxiv.org/abs/2202.06009). -10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596) and [ICML 2022](https://proceedings.mlr.press/v162/rajbhandari22a.html). +10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596) and [ICML 2022](https://proceedings.mlr.press/v162/rajbhandari22a.html). [[pdf]](https://arxiv.org/abs/2201.05596) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/ICML-5mins.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/) 11. Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, Elton Zhang, Rewon Child, Reza Yazdani Aminabadi, Julie Bernauer, Xia Song, Mohammad Shoeybi, Yuxiong He, Michael Houston, Saurabh Tiwary, Bryan Catanzaro. (2022) Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model [arXiv:2201.11990](https://arxiv.org/abs/2201.11990). 12. Xiaoxia Wu, Zhewei Yao, Minjia Zhang, Conglong Li, Yuxiong He. (2022) Extreme Compression for Pre-trained Transformers Made Simple and Efficient. [arXiv:2206.01859](https://arxiv.org/abs/2206.01859) and [NeurIPS 2022](https://openreview.net/forum?id=xNeAhc2CNAl). -13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861) and [NeurIPS 2022](https://openreview.net/forum?id=f-fVCElZ-G1). -14. Reza Yazdani Aminabadi, Samyam Rajbhandari, Minjia Zhang, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Jeff Rasley, Shaden Smith, Olatunji Ruwase, Yuxiong He. (2022) DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. [arXiv:2207.00032](https://arxiv.org/abs/2207.00032) and [SC 2022](https://dl.acm.org/doi/abs/10.5555/3571885.3571946). +13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861) and [NeurIPS 2022](https://openreview.net/forum?id=f-fVCElZ-G1) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/zeroquant_series.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-compression-a-composable-library-for-extreme-compression-and-zero-cost-quantization/) +14. Reza Yazdani Aminabadi, Samyam Rajbhandari, Minjia Zhang, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Jeff Rasley, Shaden Smith, Olatunji Ruwase, Yuxiong He. (2022) DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. [arXiv:2207.00032](https://arxiv.org/abs/2207.00032) and [SC 2022](https://dl.acm.org/doi/abs/10.5555/3571885.3571946). [[paper]](https://arxiv.org/abs/2207.00032) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/sc22-ds-inference.pdf) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/) 15. Zhewei Yao, Xiaoxia Wu, Conglong Li, Connor Holmes, Minjia Zhang, Cheng Li, Yuxiong He. (2022) Random-LTD: Random and Layerwise Token Dropping Brings Efficient Training for Large-scale Transformers. [arXiv:2211.11586](https://arxiv.org/abs/2211.11586). -16. Conglong Li, Zhewei Yao, Xiaoxia Wu, Minjia Zhang, Yuxiong He. (2022) DeepSpeed Data Efficiency: Improving Deep Learning Model Quality and Training Efficiency via Efficient Data Sampling and Routing. [arXiv:2212.03597](https://arxiv.org/abs/2212.03597). -17. Xiaoxia Wu, Cheng Li, Reza Yazdani Aminabadi, Zhewei Yao, Yuxiong He. (2023) Understanding INT4 Quantization for Transformer Models: Latency Speedup, Composability, and Failure Cases. [arXiv:2301.12017](https://arxiv.org/abs/2301.12017). +16. Conglong Li, Zhewei Yao, Xiaoxia Wu, Minjia Zhang, Yuxiong He. (2022) DeepSpeed Data Efficiency: Improving Deep Learning Model Quality and Training Efficiency via Efficient Data Sampling and Routing. [arXiv:2212.03597](https://arxiv.org/abs/2212.03597) [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) +17. Xiaoxia Wu, Cheng Li, Reza Yazdani Aminabadi, Zhewei Yao, Yuxiong He. (2023) Understanding INT4 Quantization for Transformer Models: Latency Speedup, Composability, and Failure Cases. [arXiv:2301.12017](https://arxiv.org/abs/2301.12017) and [ICML2023](https://icml.cc/Conferences/2023). 18. Syed Zawad, Cheng Li, Zhewei Yao, Elton Zheng, Yuxiong He, Feng Yan. (2023) DySR: Adaptive Super-Resolution via Algorithm and System Co-design. [ICLR:2023](https://openreview.net/forum?id=Pgtn4l6eKjv). -19. Sheng Shen, Zhewei Yao, Chunyuan Li, Trevor Darrell, Kurt Keutzer, Yuxiong He. (2023) Scaling Vision-Language Models with Sparse Mixture of Experts. [arXiv:2303.07226](https://arxiv.org/abs/2303.07226). +19. Sheng Shen, Zhewei Yao, Chunyuan Li, Trevor Darrell, Kurt Keutzer, Yuxiong He. (2023) Scaling Vision-Language Models with Sparse Mixture of Experts. [arXiv:2303.07226](https://arxiv.org/abs/2303.07226) and [Finding at EMNLP2023](https://2023.emnlp.org/). 20. Quentin Anthony, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He, Aamir Shafi, Mustafa Abduljabbar, Hari Subramoni, Dhabaleswar Panda. (2023) MCR-DL: Mix-and-Match Communication Runtime for Deep Learning [arXiv:2303.08374](https://arxiv.org/abs/2303.08374) and will appear at IPDPS 2023. +21. Siddharth Singh, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He, Abhinav Bhatele. (2023) A Hybrid Tensor-Expert-Data Parallelism Approach to Optimize Mixture-of-Experts Training [arXiv:2303.06318](https://arxiv.org/abs/2303.06318) and will appear at ICS 2023. +22. Guanhua Wang, Heyang Qin, Sam Ade Jacobs, Xiaoxia Wu, Connor Holmes, Zhewei Yao, Samyam Rajbhandari, Olatunji Ruwase, Feng Yan, Lei Yang, Yuxiong He. (2023) ZeRO++: Extremely Efficient Collective Communication for Giant Model Training [arXiv:2306.10209](https://arxiv.org/abs/2306.10209) and [ML for Sys Workshop at NeurIPS2023](http://mlforsystems.org/) [[blog]](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/) +23. Zhewei Yao, Xiaoxia Wu, Cheng Li, Stephen Youn, Yuxiong He. (2023) ZeroQuant-V2: Exploring Post-training Quantization in LLMs from Comprehensive Study to Low Rank Compensation [arXiv:2303.08302](https://arxiv.org/abs/2303.08302) and [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/zeroquant_series.pdf) +24. Pareesa Ameneh Golnari, Zhewei Yao, Yuxiong He. (2023) Selective Guidance: Are All the Denoising Steps of Guided Diffusion Important? [arXiv:2305.09847](https://arxiv.org/abs/2305.09847) +25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320). +26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782) and [ENLSP2023 Workshop at NeurIPS2023](https://neurips2023-enlsp.github.io/) [[slides]](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/assets/files/zeroquant_series.pdf) +27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf) +28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610) [[blog]](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/) +29. Zhewei Yao, Reza Yazdani Aminabadi, Stephen Youn, Xiaoxia Wu, Elton Zheng, Yuxiong He. (2023) ZeroQuant-HERO: Hardware-Enhanced Robust Optimized Post-Training Quantization Framework for W8A8 Transformers [arXiv:2310.17723](https://arxiv.org/abs/2310.17723) +30. Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Reza Yazdani Aminadabi, Shuaiwen Leon Song, Samyam Rajbhandari, Yuxiong He. (2024) [System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://dl.acm.org/doi/10.1145/3662158.3662806) +31. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820) + +32. Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Reza Yazdani Aminadabi, Shuaiwen Leon Song, Samyam Rajbhandari, Yuxiong He. (2024) [System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://dl.acm.org/doi/10.1145/3662158.3662806) +33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820) +34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996) +35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242) +36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [arxiv](https://arxiv.org/abs/2509.21271), [ASPLOS 2026](https://www.asplos-conference.org/asplos2026) # Videos 1. DeepSpeed KDD 2020 Tutorial @@ -142,7 +137,8 @@ comments. * Registration is free and all videos are available on-demand. * [ZeRO & Fastest BERT: Increasing the scale and speed of deep learning training in DeepSpeed](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html). 3. [DeepSpeed on AzureML](https://youtu.be/yBVXR8G8Bg8) -4. Community Tutorials +4. [Large Model Training and Inference with DeepSpeed // Samyam Rajbhandari // LLMs in Prod Conference](https://www.youtube.com/watch?v=cntxC3g22oU) [[slides]](docs/assets/files/presentation-mlops.pdf) +5. Community Tutorials * [DeepSpeed: All the tricks to scale to gigantic models (Mark Saroufim)](https://www.youtube.com/watch?v=pDGI668pNg0) * [Turing-NLG, DeepSpeed and the ZeRO optimizer (Yannic Kilcher)](https://www.youtube.com/watch?v=tC01FRB0M7w) * [Ultimate Guide To Scaling ML Models (The AI Epiphany)](https://www.youtube.com/watch?v=hc0u4avAkuM) diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000000..28c298717d80 --- /dev/null +++ b/environment.yml @@ -0,0 +1,21 @@ +channels: + - nvidia/label/cuda-11.8.0 + - pytorch # or pytorch-nightly + - conda-forge +dependencies: + - pytorch + - torchvision + - torchaudio + - cuda + - pytorch-cuda=11.8 + - compilers + - sysroot_linux-64==2.17 + - gcc==11.4 + - ninja + - py-cpuinfo + - libaio + - ca-certificates + - certifi + - openssl + - python=3.10 + - pydantic diff --git a/examples/README.md b/examples/README.md index c110395cb95b..c7ff01dcd2d4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,8 +2,8 @@ If you are looking for examples using DeepSpeed please see the following resources: -1. [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) -2. [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed) +1. [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples) +2. [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed) 3. [DeepSpeed + AzureML](https://github.com/Azure/azureml-examples/tree/main/v1/python-sdk/workflows/train/deepspeed) -4. [DeepSpeed + Hugging Face Transformers Integration](https://huggingface.co/docs/transformers/main_classes/deepspeed) -5. [DeepSpeed + PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.deepspeed.html) +4. [DeepSpeed + Hugging Face Transformers Integration](https://huggingface.co/docs/transformers/deepspeed) +5. [DeepSpeed + PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.deepspeed.html) diff --git a/examples/sdma_allgather/README.md b/examples/sdma_allgather/README.md new file mode 100644 index 000000000000..b80f669f33b1 --- /dev/null +++ b/examples/sdma_allgather/README.md @@ -0,0 +1,136 @@ +# SDMA-Accelerated ZeRO-3 on AMD GPUs + +## Motivation + +ZeRO-3 reconstructs each layer with an AllGather right before its forward / backward +pass, and DeepSpeed's `PartitionedParameterCoordinator` prefetches these AllGathers +on a separate stream so that collective and compute can overlap *in time*. In +practice the time-overlap is already quite good for typical ZeRO-3 workloads. + +What's left is **resource** overlap. On AMD GPUs, RCCL AllGather kernels execute +on the same compute units (CUs) that GEMM and attention run on, and tend to share +wavefront slots, LDS, register file and HBM bandwidth with concurrent compute. +Even when the time-overlap schedule looks near-perfect, the effective compute +throughput during the overlap window can stay below peak because the two workloads +sit on the same physical hardware. + +AMD MI300X / MI325X / MI355X contain dedicated **System DMA (SDMA)** copy engines — +independent hardware queues that move data between GPUs over XGMI without using +the CU array. Routing ZeRO-3's AllGather through SDMA instead of CU-based RCCL +kernels lets collective traffic and compute run on physically separate engines, +leaving CUs largely free for GEMM / attention during the overlap window. In +workloads where overlap is a meaningful bottleneck this can translate into +end-to-end step-time gains (workload-dependent; see the verified results table +below). + +See [RFC #7884](https://github.com/deepspeedai/DeepSpeed/issues/7884) for the +longer design rationale and discussion. + +## Overview + +End-to-end example for the SDMA fast-path inside +`TorchBackend.all_gather_into_tensor`. When the runtime is AMD/ROCm, +the [`mori`](https://github.com/ROCm/mori) package is importable, and the user opts in via +`DS_SDMA_ALLGATHER=1`, `deepspeed.comm` acquires the SDMA backend at +`init_distributed()` time and routes WORLD-group +`all_gather_into_tensor` calls through `mori_cpp.AllGatherIntoTensor` +(intra-node SDMA copy on MI300). RCCL/NCCL is used as the fallback on +any condition that makes the SDMA path unsafe (user did not opt in, +non-WORLD process group, shard larger than the transit buffer, +unsupported dtype, init failure). + +This means: + +- No `ds_config` knob — control is a single env var. Works out of the + box for ZeRO-3 (sequential and coalesced prefetch paths both benefit). +- No source modifications in `partition_parameters.py`: ZeRO-3 just calls + `dist.allgather_fn`, which lands on the backend's + `all_gather_into_tensor`. +- Sub-group allgathers (e.g. when ZeRO is initialised with a non-WORLD + data-parallel group, or with a secondary zero-param group) are routed + through RCCL/NCCL automatically, since the SDMA backend is bound to + WORLD. +- Even when mori is installed, the SDMA path stays off unless the user + sets `DS_SDMA_ALLGATHER=1`, so users keep explicit control over a + hardware-specific fast-path. + +## Environment variables + +| Var | Purpose | +|---|---| +| `DS_SDMA_ALLGATHER=1` | **Opt-in switch.** Required to enable the SDMA fast-path; default is off even when mori is installed. When set, `MORI_ENABLE_SDMA=1` is auto-exported on your behalf so mori allocates the uncached transit buffers the SDMA kernel needs. | +| `DS_SDMA_ALLGATHER_MAX_NUMEL=N` | Transit buffer size in elements (default 64M = 256 MiB per-rank input, ~2 GiB output on 8 ranks). Calls larger than this fall back to RCCL/NCCL. | +| `MORI_ENABLE_SDMA=1` | mori's own knob for uncached transit buffers; normally set automatically by DeepSpeed when `DS_SDMA_ALLGATHER=1`. Export it explicitly only if you want to override or pre-set it. | + +The `run_*_sdma_on.sh` scripts export `DS_SDMA_ALLGATHER=1`; the +`run_*_sdma_off.sh` scripts leave it unset (default). Both variants +share the same `ds_config_zero3.json` — the SDMA decision is made +entirely by env vars. + +## Verified results on 8x MI300X + +| | GPT-7B-ish | Qwen3-32B | +|---|---|---| +| trainer | `train_zero3.py` | `train_qwen3_zero3.py` | +| seq / micro batch | 2048 / 1 | 1024 / 1 | +| dataset | wikitext-2-raw-v1 | wikitext-103-raw-v1 (10 %) | +| measured / warmup steps | 100 / 10 | 100 / 10 | +| **SDMA off (RCCL)** | 697.7 ms / step (11.6 samples/s) | 1402.5 ms / step (5841 tok/s) | +| **SDMA on (this PR)** | **622.0 ms / step (13.0 samples/s)** | **1263.2 ms / step (6486 tok/s)** | +| **gain** | **+10.85 %** | **+9.93 %** | +| peak mem (rank 0) | 12.12 GB, unchanged off ↔ on | 96.45 GB, unchanged off ↔ on | + +The Qwen3-32B number is averaged over two fresh rounds; per-round delta +was +10.85 % and +9.92 %, with 0.29 % run-to-run variance on the off +baseline, so the gap is well outside per-step jitter (~1.5–2.7 %). + +Speedup is workload-dependent — gains shrink (or invert) when allgather +cannot be overlapped with compute (e.g. very small payloads, or +`overlap_comm=false`). + +### Loss curves match across off ↔ on (2000-step runs) + +A long-horizon sanity check on each demo confirms the SDMA path +introduces no numerical drift: 2000 training steps on the same wikitext +shuffle, off vs on traces overlap throughout. Both trainers use the +standard "concat the corpus + slice into fixed `seq_length` chunks" +pattern, so every sample has the same number of real tokens and per-step +loss has no variance from padding fraction. Bucketed mean |off − on| +over the full 2000 steps is ≤ **0.026** on GPT and ≤ **0.048** on Qwen3, +well below natural per-step jitter. + +![GPT-7B-ish — training loss vs step, SDMA off vs on, 2000 steps](images/loss_gpt_2k.png) + +![Qwen3-32B — training loss vs step, SDMA off vs on, 2000 steps](images/loss_qwen3_2k.png) + +## Reproduction + +```bash +cd examples/sdma_allgather + +# Demo 1 — GPT-7B-ish, ~minute run, no HF download +bash run_gpt_sdma_off.sh # default (DS_SDMA_ALLGATHER unset), RCCL baseline +bash run_gpt_sdma_on.sh # DS_SDMA_ALLGATHER=1, SDMA fast-path -> +10.85 % + +# Demo 2 — Qwen3-32B, ~few-minute run, weight-free (random init via from_config) +bash run_qwen3_sdma_off.sh # ~1402 ms / step +bash run_qwen3_sdma_on.sh # ~1263 ms / step -> +9.93 % +``` + +Override knobs via env vars: `SEQ_LEN`, `BATCH_SIZE`, `NUM_STEPS`, +`WARMUP_STEPS`, `NUM_GPUS`, `MODEL`, `DS_CONFIG`. + +## Files + +``` +ds_config_zero3.json single shared ZeRO-3 + bf16 + DS-default buckets config +run_gpt_sdma_off.sh GPT-7B-ish + ZeRO-3, SDMA disabled via env var +run_gpt_sdma_on.sh GPT-7B-ish + ZeRO-3, SDMA enabled via env var +run_qwen3_sdma_off.sh Qwen3-32B + ZeRO-3, SDMA disabled via env var +run_qwen3_sdma_on.sh Qwen3-32B + ZeRO-3, SDMA enabled via env var +test_sdma_allgather_zero3.py unit test exercising the transparent SDMA path +train_qwen3_zero3.py Qwen3 trainer (self-contained, wikitext) +train_zero3.py GPT trainer +images/loss_gpt_2k.png GPT loss curve, off vs on, 2000 steps +images/loss_qwen3_2k.png Qwen3-32B loss curve, off vs on, 2000 steps +``` diff --git a/examples/sdma_allgather/ds_config_zero3.json b/examples/sdma_allgather/ds_config_zero3.json new file mode 100644 index 000000000000..8cae9d5648b3 --- /dev/null +++ b/examples/sdma_allgather/ds_config_zero3.json @@ -0,0 +1,42 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 10, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 10 + } + }, + "gradient_clipping": 1.0, + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "allgather_partitions": true, + "allgather_bucket_size": 5e7, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e7, + "contiguous_gradients": true, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_prefetch_bucket_size": 5e7, + "stage3_param_persistence_threshold": 1e5, + "stage3_gather_16bit_weights_on_model_save": true, + "sub_group_size": 1e12 + }, + "wall_clock_breakdown": false +} diff --git a/examples/sdma_allgather/images/loss_gpt_2k.png b/examples/sdma_allgather/images/loss_gpt_2k.png new file mode 100644 index 000000000000..de3d18d9f014 Binary files /dev/null and b/examples/sdma_allgather/images/loss_gpt_2k.png differ diff --git a/examples/sdma_allgather/images/loss_qwen3_2k.png b/examples/sdma_allgather/images/loss_qwen3_2k.png new file mode 100644 index 000000000000..1843e506a092 Binary files /dev/null and b/examples/sdma_allgather/images/loss_qwen3_2k.png differ diff --git a/examples/sdma_allgather/run_gpt_sdma_off.sh b/examples/sdma_allgather/run_gpt_sdma_off.sh new file mode 100755 index 000000000000..a8876e1b5bec --- /dev/null +++ b/examples/sdma_allgather/run_gpt_sdma_off.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# GPT-7B-ish + ZeRO-3 baseline (RCCL allgather). +# Default: the SDMA fast-path inside deepspeed.comm stays off unless the +# user explicitly sets DS_SDMA_ALLGATHER=1, so this script simply doesn't +# export it. + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +deepspeed --num_gpus 8 "${SCRIPT_DIR}/train_zero3.py" \ + --deepspeed \ + --deepspeed_config "${SCRIPT_DIR}/ds_config_zero3.json" \ + --data_mode wikitext2 \ + --train_steps 100 diff --git a/examples/sdma_allgather/run_gpt_sdma_on.sh b/examples/sdma_allgather/run_gpt_sdma_on.sh new file mode 100755 index 000000000000..48e9fed2bb80 --- /dev/null +++ b/examples/sdma_allgather/run_gpt_sdma_on.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# GPT-7B-ish + ZeRO-3 with the SDMA fast-path opted in. +# +# DS_SDMA_ALLGATHER=1 is the single opt-in switch. When set, +# deepspeed.comm's TorchBackend tries to bring up the mori SDMA backend +# at init time and routes WORLD-group all_gather_into_tensor through it. +# Mori's MORI_ENABLE_SDMA=1 is auto-exported on the user's behalf when +# DS_SDMA_ALLGATHER=1 is set, so users normally don't need to touch it. +# Without DS_SDMA_ALLGATHER=1, even an mori-installed run stays on RCCL. +export DS_SDMA_ALLGATHER=1 + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +deepspeed --num_gpus 8 "${SCRIPT_DIR}/train_zero3.py" \ + --deepspeed \ + --deepspeed_config "${SCRIPT_DIR}/ds_config_zero3.json" \ + --data_mode wikitext2 \ + --train_steps 100 diff --git a/examples/sdma_allgather/run_qwen3_sdma_off.sh b/examples/sdma_allgather/run_qwen3_sdma_off.sh new file mode 100755 index 000000000000..ddfdf2fa650e --- /dev/null +++ b/examples/sdma_allgather/run_qwen3_sdma_off.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Qwen3-32B + DeepSpeed ZeRO-3 baseline (RCCL allgather). +# +# Default: deepspeed.comm's SDMA fast-path stays off unless the user +# explicitly sets DS_SDMA_ALLGATHER=1, so this script simply doesn't +# export it and pairs cleanly with run_qwen3_sdma_on.sh (same ds_config; +# only env vars differ) for the A/B. +# +# model : Qwen/Qwen3-32B (full 64 layers, BF16, eager attention) +# data : wikitext-103-raw-v1, 10% split, model's own tokenizer +# parallel : ZeRO-3, DP=8 (single node, MI300X x 8) +# bucket : DeepSpeed defaults (stage3_prefetch_bucket_size = 5e7) +# seq/bs : seq_length=1024, micro_batch=1 +# steps : 100 measured + 10 warmup +# +# Override via env vars: SEQ_LEN, BATCH_SIZE, NUM_STEPS, WARMUP_STEPS, +# NUM_GPUS, MODEL, DS_CONFIG. +set -eu + +# Reduce HIP allocator fragmentation — the 32B model has long-lived tensors +# that benefit from segment expansion under heavy ZeRO-3 churn. +export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TORCH_NCCL_ENABLE_MONITORING=0 # quiets harmless TCPStore shutdown trace + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +deepspeed --num_gpus "${NUM_GPUS:-8}" "${SCRIPT_DIR}/train_qwen3_zero3.py" \ + --model_name "${MODEL:-Qwen/Qwen3-32B}" \ + --num_layers "${NUM_LAYERS:-0}" \ + --seq_length "${SEQ_LEN:-1024}" \ + --batch_size "${BATCH_SIZE:-1}" \ + --num_steps "${NUM_STEPS:-100}" \ + --warmup_steps "${WARMUP_STEPS:-10}" \ + --ds_config "${DS_CONFIG:-${SCRIPT_DIR}/ds_config_zero3.json}" diff --git a/examples/sdma_allgather/run_qwen3_sdma_on.sh b/examples/sdma_allgather/run_qwen3_sdma_on.sh new file mode 100755 index 000000000000..0c22aee308b5 --- /dev/null +++ b/examples/sdma_allgather/run_qwen3_sdma_on.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Qwen3-32B + DeepSpeed ZeRO-3 with the SDMA fast-path opted in. +# +# DS_SDMA_ALLGATHER=1 is the single opt-in switch. When set, +# deepspeed.comm's TorchBackend tries to bring up the mori SDMA backend +# at init time and routes WORLD-group all_gather_into_tensor through it. +# Mori's MORI_ENABLE_SDMA=1 is auto-exported on the user's behalf when +# DS_SDMA_ALLGATHER=1 is set, so users normally don't need to touch it. +# This script otherwise uses the same ds_config as run_qwen3_sdma_off.sh; +# the only difference is this env var. +set -eu + +export DS_SDMA_ALLGATHER=1 + +export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TORCH_NCCL_ENABLE_MONITORING=0 + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +deepspeed --num_gpus "${NUM_GPUS:-8}" "${SCRIPT_DIR}/train_qwen3_zero3.py" \ + --model_name "${MODEL:-Qwen/Qwen3-32B}" \ + --num_layers "${NUM_LAYERS:-0}" \ + --seq_length "${SEQ_LEN:-1024}" \ + --batch_size "${BATCH_SIZE:-1}" \ + --num_steps "${NUM_STEPS:-100}" \ + --warmup_steps "${WARMUP_STEPS:-10}" \ + --ds_config "${DS_CONFIG:-${SCRIPT_DIR}/ds_config_zero3.json}" diff --git a/examples/sdma_allgather/test_sdma_allgather_zero3.py b/examples/sdma_allgather/test_sdma_allgather_zero3.py new file mode 100644 index 000000000000..1dd3a3172d41 --- /dev/null +++ b/examples/sdma_allgather/test_sdma_allgather_zero3.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Unit test for the transparent SDMA allgather path in deepspeed.comm. + +After ``deepspeed.init_distributed()`` returns, ``dist.all_gather_into_tensor`` +on the WORLD process group is transparently routed through +``mori_cpp.AllGatherIntoTensor`` on AMD MI300 when mori is available, with +RCCL/NCCL as a fallback. This test exercises that path the same way +ZeRO-3's ``_all_gather_dtype`` does (flat output / per-rank shard input +with ``async_op=True``) and verifies correctness and algorithm bandwidth +for the common dtypes. + +Usage: + cd examples/sdma_allgather + deepspeed --num_gpus 8 test_sdma_allgather_zero3.py + deepspeed --num_gpus 8 test_sdma_allgather_zero3.py --partition_sz 4194304 --iterations 50 +""" + +import argparse +import os + +import numpy as np +import torch + +import deepspeed +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.comm import mori as _mori + + +def verify_allgather(partitions, world_size, partition_sz, rank, dtype): + """Verify that each rank's partition contains the expected fill pattern.""" + passed = True + for r in range(world_size): + chunk = partitions[r].narrow(0, 0, partition_sz).float().cpu() + expected_val = float(r + 1) + if not torch.allclose(chunk, torch.full_like(chunk, expected_val)): + unique_vals = chunk.unique() + print(f" [rank {rank}] FAIL: partition[{r}] expected all {expected_val}, " + f"got unique values: {unique_vals[:10]}") + passed = False + return passed + + +def run_single_allgather(rank, world_size, dtype, partition_sz, ag_stream): + """Execute one allgather call following the ZeRO-3 ``_all_gather_dtype`` path.""" + device = get_accelerator().current_device_name() + + flat_tensor = torch.empty(partition_sz * world_size, dtype=dtype, device=device, requires_grad=False) + partitions = [flat_tensor.narrow(0, partition_sz * i, partition_sz) for i in range(world_size)] + partitions[rank].fill_(float(rank + 1)) + + with get_accelerator().stream(ag_stream): + handle = dist.allgather_fn(flat_tensor, partitions[rank], async_op=True) + + with get_accelerator().stream(ag_stream): + handle.wait() + get_accelerator().current_stream().wait_stream(ag_stream) + + return partitions + + +def run_bandwidth_test(rank, world_size, dtype, partition_sz, ag_stream, iterations, warmup): + """Measure allgather bandwidth following the ZeRO-3 overlap pattern.""" + device = get_accelerator().current_device_name() + elem_size = torch.tensor([], dtype=dtype).element_size() + total_bytes = partition_sz * elem_size * world_size + + ev_start = get_accelerator().Event(enable_timing=True) + ev_end = get_accelerator().Event(enable_timing=True) + times_ms = [] + + for i in range(warmup + iterations): + flat_tensor = torch.empty(partition_sz * world_size, dtype=dtype, device=device, requires_grad=False) + partitions = [flat_tensor.narrow(0, partition_sz * r, partition_sz) for r in range(world_size)] + partitions[rank].fill_(float(rank + 1)) + + dist.barrier() + + ev_start.record(ag_stream) + with get_accelerator().stream(ag_stream): + handle = dist.allgather_fn(flat_tensor, partitions[rank], async_op=True) + with get_accelerator().stream(ag_stream): + handle.wait() + ev_end.record(ag_stream) + + ag_stream.synchronize() + + elapsed_ms = ev_start.elapsed_time(ev_end) + if i >= warmup: + times_ms.append(elapsed_ms) + + return times_ms, total_bytes + + +def main(): + parser = argparse.ArgumentParser(description="Transparent SDMA allgather unit test") + parser.add_argument("--partition_sz", type=int, default=1024 * 1024, help="Elements per rank per allgather call") + parser.add_argument("--iterations", type=int, default=20, help="Number of measurement iterations") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--local_rank", type=int, default=int(os.environ.get("LOCAL_RANK", 0))) + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + + deepspeed.init_distributed(dist_backend="cpu:gloo,cuda:nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + get_accelerator().set_device(args.local_rank) + + if rank == 0: + backend = "SDMA (mori)" if _mori.is_enabled() else "RCCL/NCCL (mori unavailable or disabled)" + print(f"\n{'=' * 65}") + print(f" Transparent SDMA Allgather Unit Test") + print(f" world_size : {world_size}") + print(f" partition_sz : {args.partition_sz:,} elements") + print(f" iterations : {args.iterations} (warmup {args.warmup})") + print(f" backend : {backend}") + print(f"{'=' * 65}\n") + + ag_stream = get_accelerator().Stream() + + test_dtypes = [ + ("bfloat16", torch.bfloat16), + ("float16", torch.float16), + ("float32", torch.float32), + ] + + if rank == 0: + print("--- Correctness ---") + + all_correct = True + for dtype_name, dtype in test_dtypes: + dist.barrier() + partitions = run_single_allgather(rank, world_size, dtype, args.partition_sz, ag_stream) + passed = verify_allgather(partitions, world_size, args.partition_sz, rank, dtype) + + passed_t = torch.tensor([1 if passed else 0], dtype=torch.int32) + dist.all_reduce(passed_t, op=dist.ReduceOp.MIN) + ok = passed_t.item() == 1 + + if rank == 0: + elem_bytes = torch.tensor([], dtype=dtype).element_size() + data_mb = args.partition_sz * elem_bytes * world_size / (1024**2) + status = "PASSED" if ok else "FAILED" + print(f" {dtype_name:10s} data={data_mb:8.2f} MB {status}") + if not ok: + all_correct = False + + if rank == 0: + print(f"\n--- Bandwidth (iterations={args.iterations}, warmup={args.warmup}) ---") + print(f" {'dtype':10s} {'data_MB':>10s} {'avg_ms':>9s} " + f"{'min_ms':>9s} {'max_ms':>9s} {'algo_BW':>12s}") + print(f" {'-'*10} {'-'*10} {'-'*9} {'-'*9} {'-'*9} {'-'*12}") + + for dtype_name, dtype in test_dtypes: + dist.barrier() + times_ms, total_bytes = run_bandwidth_test(rank, world_size, dtype, args.partition_sz, ag_stream, + args.iterations, args.warmup) + + avg_ms = np.mean(times_ms) + min_ms = np.min(times_ms) + max_ms = np.max(times_ms) + + avg_t = torch.tensor([avg_ms], dtype=torch.float64) + min_t = torch.tensor([min_ms], dtype=torch.float64) + max_t = torch.tensor([max_ms], dtype=torch.float64) + dist.all_reduce(avg_t, op=dist.ReduceOp.SUM) + dist.all_reduce(min_t, op=dist.ReduceOp.MIN) + dist.all_reduce(max_t, op=dist.ReduceOp.MAX) + + if rank == 0: + g_avg_ms = avg_t.item() / world_size + g_min_ms = min_t.item() + g_max_ms = max_t.item() + data_mb = total_bytes / (1024**2) + algo_bw_gbs = total_bytes / (g_avg_ms / 1000) / (1024**3) + print(f" {dtype_name:10s} {data_mb:10.2f} {g_avg_ms:9.3f} " + f"{g_min_ms:9.3f} {g_max_ms:9.3f} {algo_bw_gbs:9.2f} GB/s") + + dist.barrier() + if rank == 0: + print() + print(f"Result: {'All correctness tests PASSED' if all_correct else 'Some correctness tests FAILED'}") + print(f"{'=' * 65}\n") + + get_accelerator().synchronize() + dist.barrier() + if _mori.is_enabled(): + import mori.shmem as shmem + shmem.shmem_finalize() + + +if __name__ == "__main__": + main() diff --git a/examples/sdma_allgather/train_qwen3_zero3.py b/examples/sdma_allgather/train_qwen3_zero3.py new file mode 100644 index 000000000000..673f09b83d54 --- /dev/null +++ b/examples/sdma_allgather/train_qwen3_zero3.py @@ -0,0 +1,283 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Qwen3 + DeepSpeed ZeRO-3 benchmark for the SDMA allgather feature. + +Loads a Qwen3 model with random initialisation under `deepspeed.zero.Init` +so each rank only allocates its 1/world_size shard, then runs a small number +of training steps on either real wikitext or synthetic random tokens. Step +time is measured rank-0 side and reported with peak memory and the average +loss. The same trainer is used for the SDMA-on and SDMA-off comparison runs +in run_qwen3_sdma_{on,off}.sh. + +The SDMA fast-path is opt-in via a single env var: ``deepspeed.comm`` +brings up the mori SDMA backend at init time when ``DS_SDMA_ALLGATHER=1`` +and routes WORLD-group ``all_gather_into_tensor`` through +``mori_cpp.AllGatherIntoTensor`` on AMD MI300. No ``ds_config`` flag is +required. Leaving ``DS_SDMA_ALLGATHER`` unset (the default) reproduces +the RCCL/NCCL baseline for A/B comparisons even with mori installed. + +Real-data path uses HuggingFace `datasets` to stream wikitext-103 and the +model's own tokenizer to pad/truncate to seq_length. No external benchmark +repo is required. +""" + +import argparse +import os +import time + +import deepspeed +import torch +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--model_name", default="Qwen/Qwen3-32B") + p.add_argument("--num_layers", + type=int, + default=0, + help="0 = use model default; smaller values for quick smoke runs") + p.add_argument("--seq_length", type=int, default=1024) + p.add_argument("--batch_size", type=int, default=1) + p.add_argument("--num_steps", type=int, default=50) + p.add_argument("--warmup_steps", type=int, default=10) + p.add_argument("--log_interval", type=int, default=10) + p.add_argument("--ds_config", required=True) + p.add_argument("--dataset", + default="wikitext", + choices=["wikitext", "synthetic"], + help="Real text (wikitext-103) or pre-generated random ids") + p.add_argument("--dataset_percentage", + type=float, + default=10.0, + help="Percentage of wikitext train split to load (1.0-100.0)") + p.add_argument("--local_rank", type=int, default=-1) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Self-contained data pipeline (no external benchmark repo dependency). +# --------------------------------------------------------------------------- +class _SyntheticDataset(Dataset): + """Pre-generated random token ids for deterministic timing runs.""" + + def __init__(self, vocab_size, seq_length, num_samples=10000, seed=42): + gen = torch.Generator().manual_seed(seed) + self.input_ids = torch.randint(0, vocab_size, (num_samples, seq_length), generator=gen, dtype=torch.long) + self.seq_length = seq_length + + def __len__(self): + return self.input_ids.shape[0] + + def __getitem__(self, idx): + ids = self.input_ids[idx] + return { + "input_ids": ids, + "labels": ids.clone(), + "attention_mask": torch.ones(self.seq_length, dtype=torch.long), + } + + +def _build_wikitext_loader(model_name, seq_length, batch_size, dataset_percentage, rank, world_size, is_main): + """Stream wikitext-103-raw-v1 as a concatenated token stream sliced into + fixed `seq_length` chunks. + + This is the standard "group_texts" / GPT-style chunking pattern: every + sample is exactly seq_length REAL tokens with no padding and no per-row + boundaries. Result is uniform-difficulty samples, so the per-step loss + has no variance from "this row was 90 % padding" effects — which is what + made the per-row+padding variant of this loader jittery on bs=1 demos. + """ + from datasets import DownloadConfig, load_dataset + from datasets.utils.logging import disable_progress_bar + if not is_main: + disable_progress_bar() + + fraction = max(1, int(dataset_percentage)) + split = "train" if dataset_percentage >= 100 else f"train[:{fraction}%]" + + if is_main: + print(f"[trainer] loading wikitext-103-raw-v1 split={split}") + raw = load_dataset("wikitext", + "wikitext-103-raw-v1", + split=split, + download_config=DownloadConfig(disable_tqdm=True)) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token or tokenizer.convert_ids_to_tokens(2) + + if is_main: + print(f"[trainer] encoding {len(raw)} rows as a single stream ...") + text = "\n\n".join(t for t in raw["text"] if t.strip()) + all_ids = tokenizer.encode(text, add_special_tokens=False) + + # Optional cap on number of chunks (env var) so the per-epoch length can + # be tuned for short demos. 0 = use all available chunks. + max_chunks = int(os.environ.get("QWEN3_MAX_CHUNKS", "0")) + n_full = len(all_ids) // seq_length + if max_chunks > 0: + n_full = min(n_full, max_chunks) + chunks = torch.tensor(all_ids[:n_full * seq_length], dtype=torch.long).view(n_full, seq_length) + if is_main: + print(f"[trainer] chunked: {len(all_ids)} tokens -> {n_full} " + f"sequences of {seq_length} (no padding)", + flush=True) + + class _ChunkDataset(Dataset): + + def __init__(self, t): + self.t = t + + def __len__(self): + return self.t.shape[0] + + def __getitem__(self, idx): + ids = self.t[idx] + return { + "input_ids": ids, + "labels": ids.clone(), + "attention_mask": torch.ones(ids.shape[0], dtype=torch.long), + } + + ds = _ChunkDataset(chunks) + sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) + return DataLoader(ds, batch_size=batch_size, sampler=sampler, num_workers=0, drop_last=True, pin_memory=True) + + +def _build_loader(args, vocab_size, rank, world_size, is_main): + if args.dataset == "wikitext": + return _build_wikitext_loader(args.model_name, args.seq_length, args.batch_size, args.dataset_percentage, rank, + world_size, is_main) + ds = _SyntheticDataset(vocab_size, args.seq_length) + return DataLoader(ds, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=0, pin_memory=True) + + +# --------------------------------------------------------------------------- +# Model construction under deepspeed.zero.Init so each rank only materialises +# its shard. +# --------------------------------------------------------------------------- +def build_model(model_name, num_layers, ds_config_path): + cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + if num_layers > 0: + cfg.num_hidden_layers = num_layers + cfg.torch_dtype = torch.bfloat16 + cfg.use_cache = False + cfg.attn_implementation = "eager" # FA2 not always available on AMD; eager is safe. + if dist.is_initialized() and dist.get_rank() == 0: + print(f"[trainer] {model_name}: layers={cfg.num_hidden_layers} " + f"hidden={cfg.hidden_size} heads={cfg.num_attention_heads} " + f"kv_heads={cfg.num_key_value_heads} vocab={cfg.vocab_size}") + with deepspeed.zero.Init(config_dict_or_path=ds_config_path): + model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True) + return model, cfg + + +def main(): + args = parse_args() + deepspeed.init_distributed() + rank = dist.get_rank() + world = dist.get_world_size() + accel = get_accelerator() + device_idx = args.local_rank if args.local_rank >= 0 else rank % accel.device_count() + device = torch.device(accel.device_name(device_idx)) + accel.set_device(device_idx) + + if rank == 0: + print(f"[trainer] world={world} device={device} ds_config={args.ds_config}") + + model, cfg = build_model(args.model_name, args.num_layers, args.ds_config) + + engine, _, _, _ = deepspeed.initialize( + args=args, + model=model, + model_parameters=[p for p in model.parameters() if p.requires_grad], + config=args.ds_config, + ) + + if rank == 0: + from deepspeed.comm import mori as _mori + print(f"[trainer] SDMA handle is_enabled={_mori.is_enabled()}", flush=True) + + loader = _build_loader(args, cfg.vocab_size, rank, world, rank == 0) + if rank == 0: + print(f"[trainer] dataloader: {len(loader)} batches/epoch, " + f"running {args.num_steps} steps", flush=True) + + step_times, losses = [], [] + get_accelerator().reset_peak_memory_stats() + t_train_start = time.perf_counter() + step, epoch = 0, 0 + data_iter = iter(loader) + skipped_empty = 0 + while step < args.num_steps: + try: + batch = next(data_iter) + except StopIteration: + epoch += 1 + if hasattr(loader.sampler, "set_epoch"): + loader.sampler.set_epoch(epoch) + data_iter = iter(loader) + batch = next(data_iter) + ids = batch["input_ids"].to(device, non_blocking=True) + labels = batch["labels"].to(device, non_blocking=True) + attn = batch["attention_mask"].to(device, non_blocking=True) + # Defensive: on the chunked wikitext loader every chunk is full of + # real tokens so these guards are no-ops, but they keep the trainer + # safe for the synthetic mode and any future padded variants. + if int(attn.sum().item()) == 0: + skipped_empty += 1 + continue + labels = labels.masked_fill(attn == 0, -100) + get_accelerator().synchronize() + t0 = time.perf_counter() + out = engine(input_ids=ids, labels=labels, attention_mask=attn) + engine.backward(out.loss) + engine.step() + get_accelerator().synchronize() + dt = time.perf_counter() - t0 + + if step >= args.warmup_steps: + step_times.append(dt) + losses.append(out.loss.detach().item()) + + if rank == 0 and step % args.log_interval == 0: + tag = "warmup" if step < args.warmup_steps else "measured" + tps = args.batch_size * args.seq_length * world / dt + print( + f"[trainer] step {step:4d} ({tag:7s}) | loss {out.loss.item():8.4f} | " + f"step {dt*1000:7.1f} ms | {tps:8.0f} tok/s", + flush=True) + step += 1 + + t_train_end = time.perf_counter() + + if rank == 0: + n = len(step_times) + avg_dt = sum(step_times) / n + tokens_per_step = args.batch_size * args.seq_length * world + tps = tokens_per_step / avg_dt + peak_gb = get_accelerator().max_memory_allocated() / 1e9 + avg_loss = sum(losses) / n + print() + print("=" * 70) + print("Qwen3 ZeRO-3 benchmark complete") + print(f" measured steps : {n} (warmup={args.warmup_steps} skipped)") + print(f" total wall (s) : {t_train_end - t_train_start:.1f}") + print(f" avg step (ms) : {avg_dt * 1000:.1f}") + print(f" tokens/sec (ws) : {tps:.1f}") + print(f" peak mem (GB,r0) : {peak_gb:.2f}") + print(f" avg loss : {avg_loss:.4f}") + print(f" final loss : {losses[-1]:.4f}") + print(f" empty batches : {skipped_empty}") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/sdma_allgather/train_zero3.py b/examples/sdma_allgather/train_zero3.py new file mode 100644 index 000000000000..69440c5f43f6 --- /dev/null +++ b/examples/sdma_allgather/train_zero3.py @@ -0,0 +1,335 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +DeepSpeed ZeRO-3 training example with allgather overlap. +Trains a GPT-2-style transformer on synthetic data for demonstration. +Designed for single-node 8x AMD GPU setup. +""" + +import argparse +import math +import os +import time + +import torch +import torch.nn as nn +import deepspeed +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from torch.utils.data import Dataset, DataLoader + + +# --------------------------------------------------------------------------- +# Model: minimal GPT-2-style transformer +# --------------------------------------------------------------------------- +class CausalSelfAttention(nn.Module): + + def __init__(self, hidden_size, num_heads, max_seq_len, dropout=0.1): + super().__init__() + assert hidden_size % num_heads == 0 + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.qkv = nn.Linear(hidden_size, 3 * hidden_size) + self.proj = nn.Linear(hidden_size, hidden_size) + self.attn_drop = nn.Dropout(dropout) + self.proj_drop = nn.Dropout(dropout) + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(max_seq_len, max_seq_len)).view(1, 1, max_seq_len, max_seq_len), + ) + + def forward(self, x): + B, T, C = x.size() + q, k, v = self.qkv(x).split(C, dim=-1) + q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + + scale = 1.0 / math.sqrt(self.head_dim) + attn = (q @ k.transpose(-2, -1)) * scale + attn = attn.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float("-inf")) + attn = torch.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C) + return self.proj_drop(self.proj(out)) + + +class TransformerBlock(nn.Module): + + def __init__(self, hidden_size, num_heads, max_seq_len, dropout=0.1): + super().__init__() + self.ln1 = nn.LayerNorm(hidden_size) + self.attn = CausalSelfAttention(hidden_size, num_heads, max_seq_len, dropout) + self.ln2 = nn.LayerNorm(hidden_size) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.GELU(), + nn.Linear(4 * hidden_size, hidden_size), + nn.Dropout(dropout), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +class GPT2Model(nn.Module): + + def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, dropout=0.1): + super().__init__() + self.tok_emb = nn.Embedding(vocab_size, hidden_size) + self.pos_emb = nn.Embedding(max_seq_len, hidden_size) + self.drop = nn.Dropout(dropout) + self.blocks = nn.Sequential( + *[TransformerBlock(hidden_size, num_heads, max_seq_len, dropout) for _ in range(num_layers)]) + self.ln_f = nn.LayerNorm(hidden_size) + self.head = nn.Linear(hidden_size, vocab_size, bias=False) + + def forward(self, input_ids, labels=None): + B, T = input_ids.size() + pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) + x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + loss = None + if labels is not None: + loss = nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ) + return loss, logits + + +# --------------------------------------------------------------------------- +# Synthetic dataset +# --------------------------------------------------------------------------- +class SyntheticTextDataset(Dataset): + """Generates synthetic token sequences for perf/correctness testing.""" + + def __init__(self, vocab_size, seq_len, num_samples, seed=42, mode="random"): + self.vocab_size = vocab_size + self.seq_len = seq_len + self.num_samples = num_samples + self.seed = seed + self.mode = mode + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + if self.mode == "random": + g = torch.Generator() + g.manual_seed(self.seed + idx) + tokens = torch.randint(0, self.vocab_size, (self.seq_len + 1, ), generator=g) + elif self.mode == "arange": + start = (self.seed + idx) % self.vocab_size + tokens = (torch.arange(self.seq_len + 1, dtype=torch.long) + start) % self.vocab_size + elif self.mode == "repeat": + v = (self.seed + idx) % self.vocab_size + tokens = torch.full((self.seq_len + 1, ), v, dtype=torch.long) + else: + raise ValueError(f"Unsupported data mode: {self.mode}") + return tokens[:-1], tokens[1:] + + +class WikitextDataset(Dataset): + """Real text dataset from HuggingFace wikitext-2 / wikitext-103.""" + + def __init__(self, vocab_size, seq_len, num_samples, split="train", dataset_name="wikitext-2-raw-v1"): + from datasets import load_dataset + from transformers import GPT2TokenizerFast + + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + raw = load_dataset("wikitext", dataset_name, split=split) + text = "\n\n".join([t for t in raw["text"] if t.strip()]) + all_ids = tokenizer.encode(text) + + self.seq_len = seq_len + self.samples = [] + for i in range(0, len(all_ids) - seq_len - 1, seq_len): + self.samples.append(torch.tensor(all_ids[i:i + seq_len + 1], dtype=torch.long)) + if len(self.samples) >= num_samples: + break + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + tokens = self.samples[idx] + return tokens[:-1], tokens[1:] + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def parse_args(): + parser = argparse.ArgumentParser(description="DeepSpeed ZeRO-3 training with allgather overlap") + parser.add_argument("--vocab_size", type=int, default=50257) + parser.add_argument("--hidden_size", type=int, default=4096) + parser.add_argument("--num_layers", type=int, default=48) + parser.add_argument("--num_heads", type=int, default=32) + parser.add_argument("--max_seq_len", type=int, default=2048) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--num_samples", type=int, default=10000) + parser.add_argument("--train_steps", type=int, default=2000) + parser.add_argument("--data_mode", + type=str, + default="random", + choices=["random", "arange", "repeat", "wikitext2", "wikitext103"], + help="Data mode. random/arange/repeat are synthetic; wikitext2/wikitext103 use real text.") + parser.add_argument("--local_rank", type=int, default=-1) + parser = deepspeed.add_config_arguments(parser) + return parser.parse_args() + + +def main(): + args = parse_args() + + ds_config_path = args.deepspeed_config + if ds_config_path and not os.path.isfile(ds_config_path): + script_dir = os.path.dirname(os.path.abspath(__file__)) + ds_config_path = os.path.join(script_dir, ds_config_path) + args.deepspeed_config = ds_config_path + + deepspeed.init_distributed(dist_backend="cpu:gloo,cuda:nccl") + local_rank = args.local_rank + get_accelerator().set_device(local_rank) + + torch.manual_seed(42) + get_accelerator().manual_seed_all(42) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + with deepspeed.zero.Init(config_dict_or_path=ds_config_path): + model = GPT2Model( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_layers=args.num_layers, + num_heads=args.num_heads, + max_seq_len=args.max_seq_len, + dropout=0.0, + ) + + total_params = sum(p.numel() for p in model.parameters()) + num_gpus = dist.get_world_size() + if local_rank == 0: + print(f"Model parameters: {total_params / 1e6:.1f}M") + print(f"GPUs: {num_gpus}") + + # FLOPs per token (forward + backward): 6*params + 12*L*H*S + # Reference: "Efficient Large-Scale Language Model Training on GPU Clusters + # Using Megatron-LM" (Narayanan et al., 2021) + flops_per_token = 6 * total_params + 12 * args.num_layers * args.hidden_size * args.max_seq_len + + if args.data_mode in ("wikitext2", "wikitext103"): + ds_name = "wikitext-2-raw-v1" if args.data_mode == "wikitext2" else "wikitext-103-raw-v1" + dataset = WikitextDataset(args.vocab_size, args.max_seq_len, args.num_samples, dataset_name=ds_name) + else: + dataset = SyntheticTextDataset(args.vocab_size, args.max_seq_len, args.num_samples, mode=args.data_mode) + if local_rank == 0: + if args.data_mode == "random": + print(f"Data mode: random (expected CE floor ~ ln(vocab) = {math.log(args.vocab_size):.4f})") + elif args.data_mode in ("wikitext2", "wikitext103"): + print(f"Data mode: {args.data_mode} (real text, {len(dataset)} samples)") + else: + print(f"Data mode: {args.data_mode} (learnable pattern, loss should decrease)") + + model_engine, optimizer, _, lr_scheduler = deepspeed.initialize( + args=args, + model=model, + ) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + shuffle=False, + seed=42, + ) + train_loader = DataLoader( + dataset, + batch_size=model_engine.train_micro_batch_size_per_gpu(), + sampler=sampler, + num_workers=0, + pin_memory=True, + ) + + device = model_engine.device + global_batch = model_engine.train_batch_size() + tokens_per_step = global_batch * args.max_seq_len + warmup_steps = min(50, args.train_steps // 10) + + step = 0 + step_times = [] + t_start = time.time() + t_steady = None + while step < args.train_steps: + for batch in train_loader: + if step >= args.train_steps: + break + + get_accelerator().synchronize() + t_step_start = time.time() + + input_ids = batch[0].to(device) + labels = batch[1].to(device) + + loss, _ = model_engine(input_ids, labels=labels) + model_engine.backward(loss) + model_engine.step() + + get_accelerator().synchronize() + step_time_ms = (time.time() - t_step_start) * 1000 + + if step == warmup_steps: + t_steady = time.time() + if step >= warmup_steps: + step_times.append(step_time_ms) + + if step % 10 == 0 and local_rank == 0: + if step_times: + import numpy as np + recent = np.array(step_times[-20:]) + avg_ms = recent.mean() + cur_samples_per_sec = global_batch / (avg_ms / 1000) + cur_tokens_per_sec = cur_samples_per_sec * args.max_seq_len + cur_tflops_per_gpu = cur_tokens_per_sec * flops_per_token / 1e12 / num_gpus + else: + avg_ms = step_time_ms + cur_tflops_per_gpu = 0.0 + cur_samples_per_sec = 0.0 + print(f"step {step:5d} | loss {loss.item():.4f} | " + f"lr {lr_scheduler.get_last_lr()[0]:.6f} | " + f"{cur_samples_per_sec:.1f} samples/s | " + f"{cur_tflops_per_gpu:.2f} TFLOPS/GPU | " + f"step {avg_ms:.1f} ms") + step += 1 + + if local_rank == 0: + import numpy as np + total_time = time.time() - t_start + st = np.array(step_times) + steady_steps = len(st) + steady_time = time.time() - t_steady if t_steady else total_time + + steady_samples_per_sec = steady_steps * global_batch / steady_time + steady_tokens_per_sec = steady_samples_per_sec * args.max_seq_len + steady_tflops = steady_tokens_per_sec * flops_per_token / 1e12 + steady_tflops_per_gpu = steady_tflops / num_gpus + + print(f"\n{'=' * 70}") + print(f"Training complete: {args.train_steps} steps in {total_time:.1f}s") + print(f" (warmup={warmup_steps} steps skipped, measured {steady_steps} steps)") + print(f"{'=' * 70}") + print(f" Throughput : {steady_samples_per_sec:.1f} samples/s") + print(f" TFLOPS : {steady_tflops:.1f} (total) | {steady_tflops_per_gpu:.2f} (per GPU)") + print(f" Step time (ms) : avg {st.mean():.1f} | p50 {np.median(st):.1f} | " + f"p99 {np.percentile(st, 99):.1f} | min {st.min():.1f} | max {st.max():.1f}") + print(f"{'=' * 70}") + + +if __name__ == "__main__": + main() diff --git a/install.sh b/install.sh index 6770924d1ef8..8be574c6ec1f 100755 --- a/install.sh +++ b/install.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e err_report() { @@ -121,7 +121,7 @@ rm_if_exist() { if [ -f $1 ]; then rm $VERBOSE $1 elif [ -d $1 ]; then - rm -r $VERBOSE $1 + rm -rf $VERBOSE $1 fi } @@ -152,7 +152,7 @@ if [ ! -f $hostfile ]; then fi echo "Building deepspeed wheel" -python setup.py $VERBOSE bdist_wheel +python -m build $VERBOSE --wheel --no-isolation if [ "$local_only" == "1" ]; then echo "Installing deepspeed" diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 2c55662df8ce..afe48159933c 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -15,7 +15,7 @@ # List of all available op builders from deepspeed op_builder try: - import deepspeed.ops.op_builder # noqa: F401 + import deepspeed.ops.op_builder # noqa: F401 # type: ignore op_builder_dir = "deepspeed.ops.op_builder" except ImportError: op_builder_dir = "op_builder" diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py index 9c41f35eaf1b..ff11ca180072 100644 --- a/op_builder/all_ops.py +++ b/op_builder/all_ops.py @@ -30,3 +30,4 @@ __op_builders__.append(builder) ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} +accelerator_name = get_accelerator()._name diff --git a/op_builder/async_io.py b/op_builder/async_io.py index 084cb10864cf..f59cc6810c6f 100644 --- a/op_builder/async_io.py +++ b/op_builder/async_io.py @@ -3,13 +3,14 @@ # DeepSpeed Team -import distutils.spawn +import os +import shutil import subprocess -from .builder import OpBuilder +from .builder import TorchCPUOpBuilder -class AsyncIOBuilder(OpBuilder): +class AsyncIOBuilder(TorchCPUOpBuilder): BUILD_VAR = "DS_BUILD_AIO" NAME = "async_io" @@ -19,38 +20,57 @@ def __init__(self): def absolute_name(self): return f'deepspeed.ops.aio.{self.NAME}_op' - def sources(self): - return [ - 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/py_ds_aio.cpp', - 'csrc/aio/py_lib/deepspeed_py_aio.cpp', 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', - 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', 'csrc/aio/common/deepspeed_aio_utils.cpp', - 'csrc/aio/common/deepspeed_aio_common.cpp', 'csrc/aio/common/deepspeed_aio_types.cpp', + def lib_sources(self): + src_list = [ + 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', 'csrc/aio/py_lib/deepspeed_py_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', + 'csrc/aio/common/deepspeed_aio_utils.cpp', 'csrc/aio/common/deepspeed_aio_common.cpp', + 'csrc/aio/common/deepspeed_aio_types.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp' ] + return src_list + + def sources(self): + return self.lib_sources() + ['csrc/aio/py_lib/py_ds_aio.cpp'] def include_paths(self): - return ['csrc/aio/py_lib', 'csrc/aio/common'] + import torch + if self.build_for_cpu: + CUDA_INCLUDE = [] + elif not self.is_rocm_pytorch(): + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + else: + CUDA_INCLUDE = [ + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), + ] + return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE def cxx_args(self): # -O0 for improved debugging, since performance is bound by I/O - CPU_ARCH = self.cpu_arch() - SIMD_WIDTH = self.simd_width() - return [ - '-g', - '-Wall', - '-O0', - '-std=c++14', - '-shared', - '-fPIC', - '-Wno-reorder', - CPU_ARCH, - '-fopenmp', - SIMD_WIDTH, - '-laio', - ] + args = super().cxx_args() + import torch + TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) + if not (TORCH_MAJOR >= 2 and TORCH_MINOR >= 1): + args.remove('-std=c++17') + args.append('-std=c++14') + args += ['-Wall', '-O0', '-shared', '-fPIC', '-Wno-reorder'] + return args def extra_ldflags(self): - return ['-laio'] + if self.build_for_cpu: + return ['-fopenmp'] + + import torch.utils.cpp_extension + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME + if CUDA_HOME is None: + ldflags = ['-laio'] # the ROCM case + else: + CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") + ldflags = [f'-L{CUDA_HOME}', f'-L{CUDA_LIB64}', '-laio', '-lcuda', '-lcudart'] + return ldflags def check_for_libaio_pkg(self): libs = dict( @@ -62,10 +82,10 @@ def check_for_libaio_pkg(self): found = False for pkgmgr, data in libs.items(): flag, lib, tool = data - path = distutils.spawn.find_executable(pkgmgr) + path = shutil.which(pkgmgr) if path is not None: - cmd = f"{pkgmgr} {flag} {lib}" - result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.wait() == 0: found = True else: @@ -73,7 +93,7 @@ def check_for_libaio_pkg(self): break return found - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): # Check for the existence of libaio by using distutils # to compile and link a test program that calls io_submit, # which is a function provided by libaio that is used in the async_io op. diff --git a/op_builder/builder.py b/op_builder/builder.py index 44d6a440c056..0fda11caeca1 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re import sys import time import importlib @@ -35,13 +36,19 @@ TORCH_MINOR = int(torch.__version__.split('.')[1]) +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + def installed_cuda_version(name=""): - import torch.cuda - if not torch.cuda.is_available(): - return 0, 0 import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME - assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") # Ensure there is not a cuda version mismatch between torch and nvcc compiler output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) output_split = output.split() @@ -54,32 +61,34 @@ def installed_cuda_version(name=""): def get_default_compute_capabilities(): compute_caps = DEFAULT_COMPUTE_CAPABILITIES + # Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported import torch.utils.cpp_extension - if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11: - if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0: - # Special treatment of CUDA 11.0 because compute_86 is not supported. - compute_caps += ";8.0" - else: - compute_caps += ";8.0;8.6" + if torch.utils.cpp_extension.CUDA_HOME is not None: + if installed_cuda_version()[0] == 11: + if installed_cuda_version()[1] >= 0: + compute_caps += ";8.0" + if installed_cuda_version()[1] >= 1: + compute_caps += ";8.6" + if installed_cuda_version()[1] >= 8: + compute_caps += ";9.0" + elif installed_cuda_version()[0] == 12: + compute_caps += ";8.0;8.6;9.0" + if installed_cuda_version()[1] >= 8: + compute_caps += ";10.0;12.0" return compute_caps # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used # to build deepspeed and system-wide installed cuda 11.2 cuda_minor_mismatch_ok = { - 10: [ - "10.0", - "10.1", - "10.2", - ], + 10: ["10.0", "10.1", "10.2"], 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], + 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6", "12.8", "12.9"], # There is no CUDATk 12.7 } def assert_no_cuda_mismatch(name=""): cuda_major, cuda_minor = installed_cuda_version(name) - if cuda_minor == 0 and cuda_major == 0: - return False sys_cuda_version = f'{cuda_major}.{cuda_minor}' torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) # This is a show-stopping error, should probably not proceed past this @@ -90,20 +99,33 @@ def assert_no_cuda_mismatch(name=""): f"version torch was compiled with {torch.version.cuda} " "but since the APIs are compatible, accepting this combination") return True - raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " - f"version torch was compiled with {torch.version.cuda}, unable to compile " - "cuda/cpp extensions without a matching cuda version.") + elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1": + print( + f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}." + "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." + ) + return True + raise CUDAMismatchException( + f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") return True class OpBuilder(ABC): _rocm_version = None + _rocm_gpu_arch = None + _rocm_wavefront_size = None _is_rocm_pytorch = None + _is_sycl_enabled = None + _loaded_ops = {} def __init__(self, name): self.name = name self.jit_mode = False self.build_for_cpu = False + self.enable_bf16 = False self.error_log = None @abstractmethod @@ -124,6 +146,9 @@ def sources(self): def hipify_extension(self): pass + def sycl_extension(self): + pass + @staticmethod def validate_torch_version(torch_info): install_torch_version = torch_info['version'] @@ -175,6 +200,22 @@ def is_rocm_pytorch(): OpBuilder._is_rocm_pytorch = _is_rocm_pytorch return OpBuilder._is_rocm_pytorch + @staticmethod + def is_sycl_enabled(): + if OpBuilder._is_sycl_enabled is not None: + return OpBuilder._is_sycl_enabled + + _is_sycl_enabled = False + try: + result = subprocess.run(["c2s", "--version"], capture_output=True) + except Exception: + pass + else: + _is_sycl_enabled = True + + OpBuilder._is_sycl_enabled = _is_sycl_enabled + return OpBuilder._is_sycl_enabled + @staticmethod def installed_rocm_version(): if OpBuilder._rocm_version: @@ -182,22 +223,69 @@ def installed_rocm_version(): ROCM_MAJOR = '0' ROCM_MINOR = '0' + ROCM_VERSION_DEV_RAW = "" if OpBuilder.is_rocm_pytorch(): from torch.utils.cpp_extension import ROCM_HOME - rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev") + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version") if rocm_ver_file.is_file(): with open(rocm_ver_file, 'r') as file: ROCM_VERSION_DEV_RAW = file.read() elif "rocm" in torch.__version__: ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] + if ROCM_VERSION_DEV_RAW != "": + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] else: + # Look in /usr/include/rocm-version.h + rocm_ver_file = Path("/usr/include/rocm_version.h") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + for ln in file.readlines(): + if "#define ROCM_VERSION_MAJOR" in ln: + ROCM_MAJOR = re.findall(r'\S+', ln)[2] + elif "#define ROCM_VERSION_MINOR" in ln: + ROCM_MINOR = re.findall(r'\S+', ln)[2] + if ROCM_MAJOR == '0': assert False, "Could not detect ROCm version" - assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version" - ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] - ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) return OpBuilder._rocm_version + @staticmethod + def get_rocm_gpu_arch(): + if OpBuilder._rocm_gpu_arch: + return OpBuilder._rocm_gpu_arch + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + try: + result = subprocess.check_output([str(rocm_info)], stderr=subprocess.DEVNULL) + output = result.decode('utf-8') + match = re.search(r'gfx\S+', output) + rocm_gpu_arch = match.group(0).strip() if match else "" + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + rocm_gpu_arch = "" + OpBuilder._rocm_gpu_arch = rocm_gpu_arch + return OpBuilder._rocm_gpu_arch + + @staticmethod + def get_rocm_wavefront_size(): + if OpBuilder._rocm_wavefront_size: + return OpBuilder._rocm_wavefront_size + + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + try: + result = subprocess.check_output([str(rocm_info)], stderr=subprocess.DEVNULL) + output = result.decode('utf-8') + match = re.search(r'Wavefront Size:\s+(\d+)', output) + rocm_wavefront_size = match.group(1) if match else "32" + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + rocm_wavefront_size = "32" + OpBuilder._rocm_wavefront_size = rocm_wavefront_size + return OpBuilder._rocm_wavefront_size + def include_paths(self): ''' Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) @@ -216,7 +304,7 @@ def cxx_args(self): ''' return [] - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): ''' Check if all non-python dependencies are satisfied to build this op ''' @@ -225,15 +313,7 @@ def is_compatible(self, verbose=True): def extra_ldflags(self): return [] - def libraries_installed(self, libraries): - valid = False - check_cmd = 'dpkg -l' - for lib in libraries: - result = subprocess.Popen(f'dpkg -l {lib}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) - valid = valid or result.wait() == 0 - return valid - - def has_function(self, funcname, libraries, verbose=False): + def has_function(self, funcname, libraries, library_dirs=None, verbose=False): ''' Test for existence of a function within a tuple of libraries. @@ -289,7 +369,8 @@ def has_function(self, funcname, libraries, verbose=False): compiler.link_executable(objs, os.path.join(tempdir, 'a.out'), extra_preargs=self.strip_empty_entries(ldflags), - libraries=libraries) + libraries=libraries, + library_dirs=library_dirs) # Compile and link succeeded return True @@ -300,7 +381,7 @@ def has_function(self, funcname, libraries, verbose=False): except LinkError: return False - except: + except Exception: return False finally: @@ -331,8 +412,8 @@ def cpu_arch(self): try: cpu_info = get_cpu_info() except Exception as e: - self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") + self.warning(f"{self.name} attempted to use py-cpuinfo but failed (exception type: {type(e)}, {e}), " + "falling back to lscpu to get this information.") cpu_info = self._backup_cpuinfo() if cpu_info is None: return "-march=native" @@ -340,15 +421,18 @@ def cpu_arch(self): if cpu_info['arch'].startswith('PPC_'): # gcc does not provide -march on PowerPC, use -mcpu instead return '-mcpu=native' + elif cpu_info['arch'].startswith('riscv64'): + return '-march=rv64gc' return '-march=native' - def is_cuda_enable(self): + def get_cuda_compile_flag(self): try: - if torch.cuda.is_available(): - return '-D__ENABLE_CUDA__' - except: - print(f"{WARNING} {self.name} torch.cuda is missing, only cpu ops can be compiled!") - return '-D__DISABLE_CUDA__' + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" + except MissingCUDAException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") return '-D__DISABLE_CUDA__' def _backup_cpuinfo(self): @@ -358,7 +442,7 @@ def _backup_cpuinfo(self): "to detect the CPU architecture. 'lscpu' does not appear to exist on " "your system, will fall back to use -march=native and non-vectorized execution.") return None - result = subprocess.check_output('lscpu', shell=True) + result = subprocess.check_output(['lscpu']) result = result.decode('utf-8').strip().lower() cpu_info = {} @@ -374,6 +458,8 @@ def _backup_cpuinfo(self): cpu_info['flags'] += 'avx2' elif 'ppc64le' in result: cpu_info['arch'] = "PPC_" + elif 'riscv64' in result: + cpu_info['arch'] = "riscv64" return cpu_info @@ -388,8 +474,8 @@ def simd_width(self): try: cpu_info = get_cpu_info() except Exception as e: - self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") + self.warning(f"{self.name} attempted to use py-cpuinfo but failed (exception type: {type(e)}, {e}), " + "falling back to lscpu to get this information.") cpu_info = self._backup_cpuinfo() if cpu_info is None: return '-D__SCALAR__' @@ -408,7 +494,8 @@ def command_exists(self, cmd): cmds = [cmd] valid = False for cmd in cmds: - result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + safe_cmd = ["bash", "-c", f"type {cmd}"] + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) valid = valid or result.wait() == 0 if not valid and len(cmds) > 1: @@ -429,22 +516,29 @@ def deepspeed_src_path(self, code_path): def builder(self): from torch.utils.cpp_extension import CppExtension + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] return CppExtension(name=self.absolute_name(), sources=self.strip_empty_entries(self.sources()), - include_dirs=self.strip_empty_entries(self.include_paths()), + include_dirs=include_dirs, extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, extra_link_args=self.strip_empty_entries(self.extra_ldflags())) - def load(self, verbose=True): - from deepspeed.git_version_info import installed_ops, torch_info - if installed_ops[self.name]: + def load(self, verbose=False): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name: # Ensure the op we're about to load was compiled with the same # torch/cuda versions we are currently using at runtime. self.validate_torch_version(torch_info) if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder): self.validate_torch_op_version(torch_info) - return importlib.import_module(self.absolute_name()) + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module else: return self.jit_load(verbose) @@ -453,47 +547,69 @@ def jit_load(self, verbose=True): raise RuntimeError( f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" ) + from torch.utils.cpp_extension import verify_ninja_availability try: - import ninja # noqa: F401 - except ImportError: - raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") + verify_ninja_availability() + except RuntimeError as e: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") from e if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): - self.build_for_cpu = not assert_no_cuda_mismatch(self.name) + self.build_for_cpu = not torch.cuda.is_available() + saved_jit_mode = self.jit_mode self.jit_mode = True + torch_arch_list_present = "TORCH_CUDA_ARCH_LIST" in os.environ + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + normalized_arch_list = torch_arch_list.strip() if torch_arch_list is not None else None + self._jit_arch_list = normalized_arch_list or None from torch.utils.cpp_extension import load start_build = time.time() - sources = [self.deepspeed_src_path(path) for path in self.sources()] - extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()] + sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] + extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] - # Torch will try and apply whatever CCs are in the arch list at compile time, - # we have already set the intended targets ourselves we know that will be - # needed at runtime. This prevents CC collisions such as multiple __half - # implementations. Stash arch list to reset after build. - torch_arch_list = None - if "TORCH_CUDA_ARCH_LIST" in os.environ: - torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") - os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - op_module = load(name=self.name, - sources=self.strip_empty_entries(sources), - extra_include_paths=self.strip_empty_entries(extra_include_paths), - extra_cflags=self.strip_empty_entries(self.cxx_args()), - extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()), - extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), - verbose=verbose) - - build_duration = time.time() - start_build - if verbose: - print(f"Time to load {self.name} op: {build_duration} seconds") + try: + nvcc_args = self.strip_empty_entries(self.nvcc_args()) + cxx_args = self.strip_empty_entries(self.cxx_args()) + + cxx_args.append("-UC10_USE_GLOG") + nvcc_args.append("-UC10_USE_GLOG") + if isinstance(self, CUDAOpBuilder): + if not self.build_for_cpu and self.enable_bf16: + cxx_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + + if self.is_rocm_pytorch(): + cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=cxx_args, + extra_cuda_cflags=nvcc_args, + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None, + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") - # Reset arch list so we are not silently removing it for other possible use cases - if torch_arch_list: - os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + __class__._loaded_ops[self.name] = op_module - return op_module + return op_module + finally: + if torch_arch_list_present: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + else: + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + self._jit_arch_list = None + self.jit_mode = saved_jit_mode class CUDAOpBuilder(OpBuilder): @@ -502,30 +618,57 @@ def compute_capability_args(self, cross_compile_archs=None): """ Returns nvcc compute capability compile flags. - 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. - 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + 1. Under ``jit_mode``, the precedence is: + a. preserved ``TORCH_CUDA_ARCH_LIST`` captured by ``jit_load()`` + b. live ``TORCH_CUDA_ARCH_LIST`` from the environment + c. runtime device probing when the process is not in a bad-fork context + d. an error when no explicit arch list exists in a bad-fork context + + JIT mode auto-adds ``+PTX`` to the highest compute capability when + no entry already carries it, then sets ``TORCH_CUDA_ARCH_LIST`` so + PyTorch can generate the ``-gencode`` flags itself. + 2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``. + 3. If neither is set default compute capabilities will be used. Format: - - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + - ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples: - TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... - TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... - - `cross_compile_archs` uses ; separator. + - ``cross_compile_archs`` uses ; separator. """ ccs = [] if self.jit_mode: - # Compile for underlying architectures since we know those at runtime - for i in range(torch.cuda.device_count()): - CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) - cc = f"{CC_MAJOR}.{CC_MINOR}" - if cc not in ccs: - ccs.append(cc) - ccs = sorted(ccs) - ccs[-1] += '+PTX' + arch_string = getattr(self, '_jit_arch_list', None) + if arch_string: + arch_string = arch_string.replace(' ', ';') + ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()] + else: + arch_string = os.environ.get('TORCH_CUDA_ARCH_LIST', '').strip() + if arch_string: + arch_string = arch_string.replace(' ', ';') + ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()] + else: + if hasattr(torch.cuda, '_is_in_bad_fork') and torch.cuda._is_in_bad_fork(): + raise RuntimeError( + f"DeepSpeed JIT builder for '{self.name}' cannot probe CUDA device capabilities " + "in a forked subprocess where CUDA has already been initialized. Set " + "TORCH_CUDA_ARCH_LIST to specify target architectures explicitly.") + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + if len(ccs) == 0: + raise RuntimeError(f"DeepSpeed JIT builder for '{self.name}' found no CUDA devices. Set " + "TORCH_CUDA_ARCH_LIST or make GPUs visible.") + + ccs = sorted(ccs, key=lambda cc: tuple(int(part.split('+')[0]) for part in cc.split('.'))) + if not any('+PTX' in cc for cc in ccs): + ccs[-1] += '+PTX' else: # Cross-compile mode, compile for various architectures # env override takes priority @@ -533,7 +676,7 @@ def compute_capability_args(self, cross_compile_archs=None): if cross_compile_archs_env is not None: if cross_compile_archs is not None: print( - f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" + f"{WARNING} env var TORCH_CUDA_ARCH_LIST={cross_compile_archs_env} overrides cross_compile_archs={cross_compile_archs}" ) cross_compile_archs = cross_compile_archs_env.replace(' ', ';') else: @@ -546,11 +689,47 @@ def compute_capability_args(self, cross_compile_archs=None): raise RuntimeError( f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") - args = [] + # Canonicalize by numeric (major, minor) so the emitted -gencode + # sequence matches PyTorch's own dedupe (see #7871). For mixed inputs + # such as "8.0;8.0+PTX" or "8.0+PTX;8.0", PyTorch collapses to one + # sm_80 entry plus one compute_80 PTX entry. Track has_PTX per arch + # as the OR across all variants of that arch so any +PTX appearance + # carries through after dedupe. + canonical = {} for cc in ccs: - num = cc[0] + cc[2] + major = int(cc[0]) + minor_part = cc[1] + has_ptx = minor_part.endswith('+PTX') + minor = int(minor_part.split('+')[0]) + key = (major, minor) + canonical[key] = canonical.get(key, False) or has_ptx + canonical_archs = sorted(canonical.items()) + + self.enable_bf16 = True + for (major, _minor), _has_ptx in canonical_archs: + if major <= 7: + self.enable_bf16 = False + + # Keep TORCH_CUDA_ARCH_LIST in sync with the filtered arch list so + # PyTorch does not re-add archs that filter_ccs() removed. Emit one + # token per arch using the X.Y or X.Y+PTX form, matching PyTorch's + # canonical parsing where +PTX on an arch token already implies both + # the sm and PTX emissions for that arch. + arch_tokens = [f"{major}.{minor}{'+PTX' if has_ptx else ''}" for (major, minor), has_ptx in canonical_archs] + os.environ["TORCH_CUDA_ARCH_LIST"] = ";".join(arch_tokens) + + if self.jit_mode: + # Let PyTorch generate -gencode flags from the env var. + return [] + + # Non-JIT: return explicit flags per builder for extra_compile_args. + # Emit exactly one sm_X line per arch, followed by one compute_X PTX + # line when any variant of that arch carried +PTX. + args = [] + for (major, minor), has_ptx in canonical_archs: + num = f"{major}{minor}" args.append(f'-gencode=arch=compute_{num},code=sm_{num}') - if cc.endswith('+PTX'): + if has_ptx: args.append(f'-gencode=arch=compute_{num},code=compute_{num}') return args @@ -560,7 +739,7 @@ def filter_ccs(self, ccs: List[str]): Prune any compute capabilities that are not compatible with the builder. Should log which CCs have been pruned. """ - return ccs + return [cc.split('.') for cc in ccs] def version_dependent_macros(self): # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 @@ -575,25 +754,45 @@ def version_dependent_macros(self): version_ge_1_5 = ['-DVERSION_GE_1_5'] return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): return super().is_compatible(verbose) def builder(self): - self.build_for_cpu = not assert_no_cuda_mismatch(self.name) + try: + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except MissingCUDAException: + self.build_for_cpu = True + if self.build_for_cpu: from torch.utils.cpp_extension import CppExtension as ExtensionBuilder else: from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder - + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ {'cxx': self.strip_empty_entries(self.cxx_args()), \ - 'nvcc': self.strip_empty_entries(self.nvcc_args())} + 'nvcc': self.strip_empty_entries(self.nvcc_args())} + + if not self.build_for_cpu and self.enable_bf16: + compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args['nvcc'].append("-DBF16_AVAILABLE") + + if self.is_rocm_pytorch(): + compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") + #cxx compiler args are required to compile cpp files + compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + #nvcc compiler args are required to compile hip files + compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + if self.get_rocm_gpu_arch(): + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() cuda_ext = ExtensionBuilder(name=self.absolute_name(), sources=self.strip_empty_entries(self.sources()), - include_dirs=self.strip_empty_entries(self.include_paths()), + include_dirs=include_dirs, libraries=self.strip_empty_entries(self.libraries_args()), - extra_compile_args=compile_args) + extra_compile_args=compile_args, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) if self.is_rocm_pytorch(): # hip converts paths to absolute, this converts back to relative @@ -626,7 +825,7 @@ def cxx_args(self): if sys.platform == "win32": return ['-O2'] else: - return ['-O3', '-std=c++14', '-g', '-Wno-reorder'] + return ['-O3', '-std=c++17', '-g', '-Wno-reorder'] def nvcc_args(self): if self.build_for_cpu: @@ -635,17 +834,32 @@ def nvcc_args(self): if self.is_rocm_pytorch(): ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() args += [ - '-std=c++14', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', + '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', '-U__HIP_NO_HALF2_OPERATORS__', '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR ] + self.enable_bf16 = True else: - cuda_major, _ = installed_cuda_version() + try: + nvcc_threads = int(os.getenv("DS_NVCC_THREADS", "")) + if nvcc_threads <= 0: + raise ValueError("") + except ValueError: + nvcc_threads = min(os.cpu_count(), 8) + + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major > 10: + if cuda_major == 12 and cuda_minor >= 5: + std_lib = '-std=c++20' + else: + std_lib = '-std=c++17' + else: + std_lib = '-std=c++14' args += [ - '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', - '-std=c++17' if sys.platform == "win32" and cuda_major > 10 else '-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__' + '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib, + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + f'--threads={nvcc_threads}' ] if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': args.append('--ptxas-options=-v') @@ -664,23 +878,32 @@ def libraries_args(self): class TorchCPUOpBuilder(CUDAOpBuilder): + def get_cuda_lib64_path(self): + import torch + if not self.is_rocm_pytorch(): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + if not os.path.exists(CUDA_LIB64): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") + else: + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + return CUDA_LIB64 + def extra_ldflags(self): if self.build_for_cpu: return ['-fopenmp'] if not self.is_rocm_pytorch(): - return ['-lcurand'] + ld_flags = ['-lcurand'] + if not self.build_for_cpu: + ld_flags.append(f'-L{self.get_cuda_lib64_path()}') + return ld_flags return [] def cxx_args(self): - import torch args = [] if not self.build_for_cpu: - if not self.is_rocm_pytorch(): - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") - else: - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + CUDA_LIB64 = self.get_cuda_lib64_path() args += super().cxx_args() args += [ @@ -692,7 +915,7 @@ def cxx_args(self): CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() + CUDA_ENABLE = self.get_cuda_compile_flag() args += [ CPU_ARCH, '-fopenmp', diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py new file mode 100644 index 000000000000..7084db8469f1 --- /dev/null +++ b/op_builder/cpu/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +from .comm import CCLCommBuilder, ShareMemCommBuilder +from .fused_adam import FusedAdamBuilder +from .cpu_adam import CPUAdamBuilder +from .no_impl import NotImplementedBuilder +from .async_io import AsyncIOBuilder diff --git a/op_builder/cpu/async_io.py b/op_builder/cpu/async_io.py new file mode 100644 index 000000000000..dcb9feabcfc3 --- /dev/null +++ b/op_builder/cpu/async_io.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import shutil +import subprocess + +from .builder import CPUOpBuilder + + +class AsyncIOBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_AIO" + NAME = "async_io" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.aio.{self.NAME}_op' + + def lib_sources(self): + src_list = [ + 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', 'csrc/aio/py_lib/deepspeed_py_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', + 'csrc/aio/common/deepspeed_aio_utils.cpp', 'csrc/aio/common/deepspeed_aio_common.cpp', + 'csrc/aio/common/deepspeed_aio_types.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_py_copy.cpp', + 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp' + ] + return src_list + + def sources(self): + return self.lib_sources() + ['csrc/aio/py_lib/py_ds_aio.cpp'] + + def include_paths(self): + return ['csrc/aio/py_lib', 'csrc/aio/common'] + + def cxx_args(self): + # -O0 for improved debugging, since performance is bound by I/O + args = super().cxx_args() + import torch + TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) + if not (TORCH_MAJOR >= 2 and TORCH_MINOR >= 1): + args.remove('-std=c++17') + args.append('-std=c++14') + args += ['-Wall', '-O0', '-shared', '-fPIC', '-Wno-reorder'] + return args + + def extra_ldflags(self): + return ['-laio', '-fopenmp'] + + def check_for_libaio_pkg(self): + libs = dict( + dpkg=["-l", "libaio-dev", "apt"], + pacman=["-Q", "libaio", "pacman"], + rpm=["-q", "libaio-devel", "yum"], + ) + + found = False + for pkgmgr, data in libs.items(): + flag, lib, tool = data + path = shutil.which(pkgmgr) + if path is not None: + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.wait() == 0: + found = True + else: + self.warning(f"{self.NAME}: please install the {lib} package with {tool}") + break + return found + + def is_compatible(self, verbose=False): + # Check for the existence of libaio by using distutils + # to compile and link a test program that calls io_submit, + # which is a function provided by libaio that is used in the async_io op. + # If needed, one can define -I and -L entries in CFLAGS and LDFLAGS + # respectively to specify the directories for libaio.h and libaio.so. + aio_compatible = self.has_function('io_submit', ('aio', )) + if verbose and not aio_compatible: + self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.") + + # Check for the libaio package via known package managers + # to print suggestions on which package to install. + self.check_for_libaio_pkg() + + self.warning( + "If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found." + ) + return super().is_compatible(verbose) and aio_compatible diff --git a/op_builder/cpu/builder.py b/op_builder/cpu/builder.py new file mode 100644 index 000000000000..d881842ad0b1 --- /dev/null +++ b/op_builder/cpu/builder.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class CPUOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + args = ['-O3', '-g', '-Wno-reorder'] + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH] + return args + + def libraries_args(self): + return [] diff --git a/op_builder/cpu/comm.py b/op_builder/cpu/comm.py new file mode 100644 index 000000000000..8f5a35d0b664 --- /dev/null +++ b/op_builder/cpu/comm.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from .builder import CPUOpBuilder + + +class CCLCommBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_CCL_COMM" + NAME = "deepspeed_ccl_comm" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/comm/ccl.cpp', 'csrc/cpu/comm/shm.cpp'] + + def include_paths(self): + includes = ['csrc/cpu/includes'] + return includes + + def cxx_args(self): + return ['-O2', '-fopenmp'] + + def is_compatible(self, verbose=False): + # TODO: add soft compatibility check for private binary release. + # a soft check, as in we know it can be trivially changed. + return super().is_compatible(verbose) + + def extra_ldflags(self): + ccl_root_path = os.environ.get("CCL_ROOT") + if ccl_root_path is None: + raise ValueError( + "Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable" + ) + else: + return ['-lccl', f'-L{ccl_root_path}/lib'] + + +class ShareMemCommBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_SHM_COMM" + NAME = "deepspeed_shm_comm" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/comm/shm_interface.cpp', 'csrc/cpu/comm/shm.cpp'] + + def include_paths(self): + includes = ['csrc/cpu/includes'] + return includes + + def cxx_args(self): + return ['-O2', '-fopenmp'] + + def is_compatible(self, verbose=False): + # TODO: add soft compatibility check for private binary release. + # a soft check, as in we know it can be trivially changed. + return super().is_compatible(verbose) diff --git a/op_builder/cpu/cpu_adam.py b/op_builder/cpu/cpu_adam.py new file mode 100644 index 000000000000..0c8438aea40d --- /dev/null +++ b/op_builder/cpu/cpu_adam.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class CPUAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py new file mode 100644 index 000000000000..34b43825b090 --- /dev/null +++ b/op_builder/cpu/fused_adam.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class FusedAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/cpu/no_impl.py b/op_builder/cpu/no_impl.py new file mode 100644 index 000000000000..69d114a9f1c0 --- /dev/null +++ b/op_builder/cpu/no_impl.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class NotImplementedBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on CPU backend.") + + def sources(self): + return [] diff --git a/op_builder/cpu_adagrad.py b/op_builder/cpu_adagrad.py index 6d70c93faac2..c05f71488950 100644 --- a/op_builder/cpu_adagrad.py +++ b/op_builder/cpu_adagrad.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os from .builder import TorchCPUOpBuilder @@ -18,30 +17,11 @@ def absolute_name(self): return f'deepspeed.ops.adagrad.{self.NAME}_op' def sources(self): - if self.build_for_cpu: - return ['csrc/adagrad/cpu_adagrad.cpp'] - - return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/adagrad/cpu_adagrad.cpp'] def libraries_args(self): args = super().libraries_args() - if self.build_for_cpu: - return args - - if not self.is_rocm_pytorch(): - args += ['curand'] return args def include_paths(self): - import torch - if self.build_for_cpu: - CUDA_INCLUDE = [] - elif not self.is_rocm_pytorch(): - CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [ - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), - ] - return ['csrc/includes'] + CUDA_INCLUDE + return ['csrc/includes'] diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 29cdced0d9f2..7f4c0847a8c4 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os from .builder import TorchCPUOpBuilder @@ -18,31 +17,11 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - if self.build_for_cpu: - return ['csrc/adam/cpu_adam.cpp'] - - return ['csrc/adam/cpu_adam.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def libraries_args(self): args = super().libraries_args() - if self.build_for_cpu: - return args - - if not self.is_rocm_pytorch(): - args += ['curand'] - return args def include_paths(self): - import torch - if self.build_for_cpu: - CUDA_INCLUDE = [] - elif not self.is_rocm_pytorch(): - CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] - else: - CUDA_INCLUDE = [ - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), - os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), - ] - return ['csrc/includes'] + CUDA_INCLUDE + return ['csrc/includes'] diff --git a/op_builder/cpu_lion.py b/op_builder/cpu_lion.py new file mode 100644 index 000000000000..9a60d99773b3 --- /dev/null +++ b/op_builder/cpu_lion.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import TorchCPUOpBuilder + + +class CPULionBuilder(TorchCPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_LION" + NAME = "cpu_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/dc.py b/op_builder/dc.py new file mode 100644 index 000000000000..15b25bf3393e --- /dev/null +++ b/op_builder/dc.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import TorchCPUOpBuilder + + +class DeepCompileBuilder(TorchCPUOpBuilder): + BUILD_VAR = "DS_BUILD_DEEP_COMPILE" + NAME = "dc" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/compile/deepcompile.cpp', 'csrc/compile/init.cpp', 'csrc/compile/z1.cpp', 'csrc/compile/z2.cpp', + 'csrc/compile/z3.cpp', 'csrc/compile/util.cpp' + ] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + import os + import torch + if self.build_for_cpu: + CUDA_INCLUDE = [] + elif not self.is_rocm_pytorch(): + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + else: + CUDA_INCLUDE = [ + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), + ] + return ['csrc/includes', 'csrc/compile'] + CUDA_INCLUDE diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py new file mode 100644 index 000000000000..90e902f4e191 --- /dev/null +++ b/op_builder/evoformer_attn.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder, installed_cuda_version +import importlib +import os +from pathlib import Path +import sys + + +class EvoformerAttnBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_EVOFORMER_ATTN" + NAME = "evoformer_attn" + CUTLASS_IGNORE = "DS_IGNORE_CUTLASS_DETECTION" + CUTLASS_PYTHON_BINDINGS = "DS_USE_CUTLASS_PYTHON_BINDINGS" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + self.cutlass_path = os.environ.get("CUTLASS_PATH") + self._resolved_cutlass_path = None + + def absolute_name(self): + return f"deepspeed.ops.{self.NAME}_op" + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ["-lcurand"] + else: + return [] + + def sources(self): + src_dir = "csrc/deepspeed4science/evoformer_attn" + return [f"{src_dir}/attention.cpp", f"{src_dir}/attention_back.cu", f"{src_dir}/attention_cu.cu"] + + def nvcc_args(self): + if os.environ.get("DS_EVOFORMER_GPU_ARCH"): + self.warning("DS_EVOFORMER_GPU_ARCH is deprecated and ignored for Evoformer builds. " + "Use TORCH_CUDA_ARCH_LIST to control build targets.") + return super().nvcc_args() + + def filter_ccs(self, ccs): + """Keep only Tensor Core capable targets (>= 7.0).""" + retained = [] + pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 7: + retained.append(cc) + else: + pruned.append(cc) + if pruned: + self.warning(f"Evoformer: excluding targets below SM 7.0: {pruned}. Tensor Core required.") + return retained + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile kernels") + return False + + if self.cutlass_path != self.CUTLASS_IGNORE: + try: + self.include_paths() + except (RuntimeError, ImportError) as exc: + if verbose: + self.warning(str(exc)) + return False + # Check version in case it is a CUTLASS_PATH points to a CUTLASS checkout + if self._resolved_cutlass_path is not None: + changelog_path = self._resolved_cutlass_path / "CHANGELOG.md" + else: + changelog_path = None + if changelog_path is not None and changelog_path.exists(): + with open(changelog_path, "r") as f: + if "3.1.0" not in f.read(): + if verbose: + self.warning("Please use CUTLASS version >= 3.1.0") + return False + + # Check CUDA and GPU capabilities + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split(".")[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 7: + if verbose: + self.warning("Please use a GPU with compute capability >= 7.0") + cuda_okay = False + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("Please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + @staticmethod + def _repo_root(): + return Path(__file__).resolve().parents[1] + + @staticmethod + def _dedupe_paths(paths): + deduped = [] + seen = set() + for path in paths: + path = Path(path).expanduser() + key = str(path) + if key not in seen: + seen.add(key) + deduped.append(path) + return deduped + + @staticmethod + def _env_paths(*names): + paths = [] + for name in names: + value = os.environ.get(name) + if not value: + continue + paths.extend(Path(path) for path in value.split(os.pathsep) if path) + return paths + + @staticmethod + def _python_package_cutlass_paths(): + try: + cutlass_library = importlib.import_module("cutlass_library") + except ImportError: + return [] + + candidates = [] + source_path = getattr(cutlass_library, "source_path", None) + if source_path is not None: + candidates.append(Path(source_path)) + + package_file = getattr(cutlass_library, "__file__", None) + if package_file is not None: + package_dir = Path(package_file).resolve().parent + candidates.extend([package_dir / "source", package_dir.parent, package_dir]) + return candidates + + def _candidate_cutlass_paths(self): + if self.cutlass_path == self.CUTLASS_PYTHON_BINDINGS: + candidates = self._python_package_cutlass_paths() + if candidates: + return candidates + self.warning("Please pip install nvidia-cutlass") + raise ImportError("Unable to locate CUTLASS from the nvidia-cutlass Python package") + + if self.cutlass_path: + return [Path(self.cutlass_path)] + + repo_root = self._repo_root() + python_prefixes = self._dedupe_paths([Path(sys.prefix), Path(sys.exec_prefix), Path(sys.base_prefix)]) + prefix_paths = self._env_paths("CUTLASS_ROOT", "CUTLASS_HOME", "CONDA_PREFIX", "VIRTUAL_ENV", + "CMAKE_PREFIX_PATH", "CUDA_HOME", "CUDA_PATH") + include_paths = self._env_paths("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH") + + return self._dedupe_paths([ + *self._python_package_cutlass_paths(), + *prefix_paths, + *python_prefixes, + *include_paths, + Path.cwd() / "cutlass", + repo_root / "cutlass", + repo_root.parent / "cutlass", + Path("/usr/local/cutlass"), + Path("/opt/cutlass"), + Path("/usr/local"), + Path("/usr"), + ]) + + @staticmethod + def _cutlass_include_dirs(cutlass_path): + cutlass_path = cutlass_path.expanduser().resolve() + if not cutlass_path.is_dir(): + return [] + + if (cutlass_path / "include" / "cutlass" / "cutlass.h").is_file(): + include_root = cutlass_path / "include" + util_include = cutlass_path / "tools" / "util" / "include" + elif (cutlass_path / "cutlass" / "cutlass.h").is_file(): + include_root = cutlass_path + util_include = cutlass_path.parent / "tools" / "util" / "include" + else: + return [] + + include_dirs = [include_root] + if util_include.is_dir(): + include_dirs.append(util_include) + return [str(include_dir) for include_dir in include_dirs] + + def include_paths(self): + # Assume the user knows best and CUTLASS location is already setup externally + if self.cutlass_path == self.CUTLASS_IGNORE: + return [] + + for cutlass_path in self._candidate_cutlass_paths(): + include_dirs = self._cutlass_include_dirs(cutlass_path) + if include_dirs: + self._resolved_cutlass_path = cutlass_path.expanduser().resolve() + return include_dirs + + if self.cutlass_path: + raise RuntimeError(f"CUTLASS_PATH {self.cutlass_path} does not contain CUTLASS headers") + + raise RuntimeError("Unable to locate CUTLASS. Install nvidia-cutlass, clone CUTLASS next to DeepSpeed, " + "or set CUTLASS_PATH to the CUTLASS checkout.") diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py new file mode 100644 index 000000000000..14276aa7c1d4 --- /dev/null +++ b/op_builder/fp_quantizer.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +try: + from packaging import version as pkg_version +except ImportError: + pkg_version = None + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class FPQuantizerBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FP_QUANTIZER" + NAME = "fp_quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 8: + if verbose: + self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + + try: + import triton + except ImportError: + if verbose: + self.warning( + "please install triton==2.3.0, 2.3.1 or 3.0.0 if you want to use the FP Quantizer Kernels") + return False + + # triton 2.3.{0,1} and 3.0.0 are ok. + allowed_versions = ("2.3", "3.0", "3.1", "3.2") + if pkg_version: + allowed = (pkg_version.parse(v) for v in allowed_versions) + installed_triton = pkg_version.parse(triton.__version__) + triton_mismatch = all(installed_triton.major != a.major or installed_triton.minor != a.minor + for a in allowed) + else: + installed_triton = triton.__version__ + major, minor, _ = installed_triton.split(".") + allowed = (v.split(".") for v in allowed_versions) + triton_mismatch = all(major != v[0] or minor != v[1] for v in allowed) + + if triton_mismatch: + if verbose: + self.warning( + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.{0,1} and 3.0.0 are known to be compatible with these kernels" + ) + return False + + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 8: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def sources(self): + return [ + "csrc/fp_quantizer/fp_quantize_impl.cu", + "csrc/fp_quantizer/fp_quantize.cpp", + ] + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + + def include_paths(self): + return ['csrc/fp_quantizer/includes', 'csrc/includes'] + + @staticmethod + def get_default_quant_dtype(): + import torch + return torch.uint8 + + @staticmethod + def get_quant_range(q_bits=None): + if q_bits == 8: + return 480 + elif q_bits == 6: + return 28. + elif q_bits == 12: + return 510. + else: + assert (0), \ + "Please specify the right quantization range for the selected precision!" diff --git a/op_builder/fused_lion.py b/op_builder/fused_lion.py new file mode 100644 index 000000000000..b900a8f2369d --- /dev/null +++ b/op_builder/fused_lion.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys + + +class FusedLionBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_LION" + NAME = "fused_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/fused_lion_frontend.cpp', 'csrc/lion/multi_tensor_lion.cu'] + + def include_paths(self): + return ['csrc/includes', 'csrc/lion'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo', '--use_fast_math'] + + self.compute_capability_args()) + return nvcc_flags diff --git a/op_builder/gds.py b/op_builder/gds.py new file mode 100644 index 000000000000..d815ccea387a --- /dev/null +++ b/op_builder/gds.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from .async_io import AsyncIOBuilder + + +class GDSBuilder(AsyncIOBuilder): + BUILD_VAR = "DS_BUILD_GDS" + NAME = "gds" + + def __init__(self): + super().__init__() + + def absolute_name(self): + return f'deepspeed.ops.gds.{self.NAME}_op' + + def lib_sources(self): + src_list = ['csrc/gds/py_lib/deepspeed_py_gds_handle.cpp', 'csrc/gds/py_lib/deepspeed_gds_op.cpp'] + return super().lib_sources() + src_list + + def sources(self): + return self.lib_sources() + ['csrc/gds/py_lib/py_ds_gds.cpp'] + + def cxx_args(self): + return super().cxx_args() + ['-lcufile'] + + def include_paths(self): + import torch + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE + + def extra_ldflags(self): + return super().extra_ldflags() + ['-lcufile'] + + def is_compatible(self, verbose=False): + if self.is_rocm_pytorch(): + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') + return False + + try: + import torch.utils.cpp_extension + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile GDS") + return False + + CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME + if CUDA_HOME is None: + if verbose: + self.warning("Please install torch CUDA if trying to pre-compile GDS with CUDA") + return False + + CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64") + gds_compatible = self.has_function(funcname="cuFileDriverOpen", + libraries=("cufile", ), + library_dirs=( + CUDA_HOME, + CUDA_LIB64, + ), + verbose=verbose) + + return gds_compatible and super().is_compatible(verbose) diff --git a/op_builder/hpu/__init__.py b/op_builder/hpu/__init__.py new file mode 100644 index 000000000000..5ad1b9a7f891 --- /dev/null +++ b/op_builder/hpu/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +from .cpu_adam import CPUAdamBuilder +from .fused_adam import FusedAdamBuilder +from .transformer_inference import InferenceBuilder +from .no_impl import NotImplementedBuilder diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py new file mode 100644 index 000000000000..11e710a8ee48 --- /dev/null +++ b/op_builder/hpu/builder.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class CPUOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + args = ['-O3', '-g', '-Wno-reorder'] + return args + + def libraries_args(self): + return [] diff --git a/op_builder/hpu/cpu_adam.py b/op_builder/hpu/cpu_adam.py new file mode 100644 index 000000000000..58eea2698ebb --- /dev/null +++ b/op_builder/hpu/cpu_adam.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class CPUAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/hpu/fp_quantizer.py b/op_builder/hpu/fp_quantizer.py new file mode 100644 index 000000000000..c74affb55045 --- /dev/null +++ b/op_builder/hpu/fp_quantizer.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class FPQuantizerBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_FP_QUANTIZER" + NAME = "fp_quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' + + def sources(self): + return [] + + def load(self, verbose=True): + return FPQuantizer + + @staticmethod + def get_default_quant_dtype(): + return torch.float8_e4m3fn + + @staticmethod + def get_quant_range(q_bits=None): + import habana_frameworks.torch.utils.experimental as htexp + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + dtype = torch.float8_e4m3fnuz + else: + dtype = torch.float8_e4m3fn + return torch.finfo(dtype).max + + +class FPQuantizer: + CUDA_IMPL = False + + @classmethod + def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits): + assert False, "Selective dequantize isn't implemented for HPU!" + + @classmethod + def dequantize(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits): + orig_shape = fp_out.shape + orig_dtype = fp_out.dtype + dequant_out = torch.ops.hpu.cast_from_fp8(input_q, (1.0 / scale), orig_dtype).view(orig_shape) + fp_out.copy_(dequant_out) + return fp_out + + @classmethod + def quantize(cls, out, val, scale, group_size, stochastic_rounding, q_bits, q_mantisa_bits): + assert q_bits == 8, "Quantize on HPU only supports quantization to FP8" + assert q_mantisa_bits == 3, "Quantize on HPU only supports q_mantissa_bits = 3" + assert out.dtype.is_floating_point, "Quantization on HPU is only to float dtypes" + + num_groups, group_size = out.shape + + # Reshape the tensor + val_reshaped = val.view(num_groups, group_size).float() + # Calculate the scale + max_vals = val_reshaped.abs().max(dim=1, keepdim=True)[0] + q_range = torch.finfo(out.dtype).max + tmp_scale = q_range / max_vals + scale.copy_(tmp_scale) + # Copy quantized + quant, _ = torch.ops.hpu.cast_to_fp8_v2(val_reshaped, scale, stochastic_rounding, dtype=out.dtype) + out.copy_(quant) + + return out + + @classmethod + def get_scales(cls, out, num_groups): + return out diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py new file mode 100644 index 000000000000..5acb121668e3 --- /dev/null +++ b/op_builder/hpu/fused_adam.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + +try: + import torch + import math +except ImportError as e: + pass + + +class HPUFusedAdam: + htcore = None + is_lazy_mode = None + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + if HPUFusedAdam.htcore is None: + from habana_frameworks.torch import core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + HPUFusedAdam.htcore = htcore + HPUFusedAdam.is_lazy_mode = is_lazy() + + htcore = HPUFusedAdam.htcore + + htcore.step_closure._mark_step_if_lazy() + step_size = lr + if bias_correction: + bias_correction1 = 1.0 - pow(beta1, step) + bias_correction2 = 1.0 - pow(beta2, step) + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + neg_step = -step_size + neg_step_t = (torch.tensor([neg_step], dtype=torch.float, + requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, + non_blocking=True)) + + weight_decay = weight_decay if adam_w_mode else 0 + + # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here + # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and + # perform weight decay unconditonally. + modified_wd = 1.0 - weight_decay * lr + + if HPUFusedAdam.is_lazy_mode: + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd, + ) + else: + modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to( + tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True)) + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd_t, + modified_wd != 1.0, + ) + + htcore.step_closure._mark_step_if_lazy() + + +class FusedAdamBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return HPUFusedAdam diff --git a/op_builder/hpu/no_impl.py b/op_builder/hpu/no_impl.py new file mode 100644 index 000000000000..140d65b48def --- /dev/null +++ b/op_builder/hpu/no_impl.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class NotImplementedBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on HPU backend.") + + def sources(self): + return [] diff --git a/op_builder/hpu/transformer_inference.py b/op_builder/hpu/transformer_inference.py new file mode 100644 index 000000000000..e397c99200ec --- /dev/null +++ b/op_builder/hpu/transformer_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +import importlib + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class InferenceBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=self.NAME) + + def absolute_name(self): + return f"deepspeed.ops.transformer.inference.{self.NAME}_op" + + def sources(self): + return [] + + def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from deepspeed.git_version_info import installed_ops # noqa: F401 + if installed_ops.get(self.name, False): + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py new file mode 100755 index 000000000000..d3b0f3aaeeb9 --- /dev/null +++ b/op_builder/inference_core_ops.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class InferenceCoreBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_INFERENCE_CORE_OPS" + NAME = "inference_core_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels{self.NAME}' + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/core_ops/core_ops.cpp", + "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp", + "inference/v2/kernels/core_ops/bias_activations/bias_activation_cuda.cu", + "inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp", + "inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu", + "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp", + "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu", + "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", + "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", + "inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp", + "inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + sources = [ + 'inference/v2/kernels/core_ops/bias_activations', + 'inference/v2/kernels/core_ops/blas_kernels', + 'inference/v2/kernels/core_ops/cuda_layer_norm', + 'inference/v2/kernels/core_ops/cuda_rms_norm', + 'inference/v2/kernels/core_ops/gated_activations', + 'inference/v2/kernels/core_ops/cuda_linear', + 'inference/v2/kernels/includes', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + + return sources diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py new file mode 100644 index 000000000000..5b2299e9c5cc --- /dev/null +++ b/op_builder/inference_cutlass_builder.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class InferenceCutlassBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_CUTLASS_OPS" + NAME = "cutlass_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels.cutlass_ops.{self.NAME}' + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 8: + # Only support Ampere and newer + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/cutlass_ops/cutlass_ops.cpp", + "inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu", + "inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + import dskernels + lib_path = dskernels.library_path() + prefix = self.get_prefix() + lib_path = os.path.join(prefix, lib_path) + lib_path = self.deepspeed_src_path(lib_path) + + args = [f'-L{lib_path}', '-ldeepspeedft'] + if self.jit_load: + args.append(f'-Wl,-rpath,{lib_path}') + return args + + def include_paths(self): + sources = [ + 'inference/v2/kernels/includes', + 'inference/v2/kernels/cutlass_ops/mixed_gemm', + 'inference/v2/kernels/cutlass_ops/moe_gemm', + 'inference/v2/kernels/cutlass_ops/shared_resources/', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources diff --git a/op_builder/mlu/__init__.py b/op_builder/mlu/__init__.py new file mode 100644 index 000000000000..db12afbbf20e --- /dev/null +++ b/op_builder/mlu/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +# MLU related operators will be added in the future. +from .no_impl import NotImplementedBuilder +from .cpu_adagrad import CPUAdagradBuilder +from .cpu_adam import CPUAdamBuilder +from .fused_adam import FusedAdamBuilder diff --git a/op_builder/mlu/builder.py b/op_builder/mlu/builder.py new file mode 100644 index 000000000000..17b9723ffcc1 --- /dev/null +++ b/op_builder/mlu/builder.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class MLUOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + return ['-O3', '-g', '-Wno-reorder'] + + def libraries_args(self): + return [] diff --git a/op_builder/mlu/cpu_adagrad.py b/op_builder/mlu/cpu_adagrad.py new file mode 100644 index 000000000000..68b7bbe514ee --- /dev/null +++ b/op_builder/mlu/cpu_adagrad.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import MLUOpBuilder + + +class CPUAdagradBuilder(MLUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAGRAD" + NAME = "cpu_adagrad" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adagrad.{self.NAME}_op' + + def sources(self): + return ['csrc/adagrad/cpu_adagrad.cpp'] + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/mlu/cpu_adam.py b/op_builder/mlu/cpu_adam.py new file mode 100644 index 000000000000..b3c8e476bf39 --- /dev/null +++ b/op_builder/mlu/cpu_adam.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import MLUOpBuilder + + +class CPUAdamBuilder(MLUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/mlu/fused_adam.py b/op_builder/mlu/fused_adam.py new file mode 100644 index 000000000000..2343e55568c4 --- /dev/null +++ b/op_builder/mlu/fused_adam.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import MLUOpBuilder + +try: + import torch +except ImportError as e: + pass + + +class MLUFusedAdam: + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + torch.ops.torch_mlu.fused_adam(noop_flag_buffer, tensor_lists[0], tensor_lists[1], tensor_lists[2], + tensor_lists[3], lr, beta1, beta2, epsilon, step, adam_w_mode, bias_correction, + weight_decay) + + +class FusedAdamBuilder(MLUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return MLUFusedAdam diff --git a/op_builder/mlu/no_impl.py b/op_builder/mlu/no_impl.py new file mode 100644 index 000000000000..375c148b4a5e --- /dev/null +++ b/op_builder/mlu/no_impl.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Copyright (c) 2024 Cambricon Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import MLUOpBuilder + + +class NotImplementedBuilder(MLUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on MLU backend.") + + def sources(self): + return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return [] diff --git a/op_builder/npu/__init__.py b/op_builder/npu/__init__.py new file mode 100644 index 000000000000..e512e25a2a29 --- /dev/null +++ b/op_builder/npu/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +from .fused_adam import FusedAdamBuilder +from .async_io import AsyncIOBuilder +from .no_impl import NotImplementedBuilder +from .cpu_adam import CPUAdamBuilder +from .cpu_adagrad import CPUAdagradBuilder +from .cpu_lion import CPULionBuilder +from .inference import InferenceBuilder diff --git a/op_builder/npu/async_io.py b/op_builder/npu/async_io.py new file mode 100644 index 000000000000..b87495758f0b --- /dev/null +++ b/op_builder/npu/async_io.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import shutil +import subprocess + +from .builder import NPUOpBuilder + + +class AsyncIOBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_AIO" + NAME = "async_io" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.aio.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/py_ds_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio.cpp', 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', + 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', 'csrc/aio/common/deepspeed_aio_utils.cpp', + 'csrc/aio/common/deepspeed_aio_common.cpp', 'csrc/aio/common/deepspeed_aio_types.cpp', + 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp', 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp' + ] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/aio/py_lib', 'csrc/aio/common'] + return args + + def cxx_args(self): + args = super().cxx_args() + # -O0 for improved debugging, since performance is bound by I/O + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + import torch # Keep this import here to avoid errors when building DeepSpeed wheel without torch installed + TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) + if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1: + CPP_STD = '-std=c++17' + else: + CPP_STD = '-std=c++14' + return args + [ + '-g', + '-Wall', + '-O0', + CPP_STD, + '-shared', + '-fPIC', + '-Wno-reorder', + CPU_ARCH, + '-fopenmp', + SIMD_WIDTH, + '-laio', + ] + + def extra_ldflags(self): + args = super().extra_ldflags() + return args + ['-laio'] + + def check_for_libaio_pkg(self): + libs = dict( + dpkg=["-l", "libaio-dev", "apt"], + pacman=["-Q", "libaio", "pacman"], + rpm=["-q", "libaio-devel", "yum"], + ) + + found = False + for pkgmgr, data in libs.items(): + flag, lib, tool = data + path = shutil.which(pkgmgr) + if path is not None: + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.wait() == 0: + found = True + else: + self.warning(f"{self.NAME}: please install the {lib} package with {tool}") + break + return found + + def is_compatible(self, verbose=False): + # Check for the existence of libaio by using distutils + # to compile and link a test program that calls io_submit, + # which is a function provided by libaio that is used in the async_io op. + # If needed, one can define -I and -L entries in CFLAGS and LDFLAGS + # respectively to specify the directories for libaio.h and libaio.so. + aio_compatible = self.has_function('io_pgetevents', ('aio', )) + if verbose and not aio_compatible: + self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.") + + # Check for the libaio package via known package managers + # to print suggestions on which package to install. + self.check_for_libaio_pkg() + + self.warning( + "If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found." + ) + return super().is_compatible(verbose) and aio_compatible diff --git a/op_builder/npu/builder.py b/op_builder/npu/builder.py new file mode 100644 index 000000000000..0dea2e78915e --- /dev/null +++ b/op_builder/npu/builder.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import re +import os +try: + import torch_npu +except ImportError as e: + pass + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class NPUOpBuilder(OpBuilder): + _ascend_path = None + _torch_npu_path = None + _cann_version = None + + def __init__(self, name): + super().__init__(name) + self._ascend_path = self.installed_cann_path() + self._torch_npu_path = os.path.join(os.path.dirname(os.path.abspath(torch_npu.__file__))) + try: + self._cann_version = self.installed_cann_version(self.name) + except BaseException: + print(f"{self.name} ascend_cann is missing, npu ops cannot be compiled!") + + def cann_defs(self): + if self._cann_version: + return '-D__ENABLE_CANN__' + return '-D__DISABLE_CANN__' + + def installed_cann_path(self): + if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]): + return os.environ["ASCEND_HOME_PATH"] + return None + + def installed_cann_version(self, name=""): + ascend_path = self.installed_cann_path() + assert ascend_path is not None, "CANN_HOME does not exist, unable to compile NPU op(s)" + cann_version = "" + for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)): + if cann_version: + break + install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)] + if install_files: + filepath = os.path.join(dirpath, install_files[0]) + with open(filepath, "r") as f: + for line in f: + if line.find("version") != -1: + cann_version = line.strip().split("=")[-1] + break + return cann_version + + def include_paths(self): + paths = super().include_paths() + paths += [os.path.join(self._ascend_path, 'include'), os.path.join(self._torch_npu_path, 'include')] + return paths + + def cxx_args(self): + args = super().cxx_args() + args += ['-O3', '-std=c++17', '-g', '-Wno-reorder', '-fopenmp'] + args += ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath'] + args += [ + self.cann_defs(), + self.cpu_arch(), + self.simd_width(), '-L' + os.path.join(self._ascend_path, 'lib64'), + '-L' + os.path.join(self._torch_npu_path, 'lib') + ] + return args + + def extra_ldflags(self): + flags = super().extra_ldflags() + flags += [ + '-L' + os.path.join(self._ascend_path, 'lib64'), '-lascendcl', + '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu' + ] + return flags diff --git a/op_builder/npu/cpu_adagrad.py b/op_builder/npu/cpu_adagrad.py new file mode 100644 index 000000000000..161bc82efe1c --- /dev/null +++ b/op_builder/npu/cpu_adagrad.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPUAdagradBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAGRAD" + NAME = "cpu_adagrad" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adagrad.{self.NAME}_op' + + def sources(self): + return ['csrc/adagrad/cpu_adagrad.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/cpu_adam.py b/op_builder/npu/cpu_adam.py new file mode 100644 index 000000000000..a4e9569c0f33 --- /dev/null +++ b/op_builder/npu/cpu_adam.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPUAdamBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/cpu_lion.py b/op_builder/npu/cpu_lion.py new file mode 100644 index 000000000000..6917e0fd03d0 --- /dev/null +++ b/op_builder/npu/cpu_lion.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPULionBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_LION" + NAME = "cpu_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/fused_adam.py b/op_builder/npu/fused_adam.py new file mode 100644 index 000000000000..d32103db7055 --- /dev/null +++ b/op_builder/npu/fused_adam.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + +try: + import torch_npu +except ImportError as e: + pass + + +class NPUFusedAdam: + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + bias_correction1 = beta1**(step - 1) + bias_correction2 = beta2**(step - 1) + + # iteration group['params'] + for i in range(len(tensor_lists[0])): + grad_flat = tensor_lists[0][i] + param_flat = tensor_lists[1][i] + m_flat = tensor_lists[2][i] + v_flat = tensor_lists[3][i] + + if adam_w_mode: + param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + epsilon, + grad_flat, + None, # max_grad_norm + False, # amsgrad + False, # maximize + out=(param_flat.data, m_flat, v_flat)) + else: + param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam( + bias_correction1, + bias_correction2, + lr, + beta1, + beta2, + epsilon, + grad_flat, + False, # use_locking + False, # use_nesterov + out=(param_flat.data, m_flat, v_flat)) + + +class FusedAdamBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return NPUFusedAdam diff --git a/op_builder/npu/inference.py b/op_builder/npu/inference.py new file mode 100644 index 000000000000..46f28c0d4011 --- /dev/null +++ b/op_builder/npu/inference.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import IntEnum +from .builder import NPUOpBuilder + +try: + import torch + import torch_npu +except ImportError as e: + pass + + +class ActivationFuncType(IntEnum): + UNKNOWN = 0 + GELU = 1 + ReLU = 2 + GATED_GELU = 3 + GATED_SILU = 4 + + +class InferenceContext: + _workspace = None + + _seed = 42 + _curr_offset = 0 + _stream = 0 + _free_memory_size = 0 + _num_tokens = 1 + _attention_unfused_workspace_offset = 0 + _workSpaceSize = 0 + + workSpaceSize = 0 + kv_caches = None + + @staticmethod + def reset_tokens(initial_tokens=1): + InferenceContext._num_tokens = initial_tokens + + @staticmethod + def current_tokens(): + return InferenceContext._num_tokens + + @staticmethod + def GetWorkSpace(): + return InferenceContext._workspace + + +class NPUInference: + + @staticmethod + def layer_norm(inputs, gamma, beta, epsilon): + return torch.nn.functional.layer_norm(inputs, [inputs.shape[-1]], gamma, beta, eps=epsilon) + + @staticmethod + def _qkv_gemm(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + inp_norm = torch.nn.functional.layer_norm(inputs, (inputs.shape[2], ), gamma, beta, eps) + weight = weight.t() if transpose else weight + tmp = torch.matmul(inp_norm, weight) + if add_bias: + tmp += bias + output = [tmp, inp_norm] + return output + + @staticmethod + def qkv_gemm_fp16(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + return NPUInference._qkv_gemm(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose) + + @staticmethod + def qkv_gemm_bf16(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + return NPUInference._qkv_gemm(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose) + + @staticmethod + def qkv_gemm_fp32(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + return NPUInference._qkv_gemm(inputs, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose) + + @staticmethod + def _bias_add_transform_0213(vals, bias, hidden_dim, seq_length, seq_offset, heads, num_kv, rotary_dim, + rotate_half, rotate_every_two, rope_theta): + bsz, _, _ = vals.shape + q = vals[..., :hidden_dim].reshape(bsz, seq_length, heads, -1) + k = vals[..., hidden_dim:hidden_dim + num_kv * (hidden_dim // heads)].reshape(bsz, seq_length, num_kv, -1) + v = vals[..., hidden_dim + num_kv * (hidden_dim // heads):] + + if rotary_dim > 0 and rotate_every_two: + # sin, cos may use cache + seq_id = torch.arange(0, seq_length).to("npu") + inv_freq = torch.arange(0, rotary_dim, 2) / rotary_dim + inv_freq = inv_freq.to("npu") + inv_freq = 1.0 / torch.pow(rope_theta, inv_freq) + inv_freq = torch.outer(seq_id, inv_freq) + sin = inv_freq.sin() + cos = inv_freq.cos() + # shape: [bsz=1, seq_len, heads=1, rotary_dim] + sin = sin.view(-1, seq_length, 1, rotary_dim // 2).repeat_interleave(2, dim=-1) + cos = cos.view(-1, seq_length, 1, rotary_dim // 2).repeat_interleave(2, dim=-1) + + q_pos, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_pos, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_pos = torch_npu.npu_rotary_mul(q_pos, cos, sin) + q = torch.cat([q_pos, q_pass], dim=-1) + k_pos = torch_npu.npu_rotary_mul(k_pos, cos, sin) + k = torch.cat([k_pos, k_pass], dim=-1) + + output = q.reshape(bsz, seq_length, -1).contiguous() # [b, s, H] + k_cache = k.reshape(bsz, seq_length, heads, -1).transpose(1, 2).contiguous() # [b, n, s, d] + v_cache = v.reshape(bsz, seq_length, heads, -1).transpose(1, 2).contiguous() # [b, n, s, d] + return output, k_cache, v_cache + + @staticmethod + def _softmax_context(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, num_kv, + norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, + num_layers, alibi, rope_theta): + bsz, seq_len, k = query_key_value.size() + k = k // (heads + 2 * (num_kv if num_kv > 0 else heads)) + hidden_dim = heads * k + + is_promt = seq_len > 1 + if not InferenceContext.kv_caches: + InferenceContext.kv_caches = [[None, None] for _ in range(num_layers)] + if is_promt: + InferenceContext.reset_tokens(seq_len) + InferenceContext.kv_caches[layer_id] = [None, None] + + soft_len = InferenceContext.current_tokens() + workspace = InferenceContext.GetWorkSpace() + seq_offset = 0 if is_promt else soft_len - 1 + + q, k, v = NPUInference._bias_add_transform_0213(vals=query_key_value, + bias=None, + hidden_dim=hidden_dim, + seq_length=seq_len, + seq_offset=seq_offset, + heads=heads, + num_kv=num_kv if num_kv > 0 else heads, + rotary_dim=rotary_dim, + rotate_half=rotate_half, + rotate_every_two=rotate_every_two, + rope_theta=rope_theta) + + if not is_promt: + k_cache, v_cache = InferenceContext.kv_caches[layer_id] + if k_cache is not None: + k = torch.cat([k_cache, k], dim=2) + v = torch.cat([v_cache, v], dim=2) + InferenceContext.kv_caches[layer_id] = [k, v] + seq_len = k.shape[2] + + layer_scale = max(1, layer_id) if len(alibi.size()) > 1 else 1.0 + alpha = norm_factor * norm_factor / layer_scale + + output = torch_npu.npu_fusion_attention(q, + k.transpose(1, 2).reshape(bsz, seq_len, -1).contiguous(), + v.transpose(1, 2).reshape(bsz, seq_len, -1).contiguous(), + heads, + "BSH", + pse=None, + padding_mask=None, + atten_mask=attn_mask.bool(), + scale=alpha, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1, + inner_precise=0)[0] + + return output, k, v + + @staticmethod + def softmax_context_fp16(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, num_kv, + norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, + num_layers, alibi, rope_theta): + return NPUInference._softmax_context(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, + heads, num_kv, norm_factor, triangular_masking, local_attention, + window_size, no_masking, layer_id, num_layers, alibi, rope_theta) + + @staticmethod + def softmax_context_bf16(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, num_kv, + norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, + num_layers, alibi, rope_theta): + return NPUInference._softmax_context(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, + heads, num_kv, norm_factor, triangular_masking, local_attention, + window_size, no_masking, layer_id, num_layers, alibi, rope_theta) + + @staticmethod + def softmax_context_fp32(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, num_kv, + norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id, + num_layers, alibi, rope_theta): + return NPUInference._softmax_context(query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, + heads, num_kv, norm_factor, triangular_masking, local_attention, + window_size, no_masking, layer_id, num_layers, alibi, rope_theta) + + @staticmethod + def _vector_matmul(input, weight, async_op, q_scale, q_int8, transposed_mode): + if transposed_mode: + return torch.matmul(input, weight.t()) + return torch.matmul(input, weight) + + @staticmethod + def vector_matmul_fp16(input, weight, async_op, q_scale, q_int8, transposed_mode): + return NPUInference._vector_matmul(input, weight, async_op, q_scale, q_int8, transposed_mode) + + @staticmethod + def vector_matmul_bf16(input, weight, async_op, q_scale, q_int8, transposed_mode): + return NPUInference._vector_matmul(input, weight, async_op, q_scale, q_int8, transposed_mode) + + @staticmethod + def vector_matmul_fp32(input, weight, async_op, q_scale, q_int8, transposed_mode): + return NPUInference._vector_matmul(input, weight, async_op, q_scale, q_int8, transposed_mode) + + @staticmethod + def _mlp_gemm(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, pre_layer_norm, + mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): + if mlp_after_attn: + residual_add = torch.nn.functional.layer_norm(input + residual + input_bias, (input.shape[-1], ), gamma, + beta, eps) + else: + residual_add = torch.nn.functional.layer_norm(input, (input.shape[-1], ), gamma, beta, eps) + + weight_interm = weight_interm.t() if transpose else weight_interm + tmp = torch.matmul(residual_add, weight_interm) + if mlp_act_func_type == ActivationFuncType.GELU: + tmp = torch.nn.functional.gelu(tmp + bias) + elif mlp_act_func_type == ActivationFuncType.ReLU: + tmp = torch.nn.functional.relu(tmp + bias) + else: + raise Exception('Unsupported ActivationFuncType {}'.format(mlp_act_func_type)) + output = torch.matmul(tmp, weight_out.t()) + return output, residual_add + + @staticmethod + def mlp_gemm_fp16(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, pre_layer_norm, + mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): + return NPUInference._mlp_gemm(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, + pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, + mlp_act_func_type, transpose) + + @staticmethod + def mlp_gemm_bf16(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, pre_layer_norm, + mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): + return NPUInference._mlp_gemm(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, + pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, + mlp_act_func_type, transpose) + + @staticmethod + def mlp_gemm_fp32(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, pre_layer_norm, + mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): + return NPUInference._mlp_gemm(input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, + pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, + mlp_act_func_type, transpose) + + @staticmethod + def _residual_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + if mlp_after_attn: + if pre_layer_norm: + tmp = (residual.float() + attention_output.float() + attention_bias.float() + + final_bias.float()) / mp_size + hidden_state.float() + else: + tmp = residual.float() + hidden_state.float() + final_bias.float() + else: + if add_bias: + residual += attention_bias.float() + tmp = hidden_state.float() + attention_output.float() + (residual.float() + final_bias.float()) / mp_size + + input_dtype = hidden_state.dtype + residual.set_(tmp.to(input_dtype)) + + @staticmethod + def residual_add_bias_fp16(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + return NPUInference._residual_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, + mp_size, mlp_after_attn, add_bias, pre_layer_norm) + + @staticmethod + def residual_add_bias_bf16(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + return NPUInference._residual_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, + mp_size, mlp_after_attn, add_bias, pre_layer_norm) + + @staticmethod + def residual_add_bias_fp32(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + return NPUInference._residual_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, + mp_size, mlp_after_attn, add_bias, pre_layer_norm) + + +class InferenceBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.transformer.inference.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return NPUInference diff --git a/op_builder/npu/no_impl.py b/op_builder/npu/no_impl.py new file mode 100644 index 000000000000..5b1771fabc22 --- /dev/null +++ b/op_builder/npu/no_impl.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class NotImplementedBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on NPU backend.") + + def sources(self): + return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return [] diff --git a/op_builder/quantizer.py b/op_builder/quantizer.py index a64d1603d1e5..0b5348e5af96 100644 --- a/op_builder/quantizer.py +++ b/op_builder/quantizer.py @@ -22,11 +22,17 @@ def sources(self): 'csrc/quantization/pt_binding.cpp', 'csrc/quantization/fake_quantizer.cu', 'csrc/quantization/quantize.cu', + 'csrc/quantization/quantize_intX.cu', 'csrc/quantization/dequantize.cu', + 'csrc/quantization/swizzled_quantize.cu', + 'csrc/quantization/quant_reduce.cu', ] def include_paths(self): return ['csrc/includes'] def extra_ldflags(self): - return ['-lcurand'] + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py new file mode 100644 index 000000000000..d96f437e16ca --- /dev/null +++ b/op_builder/ragged_ops.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class RaggedOpsBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_RAGGED_DEVICE_OPS" + NAME = "ragged_device_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels.ragged_ops.{self.NAME}' + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 8: + # Blocked flash has a dependency on Ampere + newer + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/ragged_ops/ragged_ops.cpp", + "inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp", + "inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp", + "inference/v2/kernels/ragged_ops/embed/embed.cpp", + "inference/v2/kernels/ragged_ops/embed/embed_cuda.cu", + "inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp", + "inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu", + "inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp", + "inference/v2/kernels/ragged_ops/logits_gather/logits_gather_cuda.cu", + "inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp", + "inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter_cuda.cu", + "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp", + "inference/v2/kernels/ragged_ops/moe_gather/moe_gather_cuda.cu", + "inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating_cuda.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + import dskernels + lib_path = dskernels.library_path() + + prefix = self.get_prefix() + lib_path = os.path.join(prefix, lib_path) + lib_path = self.deepspeed_src_path(lib_path) + + args = [f'-L{lib_path}', '-lblockedflash'] + if self.jit_load: + args.append(f'-Wl,-rpath,{lib_path}') + return args + + def include_paths(self): + sources = [ + 'inference/v2/kernels/includes', + 'inference/v2/kernels/ragged_ops', + 'inference/v2/kernels/ragged_ops/atom_builder', + 'inference/v2/kernels/ragged_ops/blocked_flash', + 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/includes', + 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', + 'inference/v2/kernels/ragged_ops/logits_gather', + 'inference/v2/kernels/ragged_ops/moe_gather', + 'inference/v2/kernels/ragged_ops/moe_scatter', + 'inference/v2/kernels/ragged_ops/ragged_helpers', + 'inference/v2/kernels/ragged_ops/top_k_gating', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py new file mode 100755 index 000000000000..6cf4e4b3153d --- /dev/null +++ b/op_builder/ragged_utils.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class RaggedUtilsBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_RAGGED_OPS" + NAME = "ragged_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.{self.NAME}' + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split('.') for cc in ccs]: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/ragged/csrc/fast_host_buffer.cu", + "inference/v2/ragged/csrc/ragged_ops.cpp", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + include_dirs = ['inference/v2/ragged/includes', 'inference/v2/kernels/includes'] + prefix = self.get_prefix() + includes = [os.path.join(prefix, include_dir) for include_dir in include_dirs] + + return includes diff --git a/op_builder/random_ltd.py b/op_builder/random_ltd.py index 3fdc777215da..54af7150fb36 100644 --- a/op_builder/random_ltd.py +++ b/op_builder/random_ltd.py @@ -31,7 +31,4 @@ def sources(self): def include_paths(self): includes = ['csrc/includes'] - if self.is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)] return includes diff --git a/op_builder/sdaa/__init__.py b/op_builder/sdaa/__init__.py new file mode 100755 index 000000000000..2a6eb8bbfa2c --- /dev/null +++ b/op_builder/sdaa/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. + +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +from .no_impl import NotImplementedBuilder +from .cpu_adam import CPUAdamBuilder +from .fused_adam import FusedAdamBuilder diff --git a/op_builder/sdaa/builder.py b/op_builder/sdaa/builder.py new file mode 100755 index 000000000000..81f0e98c0768 --- /dev/null +++ b/op_builder/sdaa/builder.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class SDAAOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + return ['-O3', '-g', '-Wno-reorder'] + + def libraries_args(self): + return [] diff --git a/op_builder/sdaa/cpu_adam.py b/op_builder/sdaa/cpu_adam.py new file mode 100755 index 000000000000..b38a71a6275d --- /dev/null +++ b/op_builder/sdaa/cpu_adam.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team + +from .builder import SDAAOpBuilder + + +class CPUAdamBuilder(SDAAOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/sdaa/fused_adam.py b/op_builder/sdaa/fused_adam.py new file mode 100755 index 000000000000..73a2dff41459 --- /dev/null +++ b/op_builder/sdaa/fused_adam.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team + +from .builder import SDAAOpBuilder + +try: + import torch +except ImportError as e: + pass + + +class SDAAFusedAdam: + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + g_tensor_lis, p_tensor_lis, m_tensor_lis, v_tensor_lis = tensor_lists + torch.ops.sdaa.fused_adam(g_tensor_lis, p_tensor_lis, m_tensor_lis, v_tensor_lis, [], beta1, beta2, epsilon, + lr, weight_decay, adam_w_mode, step, bias_correction) + + +class FusedAdamBuilder(SDAAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] + + def load(self, verbose=True): + return SDAAFusedAdam diff --git a/op_builder/sdaa/no_impl.py b/op_builder/sdaa/no_impl.py new file mode 100755 index 000000000000..10a8b8f48652 --- /dev/null +++ b/op_builder/sdaa/no_impl.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# BSD 3- Clause License Copyright (c) 2023, Tecorigin Co., Ltd. All rights +# reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY,OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY +# WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. + +# DeepSpeed Team + +from .builder import SDAAOpBuilder + + +class NotImplementedBuilder(SDAAOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on SDAA backend.") + + def sources(self): + return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return [] diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 7fae7d997dd2..05d8ed3753a6 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -27,44 +27,50 @@ def sources(self): def cxx_args(self): return ['-O2', '-fopenmp'] - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): # Check to see if llvm and cmake are installed since they are dependencies #required_commands = ['llvm-config|llvm-config-9', 'cmake'] #command_status = list(map(self.command_exists, required_commands)) #deps_compatible = all(command_status) if self.is_rocm_pytorch(): - self.warning(f'{self.NAME} is not compatible with ROCM') + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') return False try: import torch except ImportError: - self.warning(f"unable to import torch, please install it first") + if verbose: + self.warning("unable to import torch, please install it first") return False # torch-cpu will not have a cuda version if torch.version.cuda is None: cuda_compatible = False - self.warning(f"{self.NAME} cuda is not available from torch") + if verbose: + self.warning(f"{self.NAME} cuda is not available from torch") else: major, minor = torch.version.cuda.split('.')[:2] cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11) if not cuda_compatible: - self.warning(f"{self.NAME} requires CUDA version 10.1+") + if verbose: + self.warning(f"{self.NAME} requires CUDA version 10.1+") TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) - torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5 + torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5) if not torch_compatible: - self.warning(f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}') - + if verbose: + self.warning( + f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}') try: import triton except ImportError: # auto-install of triton is broken on some systems, reverting to manual install for now - # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710 - self.warning(f"please install triton==1.0.0 if you want to use sparse attention") + # see this issue: https://github.com/deepspeedai/DeepSpeed/issues/1710 + if verbose: + self.warning("please install triton==1.0.0 if you want to use sparse attention") return False if pkg_version: @@ -75,7 +81,9 @@ def is_compatible(self, verbose=True): triton_mismatch = installed_triton != "1.0.0" if triton_mismatch: - self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") + if verbose: + self.warning( + f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") return False return super().is_compatible(verbose) and torch_compatible and cuda_compatible diff --git a/op_builder/spatial_inference.py b/op_builder/spatial_inference.py index 59caf57f938d..57714c8c6bf5 100644 --- a/op_builder/spatial_inference.py +++ b/op_builder/spatial_inference.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import os from .builder import CUDAOpBuilder, installed_cuda_version @@ -17,22 +18,25 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.spatial.{self.NAME}_op' - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True - if not self.is_rocm_pytorch() and torch.cuda.is_available(): - sys_cuda_major, _ = installed_cuda_version() - torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties(0).major - if cuda_capability >= 8: - if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") - cuda_okay = False + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False return super().is_compatible(verbose) and cuda_okay def sources(self): diff --git a/op_builder/transformer.py b/op_builder/transformer.py index 893145d44d94..8db30fdc6791 100644 --- a/op_builder/transformer.py +++ b/op_builder/transformer.py @@ -33,7 +33,4 @@ def sources(self): def include_paths(self): includes = ['csrc/includes'] - if self.is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)] return includes diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index c7b95883cebf..2507ee4ee692 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import os from .builder import CUDAOpBuilder, installed_cuda_version @@ -17,31 +18,35 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.transformer.inference.{self.NAME}_op' - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True - if not self.is_rocm_pytorch() and torch.cuda.is_available(): - sys_cuda_major, _ = installed_cuda_version() - torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties(0).major - if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") - cuda_okay = False - if cuda_capability >= 8: - if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if not os.environ.get("DS_IGNORE_CUDA_DETECTION"): + if not self.is_rocm_pytorch() and torch.cuda.is_available(): + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False return super().is_compatible(verbose) and cuda_okay def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 6: ccs_retained.append(cc) else: @@ -56,10 +61,12 @@ def sources(self): 'csrc/transformer/inference/csrc/gelu.cu', 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/layer_norm.cu', + 'csrc/transformer/inference/csrc/rms_norm.cu', 'csrc/transformer/inference/csrc/softmax.cu', 'csrc/transformer/inference/csrc/dequantize.cu', 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu', 'csrc/transformer/inference/csrc/transform.cu', + 'csrc/transformer/inference/csrc/pointwise_ops.cu', ] def extra_ldflags(self): @@ -70,3 +77,14 @@ def extra_ldflags(self): def include_paths(self): return ['csrc/transformer/inference/includes', 'csrc/includes'] + + def nvcc_args(self): + args = super().nvcc_args() + """BF16 is supported on AMD, but including `cuda_bf16.h` (`` after hipification) + in host-only translation units (*.cpp files) fails because GPU-specific builtins are pulled in with the BF16 type. + This cannot be avoided via forward declarations for this transformer_inference extension, + since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now. + """ + if self.is_rocm_pytorch(): + self.enable_bf16 = False + return args diff --git a/op_builder/utils.py b/op_builder/utils.py index c10953397775..927ff1b361bd 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -17,4 +17,4 @@ def absolute_name(self): return f'deepspeed.ops.{self.NAME}_op' def sources(self): - return ['csrc/utils/flatten_unflatten.cpp'] + return ['csrc/utils/tensor_cast.cpp', 'csrc/utils/py_ds_utils.cpp'] diff --git a/op_builder/xpu/__init__.py b/op_builder/xpu/__init__.py new file mode 100755 index 000000000000..a91775a3bb13 --- /dev/null +++ b/op_builder/xpu/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cpu_adam import CPUAdamBuilder +from .cpu_adagrad import CPUAdagradBuilder +from .fused_adam import FusedAdamBuilder +from .async_io import AsyncIOBuilder +from .flash_attn import FlashAttentionBuilder +from .no_impl import NotImplementedBuilder +from .packbits import PackbitsBuilder diff --git a/op_builder/xpu/async_io.py b/op_builder/xpu/async_io.py new file mode 100644 index 000000000000..8ec030880368 --- /dev/null +++ b/op_builder/xpu/async_io.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import shutil +import subprocess + +from .builder import OpBuilder + + +class AsyncIOBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_AIO" + NAME = "async_io" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.aio.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/aio/py_lib/deepspeed_py_copy.cpp', + 'csrc/aio/py_lib/py_ds_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio.cpp', + 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', + 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', + 'csrc/aio/common/deepspeed_aio_utils.cpp', + 'csrc/aio/common/deepspeed_aio_common.cpp', + 'csrc/aio/common/deepspeed_aio_types.cpp', + 'csrc/aio/py_lib/deepspeed_pin_tensor.cpp', + 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', + 'csrc/aio/py_lib/deepspeed_cpu_op.cpp', + 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', + ] + + def include_paths(self): + return ['csrc/aio/py_lib', 'csrc/aio/common'] + + def cxx_args(self): + import torch + # -O0 for improved debugging, since performance is bound by I/O + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) + if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1: + CPP_STD = '-std=c++17' + else: + CPP_STD = '-std=c++14' + return [ + '-g', + '-Wall', + '-O0', + CPP_STD, + '-shared', + '-fPIC', + '-Wno-reorder', + CPU_ARCH, + '-fopenmp', + SIMD_WIDTH, + '-laio', + ] + + def extra_ldflags(self): + return ['-laio'] + + def check_for_libaio_pkg(self): + libs = dict( + dpkg=["-l", "libaio-dev", "apt"], + pacman=["-Q", "libaio", "pacman"], + rpm=["-q", "libaio-devel", "yum"], + ) + + found = False + for pkgmgr, data in libs.items(): + flag, lib, tool = data + path = shutil.which(pkgmgr) + if path is not None: + cmd = [pkgmgr, flag, lib] + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.wait() == 0: + found = True + else: + self.warning(f"{self.NAME}: please install the {lib} package with {tool}") + break + return found + + def is_compatible(self, verbose=False): + # Check for the existence of libaio by using distutils + # to compile and link a test program that calls io_submit, + # which is a function provided by libaio that is used in the async_io op. + # If needed, one can define -I and -L entries in CFLAGS and LDFLAGS + # respectively to specify the directories for libaio.h and libaio.so. + aio_compatible = self.has_function('io_pgetevents', ('aio', )) + if verbose and not aio_compatible: + self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.") + + # Check for the libaio package via known package managers + # to print suggestions on which package to install. + self.check_for_libaio_pkg() + + self.warning( + "If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found." + ) + return super().is_compatible(verbose) and aio_compatible diff --git a/op_builder/xpu/builder.py b/op_builder/xpu/builder.py new file mode 100644 index 000000000000..012a9840ef26 --- /dev/null +++ b/op_builder/xpu/builder.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import time +import importlib + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class SYCLOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import SyclExtension + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + print("sycl sources = {}".format(self.sources())) + sycl_ext = SyclExtension(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + extra_compile_args={ + 'cxx': self.strip_empty_entries(self.cxx_args()), + }, + extra_link_args=self.strip_empty_entries(self.fixed_aotflags())) + return sycl_ext + + def version_dependent_macros(self): + try: + from op_builder.builder import TORCH_MAJOR, TORCH_MINOR + except ImportError: + from deepspeed.ops.op_builder.builder import TORCH_MAJOR, TORCH_MINOR + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def _sycl_env_paths(self): + """Find the SYCL include and lib directories from the Python environment. + + When using PyTorch XPU wheels, libsycl.so and SYCL headers are + installed into the Python environment (e.g. conda env). The system + ``icpx`` compiler ships its own (potentially newer) SYCL headers and + runtime. To avoid ABI mismatches we must compile and link against the + *same* SYCL version that PyTorch was built with. + + Returns (include_dir, lib_dir) – either or both may be ``None`` when + the paths do not exist. + """ + import sys + prefix = sys.prefix # e.g. /home/user/miniforge3/envs/myenv + inc = os.path.join(prefix, 'include') + lib = os.path.join(prefix, 'lib') + sycl_inc = inc if os.path.isdir(os.path.join(inc, 'sycl')) else None + sycl_lib = lib if os.path.isfile(os.path.join(lib, 'libsycl.so')) else None + return sycl_inc, sycl_lib + + def cxx_args(self): + cxx_flags = [ + '-fsycl', + '-fsycl-targets=spir64', + '-g', + '-gdwarf-4', + '-O3', + '-std=c++17', + '-fPIC', + '-DMKL_ILP64', + '-fno-strict-aliasing', + ] + # Use SYCL headers from the Python environment so that compiled code + # references symbols present in the *environment's* libsycl.so rather + # than the (possibly newer) system oneAPI installation. + sycl_inc, _ = self._sycl_env_paths() + if sycl_inc: + cxx_flags = [f'-isystem', sycl_inc] + cxx_flags + if os.environ.get('USE_MKL_GEMM'): + cxx_flags.append('-DUSE_MKL_GEMM') + return cxx_flags + + def extra_ldflags(self): + import torch + torch_lib_dir = os.path.join(os.path.dirname(torch.__file__), 'lib') + flags = [ + '-fPIC', + '-fsycl', + '-fsycl-targets=spir64', + '-Xs "-options -cl-intel-enable-auto-large-GRF-mode"', + '-fsycl-max-parallel-link-jobs=8', + '-Wl,-export-dynamic', + f'-L{torch_lib_dir}', + f'-Wl,-rpath,{torch_lib_dir}', + ] + # Link against the Python environment's libsycl.so to match the + # headers we compiled against (see cxx_args). + _, sycl_lib = self._sycl_env_paths() + if sycl_lib: + flags = [f'-L{sycl_lib}', f'-Wl,-rpath,{sycl_lib}'] + flags + return flags + + def fixed_aotflags(self): + return [ + '-fsycl', '-fsycl-targets=spir64', '-fsycl-max-parallel-link-jobs=8', + '-Xs "-options -cl-intel-enable-auto-large-GRF-mode"' + ] + + def load(self, verbose=True): + from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name # noqa: F401 + from deepspeed.accelerator import get_accelerator + if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name: + return importlib.import_module(self.absolute_name()) + else: + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(verbose): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" + ) + from torch.utils.cpp_extension import verify_ninja_availability + try: + verify_ninja_availability() + except RuntimeError as e: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") from e + + self.jit_mode = True + from torch.utils.cpp_extension import load + + start_build = time.time() + # Recognize relative paths as absolute paths for jit load + + sources = [self.deepspeed_src_path(path) for path in self.sources()] + extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()] + + # Set CXX to icpx (Intel oneAPI DPC++ compiler) so that .cpp/.dp.cpp + # files containing SYCL code are compiled with the SYCL-aware compiler. + # PyTorch's cpp_extension only routes .sycl files to icpx by default. + saved_env = {} + for var in ('CXX', 'LIBRARY_PATH', 'CPATH'): + saved_env[var] = os.environ.get(var) + os.environ['CXX'] = 'icpx' + + # Point icpx at the Python environment's SYCL headers and libraries so + # the compiled extension uses the same SYCL ABI as PyTorch. + sycl_inc, sycl_lib = self._sycl_env_paths() + if sycl_lib: + lib_path = os.environ.get('LIBRARY_PATH', '') + os.environ['LIBRARY_PATH'] = f'{sycl_lib}:{lib_path}' if lib_path else sycl_lib + if sycl_inc: + cpath = os.environ.get('CPATH', '') + os.environ['CPATH'] = f'{sycl_inc}:{cpath}' if cpath else sycl_inc + + try: + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=self.strip_empty_entries(self.cxx_args()), + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + verbose=verbose) + finally: + # Restore original environment + for var, val in saved_env.items(): + if val is None: + os.environ.pop(var, None) + else: + os.environ[var] = val + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + return op_module diff --git a/op_builder/xpu/cpu_adagrad.py b/op_builder/xpu/cpu_adagrad.py new file mode 100644 index 000000000000..18f80848e1b8 --- /dev/null +++ b/op_builder/xpu/cpu_adagrad.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import SYCLOpBuilder + + +class CPUAdagradBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAGRAD" + NAME = "cpu_adagrad" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adagrad.{self.NAME}_op' + + def sources(self): + return ['csrc/xpu/adagrad/cpu_adagrad.cpp', 'csrc/xpu/common/custom_cuda_kernel.dp.cpp'] + + def include_paths(self): + return ['csrc/xpu/includes'] diff --git a/op_builder/xpu/cpu_adam.py b/op_builder/xpu/cpu_adam.py new file mode 100644 index 000000000000..9d5fdcd3e8ab --- /dev/null +++ b/op_builder/xpu/cpu_adam.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import SYCLOpBuilder + + +class CPUAdamBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/xpu/flash_attn.py b/op_builder/xpu/flash_attn.py new file mode 100644 index 000000000000..88a639c62a40 --- /dev/null +++ b/op_builder/xpu/flash_attn.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .builder import SYCLOpBuilder + + +class FlashAttentionBuilderObject(): + + def __init__(self): + pass + + # general functions + def flash_attn_func_v2(self, q, k, v, dropout_p, softmax_scale, is_causal): + import torch + return torch.nn.functional.scaled_dot_product_attention(q, + k, + v, + dropout_p=dropout_p, + is_causal=is_causal, + scale=softmax_scale) + + +class FlashAttentionBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_FlashAttention" + NAME = "flash_attn" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return + + def include_paths(self): + return [] + + def extra_ldflags(self): + return [] + + def cxx_args(self): + return [] + + def load(self): + return FlashAttentionBuilderObject() diff --git a/op_builder/xpu/fused_adam.py b/op_builder/xpu/fused_adam.py new file mode 100644 index 000000000000..0e0f1a66f8e6 --- /dev/null +++ b/op_builder/xpu/fused_adam.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .builder import SYCLOpBuilder + + +class FusedAdamBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/xpu/adam/fused_adam_frontend.cpp', 'csrc/xpu/adam/multi_tensor_adam.dp.cpp'] + + def include_paths(self): + return ['csrc/xpu/includes', 'csrc/xpu/adam'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() diff --git a/op_builder/xpu/no_impl.py b/op_builder/xpu/no_impl.py new file mode 100644 index 000000000000..8b294f70c279 --- /dev/null +++ b/op_builder/xpu/no_impl.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import SYCLOpBuilder + + +class NotImplementedBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on XPU backend.") + + def sources(self): + return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return [] diff --git a/op_builder/xpu/packbits.py b/op_builder/xpu/packbits.py new file mode 100644 index 000000000000..cf5b5ebc59e4 --- /dev/null +++ b/op_builder/xpu/packbits.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .builder import SYCLOpBuilder + + +class PackbitsBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_PACK_BITS" + NAME = "pack_bits" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return ['csrc/xpu/packbits/packing.cpp'] + + def include_paths(self): + return ['csrc/xpu/includes'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() diff --git a/release/bump_patch_version.py b/release/bump_patch_version.py index 84cb45a8eac8..20827011d368 100644 --- a/release/bump_patch_version.py +++ b/release/bump_patch_version.py @@ -3,12 +3,20 @@ # DeepSpeed Team +import argparse from packaging import version as pkg_version -with open('../version.txt') as fd: - version = pkg_version.parse(fd.read()) +parser = argparse.ArgumentParser() -with open('../version.txt', 'w') as fd: - fd.write(f'{version.major}.{version.minor}.{version.micro + 1}\n') +parser.add_argument("--current_version", + type=str, + help="The current version being published to help set the next version.") -print(f'{version} -> {version.major}.{version.minor}.{version.micro + 1}') +args = parser.parse_args() + +current_version = pkg_version.parse(args.current_version) + +with open('./version.txt', 'w') as fd: + fd.write(f'{current_version.major}.{current_version.minor}.{current_version.micro + 1}\n') + +print(f'{current_version} -> {current_version.major}.{current_version.minor}.{current_version.micro + 1}') diff --git a/release/check_release_version.py b/release/check_release_version.py new file mode 100644 index 000000000000..148fa8aa3c42 --- /dev/null +++ b/release/check_release_version.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +from packaging import version as pkg_version + +parser = argparse.ArgumentParser() + +parser.add_argument("--release_version", type=str, help="The new version being published.") + +args = parser.parse_args() + +release_version = pkg_version.parse(args.release_version) + +with open('./version.txt') as fd: + repo_version = pkg_version.parse(fd.read()) + +assert repo_version == release_version, f"{repo_version=} does not match {release_version=}, unable to proceed" diff --git a/release/release.sh b/release/release.sh index 1366532e8b06..cc3ee2feae62 100644 --- a/release/release.sh +++ b/release/release.sh @@ -25,13 +25,20 @@ if [ "${version}" != `cat version.txt` ]; then exit 1 fi +echo "checking that the version is valid" +python release/check_release_version.py --release_version ${version} +if [ $? != 0 ]; then + echo 'please check the version number selected' + exit 1 +fi + python -c "import twine" if [ $? != 0 ]; then echo 'please install twine via pip' exit 1 fi -DS_BUILD_STRING="" python setup.py sdist +DS_BUILD_STRING="" python -m build --sdist if [ ! -f dist/deepspeed-${version}.tar.gz ]; then echo "prepared version does not match version given ($version), bump version first?" @@ -45,5 +52,4 @@ git tag v${version} git push origin v${version} echo "bumping up patch version" -cd - -python bump_patch_version.py +python release/bump_patch_version.py --current_version ${version} diff --git a/requirements/requirements-cpu.txt b/requirements/requirements-cpu.txt new file mode 100755 index 000000000000..e69de29bb2d1 diff --git a/requirements/requirements-deepcompile.txt b/requirements/requirements-deepcompile.txt new file mode 100644 index 000000000000..9a635b910d93 --- /dev/null +++ b/requirements/requirements-deepcompile.txt @@ -0,0 +1 @@ +scipy diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 6192949b2148..10a24fc697e9 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,17 +1,21 @@ -clang-format>=14.0.6 +accelerate +clang-format==18.1.3 +comet_ml>=3.41.0 +deepspeed-kernels ; sys_platform == 'linux' docutils<0.18 future importlib-metadata>=4 -megatron-lm==1.1.5 -pre-commit>=2.20.0 -pytest +mup +pre-commit>=3.2.0 +pytest>=7.2.0,<8.4.0 pytest-forked pytest-randomly pytest-xdist +qtorch==0.3.0 recommonmark sphinx sphinx-rtd-theme tensorboard torchvision -transformers +transformers>=4.51.3 wandb diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index 848a7f7a485d..b7fd13787e8b 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -1,5 +1,7 @@ google lm-eval==0.3.0 protobuf -transformers -transformers[sentencepiece] +qtorch +safetensors +sentencepiece +transformers>=4.32.1 diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index fcd0ec5a9a6a..aaac814354c4 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,9 +1,11 @@ -autodoc_pydantic +autodoc_pydantic>=2.0.0 docutils<0.18 hjson packaging psutil py-cpuinfo -pydantic -torch +pydantic>=2.0.0 +recommonmark +sphinx_rtd_theme +torch>=2.0.0 tqdm diff --git a/requirements/requirements-sd.txt b/requirements/requirements-sd.txt index 7b988876f54d..0b2ce8c2b56f 100644 --- a/requirements/requirements-sd.txt +++ b/requirements/requirements-sd.txt @@ -1,2 +1,2 @@ -diffusers -triton==2.0.0.dev20221202 +diffusers>=0.25.0 +triton>=2.1.0 diff --git a/requirements/requirements-sparse_pruning.txt b/requirements/requirements-sparse_pruning.txt new file mode 100755 index 000000000000..3b96b4134cdb --- /dev/null +++ b/requirements/requirements-sparse_pruning.txt @@ -0,0 +1 @@ +neural-compressor==2.1.0 diff --git a/requirements/requirements-triton.txt b/requirements/requirements-triton.txt new file mode 100644 index 000000000000..3b382f83f2ae --- /dev/null +++ b/requirements/requirements-triton.txt @@ -0,0 +1 @@ +triton==2.1.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6840d6dbcc98..1bbd21dd5e32 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,9 +1,11 @@ +einops hjson +msgpack ninja numpy packaging>=20.0 psutil py-cpuinfo -pydantic -torch +pydantic>=2.0.0 +torch>=2.0.0 tqdm diff --git a/scripts/check-extraindexurl.py b/scripts/check-extraindexurl.py new file mode 100755 index 000000000000..017939af95ac --- /dev/null +++ b/scripts/check-extraindexurl.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from __future__ import annotations +'''Copyright The Microsoft DeepSpeed Team''' +""" +Checks each file in sys.argv for the string "--extra-index-url". +Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py +""" + +import subprocess +import sys + + +def err(s: str) -> None: + print(s, file=sys.stderr) + + +print(*sys.argv[1:]) + +# There are many ways we could search for the string "--extra-index-url", but `git +# grep --no-index` is nice because +# - it's very fast (as compared to iterating over the file in Python) +# - we can reasonably assume it's available on all machines +# - unlike plain grep, which is slower and has different flags on MacOS versus +# Linux, git grep is always the same. +excluded_file = ".github/workflows/xpu-max1100.yml" +res = subprocess.run( + [ + "git", "grep", "-Hn", "--no-index", "-e", r"--extra-index-url", "--", f":(exclude){excluded_file}", + *sys.argv[1:] + ], + capture_output=True, +) +if res.returncode == 0: + err('Error: The string "--extra-index-url" was found.\nPlease replace all calls to --extra-index-url with "--index-url"' + ) + err(res.stdout.decode("utf-8")) + sys.exit(1) +elif res.returncode == 2: + err(f"Error invoking grep on {', '.join(sys.argv[1:])}:") + err(res.stderr.decode("utf-8")) + sys.exit(2) diff --git a/scripts/check-license.py b/scripts/check-license.py index 67caa30a3e3f..0d0e1e578faa 100755 --- a/scripts/check-license.py +++ b/scripts/check-license.py @@ -19,20 +19,24 @@ def err(s: str) -> None: COPYRIGHT = [ - r"^\(\/\/\|#\) Copyright (c) Microsoft Corporation.$", r"^\(\/\/\|#\) SPDX-License-Identifier: Apache-2.0$", - r"^\(\/\/\|#\) DeepSpeed Team$" + # (r"^# Copyright (c) Microsoft Corporation.$", r"^\/\/ Copyright (c) Microsoft Corporation.$"), + (r"^# SPDX-License-Identifier: Apache-2.0$", r"^\/\/ SPDX-License-Identifier: Apache-2.0$"), + (r"^# DeepSpeed Team$", r"^\/\/ DeepSpeed Team$"), ] success = True failures = [] for f in sys.argv[1:]: for copyright_line in COPYRIGHT: - if not success: - break - res = subprocess.run(["git", "grep", "--quiet", "-e", copyright_line, f], capture_output=True) + cmd = ["git", "grep", "--quiet"] + for line in copyright_line: + cmd.extend(["-e", line]) + cmd.append(f) + res = subprocess.run(cmd, capture_output=True) if res.returncode == 1: success = False failures.append(f) + break elif res.returncode == 2: err(f"Error invoking grep on {', '.join(sys.argv[1:])}:") err(res.stderr.decode("utf-8")) diff --git a/scripts/check-torchcuda.py b/scripts/check-torchcuda.py index 04207173e227..639d11ad54ca 100755 --- a/scripts/check-torchcuda.py +++ b/scripts/check-torchcuda.py @@ -19,6 +19,8 @@ def err(s: str) -> None: print(s, file=sys.stderr) +print(*sys.argv[1:]) + # There are many ways we could search for the string "torch.cuda", but `git # grep --no-index` is nice because # - it's very fast (as compared to iterating over the file in Python) @@ -52,3 +54,27 @@ def err(s: str) -> None: err(f"Error invoking grep on {', '.join(sys.argv[1:])}:") err(res.stderr.decode("utf-8")) sys.exit(2) + +files = [] +for file in sys.argv[1:]: + if file.endswith(".py"): + files.append(file) + +if len(files) > 0: + res = subprocess.run( + ["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files], + capture_output=True, + ) + if res.returncode == 0: + err(''' +Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor. + Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)", + and add the following import line: + 'from deepspeed.accelerator import get_accelerator' +''') + err(res.stdout.decode("utf-8")) + sys.exit(1) + elif res.returncode == 2: + err(f"Error invoking grep on {', '.join(files)}:") + err(res.stderr.decode("utf-8")) + sys.exit(2) diff --git a/scripts/check-torchdist.py b/scripts/check-torchdist.py index f0328aca6469..734c449a1135 100755 --- a/scripts/check-torchdist.py +++ b/scripts/check-torchdist.py @@ -25,8 +25,9 @@ def err(s: str) -> None: # - we can reasonably assume it's available on all machines # - unlike plain grep, which is slower and has different flags on MacOS versus # Linux, git grep is always the same. +# allowing `torch.distributed.nn` res = subprocess.run( - ["git", "grep", "-Hn", "--no-index", r"torch\.distributed", *sys.argv[1:]], + ["git", "grep", "-Hn", "--no-index", "-P", r"torch\.distributed |torch\.distributed(?!\.nn)", *sys.argv[1:]], capture_output=True, ) if res.returncode == 0: diff --git a/scripts/replace_copyright.py b/scripts/replace_copyright.py index c0697509d29b..03a8c63f9abc 100644 --- a/scripts/replace_copyright.py +++ b/scripts/replace_copyright.py @@ -115,7 +115,7 @@ def get_header_c(fp): # multiline comment not closed on same line in_multiline = True elif l.endswith(C_ML_CLOSE): - # Ended a multline comment + # Ended a multiline comment in_multiline = False elif not in_multiline or l.startswith(C_SL_COMMENT) or l.isspace(): # Not in a comment diff --git a/setup.py b/setup.py index 1c7c29663700..67befe3da31b 100755 --- a/setup.py +++ b/setup.py @@ -5,8 +5,8 @@ """ DeepSpeed library -To build wheel on Windows: -1. Install pytorch, such as pytorch 1.12 + cuda 11.6. +To build wheels on Windows: +1. Install pytorch, such as pytorch 2.3 + cuda 12.1. 2. Install visual cpp build tool. 3. Include cuda toolkit. 4. Launch cmd console with Administrator privilege for creating required symlink folders. @@ -18,12 +18,16 @@ The wheel will be located at: dist/*.whl """ +import pathlib import os +import shutil import sys import subprocess from setuptools import setup, find_packages from setuptools.command import egg_info import time +import typing +import shlex torch_available = True try: @@ -34,9 +38,11 @@ 'Please visit https://pytorch.org/ to see how to properly install torch on your system.') from op_builder import get_default_compute_capabilities, OpBuilder -from op_builder.all_ops import ALL_OPS +from op_builder.all_ops import ALL_OPS, accelerator_name from op_builder.builder import installed_cuda_version +from accelerator import get_accelerator + # Fetch rocm state. is_rocm_pytorch = OpBuilder.is_rocm_pytorch() rocm_version = OpBuilder.installed_rocm_version() @@ -56,6 +62,22 @@ def fetch_requirements(path): return [r.strip() for r in fd.readlines()] +def is_env_set(key): + """ + Checks if an environment variable is set and not "". + """ + return bool(os.environ.get(key, None)) + + +def get_env_if_set(key, default: typing.Any = ""): + """ + Returns an environment variable if it is set and not "", + otherwise returns a default value. In contrast, the fallback + parameter of os.environ.get() is skipped if the variable is set to "". + """ + return os.environ.get(key, None) or default + + install_requires = fetch_requirements('requirements/requirements.txt') extras_require = { '1bit': [], # add cupy based on cuda/rocm version @@ -65,20 +87,32 @@ def fetch_requirements(path): 'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'), 'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'), 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'), + 'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'), 'inf': fetch_requirements('requirements/requirements-inf.txt'), - 'sd': fetch_requirements('requirements/requirements-sd.txt') + 'sd': fetch_requirements('requirements/requirements-sd.txt'), + 'triton': fetch_requirements('requirements/requirements-triton.txt'), + 'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'), } +# Only install pynvml on nvidia gpus. +if torch_available and get_accelerator().device_name() == 'cuda' and not is_rocm_pytorch: + install_requires.append('nvidia-ml-py') + # Add specific cupy version to both onebit extension variants. -if torch_available and torch.cuda.is_available(): +if torch_available and get_accelerator().device_name() == 'cuda': cupy = None if is_rocm_pytorch: rocm_major, rocm_minor = rocm_version - # XXX cupy support for rocm 5 is not available yet. - if rocm_major <= 4: + # cupy support for rocm>5.0 is not available yet. + if (rocm_major == 5 and rocm_minor == 0) or rocm_major <= 4: cupy = f"cupy-rocm-{rocm_major}-{rocm_minor}" else: - cupy = f"cupy-cuda{''.join(map(str,installed_cuda_version()))}" + cuda_major_ver, cuda_minor_ver = installed_cuda_version() + if (cuda_major_ver < 11) or ((cuda_major_ver == 11) and (cuda_minor_ver < 3)): + cupy = f"cupy-cuda{cuda_major_ver}{cuda_minor_ver}" + else: + cupy = f"cupy-cuda{cuda_major_ver}x" + if cupy: extras_require['1bit'].append(cupy) extras_require['1bit_mpi'].append(cupy) @@ -94,8 +128,8 @@ def fetch_requirements(path): # For any pre-installed ops force disable ninja. if torch_available: - from accelerator import get_accelerator - cmdclass['build_ext'] = get_accelerator().build_extension().with_options(use_ninja=False) + use_ninja = is_env_set("DS_ENABLE_NINJA") + cmdclass['build_ext'] = get_accelerator().build_extension().with_options(use_ninja=use_ninja) if torch_available: TORCH_MAJOR = torch.__version__.split('.')[0] @@ -104,19 +138,19 @@ def fetch_requirements(path): TORCH_MAJOR = "0" TORCH_MINOR = "0" -if torch_available and not torch.cuda.is_available(): +if torch_available and not get_accelerator().device_name() == 'cuda': # Fix to allow docker builds, similar to https://github.com/NVIDIA/apex/issues/486. print("[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only " "you can ignore this message. Adding compute capability for Pascal, Volta, and Turing " "(compute capabilities 6.0, 6.1, 6.2)") - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + if not is_env_set("TORCH_CUDA_ARCH_LIST"): os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capabilities() ext_modules = [] # Default to pre-install kernels to false so we rely on JIT on Linux, opposite on Windows. BUILD_OP_PLATFORM = 1 if sys.platform == "win32" else 0 -BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', BUILD_OP_PLATFORM)) +BUILD_OP_DEFAULT = int(get_env_if_set('DS_BUILD_OPS', BUILD_OP_PLATFORM)) print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}") if BUILD_OP_DEFAULT: @@ -125,10 +159,12 @@ def fetch_requirements(path): def command_exists(cmd): if sys.platform == "win32": - result = subprocess.Popen(f'{cmd}', stdout=subprocess.PIPE, shell=True) + safe_cmd = shlex.split(f'{cmd}') + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) return result.wait() == 1 else: - result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + safe_cmd = shlex.split(f"bash -c type {cmd}") + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) return result.wait() == 0 @@ -140,21 +176,19 @@ def op_envvar(op_name): def op_enabled(op_name): env_var = op_envvar(op_name) - return int(os.environ.get(env_var, BUILD_OP_DEFAULT)) + return int(get_env_if_set(env_var, BUILD_OP_DEFAULT)) -compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) install_ops = dict.fromkeys(ALL_OPS.keys(), False) for op_name, builder in ALL_OPS.items(): op_compatible = builder.is_compatible() - compatible_ops[op_name] = op_compatible # If op is requested but not available, throw an error. if op_enabled(op_name) and not op_compatible: env_var = op_envvar(op_name) - if env_var not in os.environ: - builder.warning(f"One can disable {op_name} with {env_var}=0") - abort(f"Unable to pre-compile {op_name}") + if not is_env_set(env_var): + builder.warning(f"Skip pre-compile of incompatible {op_name}; One can disable {op_name} with {env_var}=0") + continue # If op is compatible but install is not enabled (JIT mode). if is_rocm_pytorch and op_compatible and not op_enabled(op_name): @@ -169,13 +203,17 @@ def op_enabled(op_name): print(f'Install Ops={install_ops}') # Write out version/git info. -git_hash_cmd = "git rev-parse --short HEAD" -git_branch_cmd = "git rev-parse --abbrev-ref HEAD" -if command_exists('git') and 'DS_BUILD_STRING' not in os.environ: +if sys.platform == "win32": + git_hash_cmd = shlex.split("git rev-parse --short HEAD") + git_branch_cmd = shlex.split("git rev-parse --abbrev-ref HEAD") +else: + git_hash_cmd = shlex.split("bash -c \"git rev-parse --short HEAD\"") + git_branch_cmd = shlex.split("bash -c \"git rev-parse --abbrev-ref HEAD\"") +if command_exists('git') and not is_env_set('DS_BUILD_STRING'): try: - result = subprocess.check_output(git_hash_cmd, shell=True) + result = subprocess.check_output(git_hash_cmd) git_hash = result.decode('utf-8').strip() - result = subprocess.check_output(git_branch_cmd, shell=True) + result = subprocess.check_output(git_branch_cmd) git_branch = result.decode('utf-8').strip() except subprocess.CalledProcessError: git_hash = "unknown" @@ -184,35 +222,30 @@ def op_enabled(op_name): git_hash = "unknown" git_branch = "unknown" - -def create_dir_symlink(src, dest): - if not os.path.islink(dest): - if os.path.exists(dest): - os.remove(dest) - assert not os.path.exists(dest) - os.symlink(src, dest) - - if sys.platform == "win32": - # This creates a symbolic links on Windows. - # It needs Administrator privilege to create symlinks on Windows. - create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc') - create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder') - create_dir_symlink('..\\accelerator', '.\\deepspeed\\accelerator') + shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True) + pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True) + shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True) + pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True) + shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True) + pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True) + shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True) egg_info.manifest_maker.template = 'MANIFEST_win.in' # Parse the DeepSpeed version string from version.txt. version_str = open('version.txt', 'r').read().strip() # Build specifiers like .devX can be added at install time. Otherwise, add the git hash. -# Example: DS_BUILD_STRING=".dev20201022" python setup.py sdist bdist_wheel. +# Example: `DS_BUILD_STRING=".dev20201022" python -m build --no-isolation`. # Building wheel for distribution, update version file. -if 'DS_BUILD_STRING' in os.environ: +if is_env_set('DS_BUILD_STRING'): # Build string env specified, probably building for distribution. with open('build.txt', 'w') as fd: - fd.write(os.environ.get('DS_BUILD_STRING')) - version_str += os.environ.get('DS_BUILD_STRING') + fd.write(os.environ['DS_BUILD_STRING']) + version_str += os.environ['DS_BUILD_STRING'] elif os.path.isfile('build.txt'): # build.txt exists, probably installing from distribution. with open('build.txt', 'r') as fd: @@ -254,11 +287,10 @@ def create_dir_symlink(src, dest): fd.write(f"git_hash='{git_hash}'\n") fd.write(f"git_branch='{git_branch}'\n") fd.write(f"installed_ops={install_ops}\n") - fd.write(f"compatible_ops={compatible_ops}\n") + fd.write(f"accelerator_name='{accelerator_name}'\n") fd.write(f"torch_info={torch_info}\n") print(f'install_requires={install_requires}') -print(f'compatible_ops={compatible_ops}') print(f'ext_modules={ext_modules}') # Parse README.md to make long_description for PyPI page. @@ -266,6 +298,14 @@ def create_dir_symlink(src, dest): with open(os.path.join(thisdir, 'README.md'), encoding='utf-8') as fin: readme_text = fin.read() +if sys.platform == "win32": + scripts = ['bin/deepspeed.bat', 'bin/ds', 'bin/ds_report.bat', 'bin/ds_report'] +else: + scripts = [ + 'bin/deepspeed', 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', 'bin/ds_bench', 'bin/dsr', + 'bin/ds_elastic', 'bin/ds_nvme_tune', 'bin/ds_io' + ] + start_time = time.time() setup(name='deepspeed', @@ -274,26 +314,23 @@ def create_dir_symlink(src, dest): long_description=readme_text, long_description_content_type='text/markdown', author='DeepSpeed Team', - author_email='deepspeed-info@microsoft.com', + author_email='info@deepspeedai.com', url='http://deepspeed.ai', project_urls={ 'Documentation': 'https://deepspeed.readthedocs.io', - 'Source': 'https://github.com/microsoft/DeepSpeed', + 'Source': 'https://github.com/deepspeedai/DeepSpeed', }, install_requires=install_requires, extras_require=extras_require, packages=find_packages(include=['deepspeed', 'deepspeed.*']), include_package_data=True, - scripts=[ - 'bin/deepspeed', 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', 'bin/ds_bench', 'bin/dsr', - 'bin/ds_elastic' - ], + scripts=scripts, classifiers=[ - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10' + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12' ], - license='MIT', + license='Apache Software License 2.0', ext_modules=ext_modules, cmdclass=cmdclass) diff --git a/tests/.coveragerc b/tests/.coveragerc new file mode 100644 index 000000000000..dccaba6b57a3 --- /dev/null +++ b/tests/.coveragerc @@ -0,0 +1,5 @@ +# .coveragerc to control coverage.py +[run] +parallel = True +sigterm = True +source = deepspeed diff --git a/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py new file mode 100644 index 000000000000..e242e0a3cd05 --- /dev/null +++ b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +This script is to test the performance of the DS4Sci_EvoformerAttention op. +To run the script, +1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git +2. DeepSpeed will detect a local or installed CUTLASS. If needed, set CUTLASS_PATH explicitly. +3. Run the script. E.g. python DS4Sci_EvoformerAttention_bench.py +""" + +import contextlib +import torch +from typing import List +from torch.nn import functional as F +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +from deepspeed.accelerator import get_accelerator + + +def attention_reference( + q_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + k_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + v_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + biases: List[torch.Tensor], + sm_scale: float) -> torch.Tensor: + # Original shape: [*, Dim_Q, H, C_hid] -> Transpose to: [*, H, Dim_Q, C_hid] + q = q_input.transpose(-2, -3) + k = k_input.transpose(-2, -3) + v = v_input.transpose(-2, -3) + + # Now, q, k, v are in shape: [*, H, Dim_Q, C_hid] + + # Transpose k to shape [*, H, C_hid, Dim_Q] + k_t = k.transpose(-1, -2) + + # Now, q and k_t are in shapes: [*, H, Dim_Q, C_hid] and [*, H, C_hid, Dim_Q] respectively + + # [*, H, Dim_Q, Dim_Q] + a = torch.matmul(q, k_t) * sm_scale + + for b in biases: + a += b + + a = F.softmax(a, dim=-1) + + # Now, a is in shape [*, H, Dim_Q, Dim_Q], v is in shape [*, H, Dim_Q, C_hid] + + # Matmul operation results in [*, H, Dim_Q, C_hid] + a_v = torch.matmul(a, v) + + # [*, Dim_Q, H, C_hid] + o = a_v.transpose(-2, -3) + + return o + + +dtype = torch.float16 + +N = 256 +heads = 4 +dim = 32 +seq_len = 256 + + +@contextlib.contextmanager +def cuda_timer(res_list): + start = get_accelerator().Event(enable_timing=True) + end = get_accelerator().Event(enable_timing=True) + start.record() + yield + end.record() + get_accelerator().synchronize() + res_list.append(start.elapsed_time(end)) + + +def benchmark(): + ours_fw = [] + ours_bw = [] + baseline_fw = [] + baseline_bw = [] + for batch in range(1, 17): + Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=False) + bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True) + # warm up + DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + with cuda_timer(ours_fw): + out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + d_out = torch.rand_like(out) + with cuda_timer(ours_bw): + out.backward(d_out) + # warm up + attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + with cuda_timer(baseline_fw): + ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + with cuda_timer(baseline_bw): + ref_out.backward(d_out) + + print("batch size\tours (FW)\tbaseline (FW)\tours (BW)\tbaseline (BW)") + for i in range(len(ours_fw)): + print(f"{i+1}\t{ours_fw[i]}\t{baseline_fw[i]}\t{ours_bw[i]}\t{baseline_bw[i]}") + + +benchmark() diff --git a/tests/benchmarks/flatten_bench.py b/tests/benchmarks/flatten_bench.py index d404acd5c344..a09600db5fbe 100755 --- a/tests/benchmarks/flatten_bench.py +++ b/tests/benchmarks/flatten_bench.py @@ -110,15 +110,15 @@ def timeme(): def line_profileme(): print("--------------- line_profiler -----------------") print("py") - profile(py)() # noqa: F821 + profile(py)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("cpp") - profile(cpp)() # noqa: F821 + profile(cpp)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("apex") - profile(apex)() # noqa: F821 + profile(apex)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() diff --git a/tests/benchmarks/unflatten_bench.py b/tests/benchmarks/unflatten_bench.py index dade4574458a..9f2f0e1e87f5 100755 --- a/tests/benchmarks/unflatten_bench.py +++ b/tests/benchmarks/unflatten_bench.py @@ -119,15 +119,15 @@ def timeme(): def line_profileme(): print("--------------- line_profier -----------------") print("py") - profile(py)() # noqa: F821 + profile(py)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("cpp") - profile(cpp)() # noqa: F821 + profile(cpp)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() print("apex") - profile(apex)() # noqa: F821 + profile(apex)() # noqa: F821 # type: ignore gc.collect() get_accelerator().empty_cache() diff --git a/tests/conftest.py b/tests/conftest.py index e5a8cce45fd9..8137dfb74042 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import sys import pytest import os -from os.path import abspath, dirname, join +from os.path import abspath, dirname import torch import warnings @@ -17,7 +17,7 @@ # allow having multiple repository checkouts and not needing to remember to rerun # 'pip install -e .[dev]' when switching between checkouts and running tests. -git_repo_path = abspath(join(dirname(dirname(__file__)), "src")) +git_repo_path = abspath(dirname(dirname(__file__))) sys.path.insert(1, git_repo_path) @@ -70,10 +70,18 @@ def pytest_runtest_call(item): item.runtest = lambda: True # Dummy function so test is not run twice +# We allow DistributedTest to reuse distributed environments. When the last +# test for a class is run, we want to make sure those distributed environments +# are destroyed. +def pytest_runtest_teardown(item, nextitem): + if getattr(item.cls, "reuse_dist_env", False) and not nextitem: + dist_test_class = item.cls() + for num_procs, pool in dist_test_class._pool_cache.items(): + dist_test_class._close_pool(pool, num_procs, force=True) + + @pytest.hookimpl(tryfirst=True) def pytest_fixture_setup(fixturedef, request): if getattr(fixturedef.func, "is_dist_fixture", False): - #for val in dir(request): - # print(val.upper(), getattr(request, val), "\n") dist_fixture_class = fixturedef.func() dist_fixture_class(request) diff --git a/tests/model/BingBertSquad/BingBertSquad_test_common.py b/tests/model/BingBertSquad/BingBertSquad_test_common.py index ef42f85cc945..b47ddfe0c649 100755 --- a/tests/model/BingBertSquad/BingBertSquad_test_common.py +++ b/tests/model/BingBertSquad/BingBertSquad_test_common.py @@ -7,6 +7,7 @@ import subprocess import os import time +import shlex class BaseTestCase(unittest.TestCase): @@ -40,9 +41,9 @@ def ensure_directory_exists(self, filename): os.makedirs(dirname) def clean_test_env(self): - cmd = "dlts_ssh pkill -9 -f /usr/bin/python" + cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python") print(cmd) - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') + subprocess.run(cmd, check=False, executable='/bin/bash') time.sleep(20) def run_BingBertSquad_test(self, test_config, output): @@ -50,8 +51,8 @@ def run_BingBertSquad_test(self, test_config, output): other_args = " " + test_config["other_args"] if "other_args" in test_config else " " cmd = "./run_BingBertSquad_sanity.sh -e 1 -g {0} {1} {2}".format(test_config["gpus"], other_args, ds_flag) - + cmd = shlex.split(cmd) self.ensure_directory_exists(output) with open(output, "w") as f: print(cmd) - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f) + subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f) diff --git a/tests/model/BingBertSquad/run_BingBertSquad.sh b/tests/model/BingBertSquad/run_BingBertSquad.sh index fcfdf5e66361..4d06bb1230a4 100755 --- a/tests/model/BingBertSquad/run_BingBertSquad.sh +++ b/tests/model/BingBertSquad/run_BingBertSquad.sh @@ -93,7 +93,7 @@ done # Validate path to BingBertSquad script if [ -z "${BingBertSquad_DIR+x}" ]; then - export BingBertSquad_DIR=../../../../DeepSpeedExamples/BingBertSquad + export BingBertSquad_DIR=../../../../DeepSpeedExamples/training/BingBertSquad echo "BingBertSquad_DIR environment variable not set; trying default: ${BingBertSquad_DIR}" fi validate_folder ${BingBertSquad_DIR} "BingBertSquad_DIR" @@ -160,8 +160,11 @@ run_cmd="deepspeed.pt \ --master_port ${master_port} ${BingBertSquad_script} ${other_args} ${squad_args}" -echo ${run_cmd} -eval ${run_cmd} +# Sanitize input before running eval() +safe_cmd=$(printf '%q' "$run_cmd") + +echo ${safe_cmd} +eval ${safe_cmd} set +x diff --git a/tests/model/BingBertSquad/run_BingBertSquad_sanity.sh b/tests/model/BingBertSquad/run_BingBertSquad_sanity.sh index 1b49a37b783f..8b6ad942ba59 100755 --- a/tests/model/BingBertSquad/run_BingBertSquad_sanity.sh +++ b/tests/model/BingBertSquad/run_BingBertSquad_sanity.sh @@ -94,7 +94,7 @@ done # Validate path to BingBertSquad script if [ -z "${BingBertSquad_DIR+x}" ]; then - export BingBertSquad_DIR=../../../DeepSpeedExamples/BingBertSquad + export BingBertSquad_DIR=../../../DeepSpeedExamples/training/BingBertSquad echo "BingBertSquad_DIR environment variable not set; trying default: ${BingBertSquad_DIR}" fi validate_folder ${BingBertSquad_DIR} "BingBertSquad_DIR" diff --git a/tests/model/BingBertSquad/run_tests.sh b/tests/model/BingBertSquad/run_tests.sh index eef93ef98796..2a69fdf01c79 100755 --- a/tests/model/BingBertSquad/run_tests.sh +++ b/tests/model/BingBertSquad/run_tests.sh @@ -31,7 +31,7 @@ validate_folder() { # Validate path to BingBertSquad script if [ -z "${BingBertSquad_DIR+x}" ]; then - export BingBertSquad_DIR=../../../DeepSpeedExamples/BingBertSquad + export BingBertSquad_DIR=../../../DeepSpeedExamples/training/BingBertSquad echo "BingBertSquad_DIR environment variable not set; trying default: ${BingBertSquad_DIR}" fi validate_folder ${BingBertSquad_DIR} "BingBertSquad_DIR" diff --git a/tests/model/BingBertSquad/test_e2e_squad.py b/tests/model/BingBertSquad/test_e2e_squad.py index 9312dc67a193..9f03b89d0829 100644 --- a/tests/model/BingBertSquad/test_e2e_squad.py +++ b/tests/model/BingBertSquad/test_e2e_squad.py @@ -10,11 +10,11 @@ import pytest import json -sys.path.append("../../../DeepSpeedExamples/BingBertSquad") +sys.path.append("../../../DeepSpeedExamples/training/BingBertSquad") import evaluate as eval squad_dir = "/data/BingBertSquad" -base_dir = "../../../DeepSpeedExamples/BingBertSquad" +base_dir = "../../../DeepSpeedExamples/training/BingBertSquad" script_file_name = "run_squad_deepspeed.sh" model_file_name = "training_state_checkpoint_162.tar" diff --git a/tests/model/Megatron_GPT2/run_checkpoint_test.py b/tests/model/Megatron_GPT2/run_checkpoint_test.py index d97a28ff1ad5..ae18607f1760 100755 --- a/tests/model/Megatron_GPT2/run_checkpoint_test.py +++ b/tests/model/Megatron_GPT2/run_checkpoint_test.py @@ -10,6 +10,7 @@ import subprocess import os import re +import shlex from .test_common import BaseTestCase LAYERS = 2 @@ -18,9 +19,9 @@ def remove_file(test_id, filename): - cmd = f"if [ -f {filename} ] ; then rm -v {filename}; fi" + cmd = shlex.split(f"if [ -f {filename} ] ; then rm -v {filename}; fi") print(f"{test_id} cmd: {cmd}") - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') + subprocess.run(cmd, check=False, executable='/bin/bash') def grep_loss_from_file(file_name): @@ -451,10 +452,10 @@ def run_test(self, test_config, r_tol): checkpoint_name = test_config["checkpoint_name"] #---------------remove old checkpoint---------------# try: - cmd = f"rm -rf {checkpoint_name}" + cmd = shlex.split(f"rm -rf {checkpoint_name}") print(f"{self.id()} cmd: {cmd}") - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') - except: + subprocess.run(cmd, check=False, executable='/bin/bash') + except Exception: print("No old checkpoint") if "cpu_optimizer" in test_config and test_config["cpu_optimizer"]: @@ -474,9 +475,9 @@ def run_test(self, test_config, r_tol): # remove previous test log try: - cmd = f"rm {base_file}" - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') - except: + cmd = shlex.split(f"rm {base_file}") + subprocess.run(cmd, check=False, executable='/bin/bash') + except Exception: print(f"{self.id()} No old logs") print("{0}: Run for saving checkpoint".format(self.id())) @@ -489,10 +490,10 @@ def run_test(self, test_config, r_tol): # set checkpoint load iteration try: - cmd = f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt" + cmd = shlex.split(f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt") print(f"{self.id()} running cmd: {cmd}") - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') - except: + subprocess.run(cmd, check=False, executable='/bin/bash') + except Exception: print(f"{self.id()} Failed to update the checkpoint iteration file") return False @@ -506,9 +507,9 @@ def run_test(self, test_config, r_tol): # remove previous test log try: - cmd = f"rm {test_file}" - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') - except: + cmd = shlex.split(f"rm {test_file}") + subprocess.run(cmd, check=False, executable='/bin/bash') + except Exception: print(f"{self.id()} no previous logs for") self.run_gpt2_test(test_config, test_file) return self.check_parity(base_file, test_file, r_tol) diff --git a/tests/model/Megatron_GPT2/test_common.py b/tests/model/Megatron_GPT2/test_common.py index 1bcd891e31d5..4eb84ac7eeee 100755 --- a/tests/model/Megatron_GPT2/test_common.py +++ b/tests/model/Megatron_GPT2/test_common.py @@ -7,6 +7,7 @@ import subprocess import os import time +import shlex class BaseTestCase(unittest.TestCase): @@ -46,9 +47,9 @@ def ensure_directory_exists(self, filename): os.makedirs(dirname) def clean_test_env(self): - cmd = "dlts_ssh pkill -9 -f /usr/bin/python" + cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python") print(cmd) - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash') + subprocess.run(cmd, check=False, executable='/bin/bash') time.sleep(20) def run_gpt2_test(self, test_config, output): @@ -60,8 +61,8 @@ def run_gpt2_test(self, test_config, output): test_config["mp"], test_config["gpus"], test_config["nodes"], test_config["bs"], test_config["steps"], test_config["layers"], test_config["hidden_size"], test_config["seq_length"], test_config["heads"], ckpt_num, other_args, ds_flag) - + cmd = shlex.split(cmd) self.ensure_directory_exists(output) with open(output, "w") as f: print(cmd) - subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f) + subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f) diff --git a/tests/onebit/README.md b/tests/onebit/README.md new file mode 100644 index 000000000000..d62c25421d00 --- /dev/null +++ b/tests/onebit/README.md @@ -0,0 +1,31 @@ +# One-Bit tests + +In this folder, you can test the functionality and performance of different backend for doing compressed allreduce, which is the main algorithm in one-bit optimizers like [One-Bit Adam](https://www.deepspeed.ai/tutorials/onebit-adam/), [One-Bit Lamb](https://www.deepspeed.ai/tutorials/onebit-lamb/) and [Zero-One Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/). + +## How to run + +### NCCL and MPI backend + +Basically it requires your environment have relative communication backend installed, the NCCL backend of PyTorch distributed or Message Passing Interface (MPI) like MVAPICH2-GDR and OpenMPI. [Detailed Pre-requisites](https://www.deepspeed.ai/tutorials/zero-one-adam/#12-pre-requisites-for-01-adam). + +To test accuracy and performance of NCCL backend: +```bash +python test_nccl_backend.py +python test_nccl_perf.py +``` +Similarly, for MPI backend: +```bash +python test_mpi_backend.py +python test_mpi_perf.py +``` + +### Compressed backend + +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend` and test it, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype. +An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. + +The test usage is same as others: +```bash +python test_compressed_backend.py +python test_compressed_perf.py +``` diff --git a/tests/onebit/test_compressed_backend.py b/tests/onebit/test_compressed_backend.py new file mode 100644 index 000000000000..f6919a09a54b --- /dev/null +++ b/tests/onebit/test_compressed_backend.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +import numpy as np +import argparse +import deepspeed +import os + +from deepspeed.runtime.comm.compressed import CompressedBackend +from deepspeed.accelerator import get_accelerator + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', type=int, default=-1) +args = parser.parse_args() + +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) +args.local_rank = int(os.environ['LOCAL_RANK']) + +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) + +size = dist.get_world_size() +rank = dist.get_rank() + +backend = CompressedBackend() +local_rank = args.local_rank + + +# A simulated compression function using deepspeed.comm +def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + get_accelerator().synchronize() + dist.barrier() + return a_server_compressed, worker_error, server_error + + +tensor_size = 300 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size + +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) + +a_torch, worker_error_torch, server_error_torch = torch_sim(a) +get_accelerator().empty_cache() + +a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) + +print(a_torch.cpu()) +print(a_after.cpu()) + +threshold = 1e-6 +magnitude_threshold = 1e-6 +diff_mask = (a_after - a_torch) > threshold +diff_server_mask = torch.chunk(diff_mask, size)[rank] +mpi_server = torch.chunk(a_after, size)[rank] + server_error +torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + +test_correctness = True + +# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic +# The test would skip those numbers that are too small in compensated_server_m +if test_correctness: + if torch.sum(diff_server_mask) == 0: + print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) + else: + check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold + if torch.sum(check_mag_mask) == 0: + print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) + else: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) diff --git a/tests/onebit/test_compressed_perf.py b/tests/onebit/test_compressed_perf.py new file mode 100644 index 000000000000..a686af0f6b8d --- /dev/null +++ b/tests/onebit/test_compressed_perf.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +import numpy as np +import argparse +import deepspeed +import os + +from deepspeed.runtime.comm.compressed import CompressedBackend +from deepspeed.utils.timer import SynchronizedWallClockTimer +from deepspeed.accelerator import get_accelerator +from statistics import mean + +timers = SynchronizedWallClockTimer() + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', type=int, default=-1) +args = parser.parse_args() + +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) +args.local_rank = int(os.environ['LOCAL_RANK']) + +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) + +size = dist.get_world_size() +rank = dist.get_rank() + +backend = CompressedBackend() +local_rank = args.local_rank + +# Setting tensor_size (BERT-Large) +tensor_size = 300 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size + +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) + +warmup = 10 +iters = 10 + +# Warmup +for i in range(warmup): + backend.compressed_allreduce(a, worker_error, server_error, local_rank) + +time_list = [] + +a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) +scale = a.norm() / np.sqrt(a.numel()) +a_compressed = scale * a_sign + +print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None + +for i in range(iters): + timers('compressed_allreduce').start() + backend.compressed_allreduce(a, worker_error, server_error, local_rank) + #deepspeed.comm.all_reduce(a_compressed) + timers('compressed_allreduce').stop() + time_list.append(timers('compressed_allreduce').elapsed()) + +#timer_names = ['compressed_allreduce'] +#timers.log(names=timer_names, normalizer=1, memory_breakdown=None) + +places = 2 +convert = 1e3 +float_size = 4 + +if rank == 0: + for i in range(iters): + lat = time_list[i] + print("latency = ", lat * convert) + +minlat = round(min(time_list) * convert) +maxlat = round(max(time_list) * convert) +meanlat = round(mean(time_list) * convert, places) +print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None +#print("tensor shape", a.shape) +duration = meanlat / 1e3 +tput = ((tensor_size * 4) / duration) +print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None +size = tensor_size * 4 +n = dist.get_world_size() +busbw = (size / duration) * (2 * (n - 1) / n) +print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py index b35477afb4fe..bde1d53e5179 100755 --- a/tests/perf/adam_test1.py +++ b/tests/perf/adam_test1.py @@ -6,12 +6,10 @@ import torch from deepspeed.ops.adam import DeepSpeedCPUAdam import time -from deepspeed.accelerator import get_accelerator device = 'cpu' model_size = 1 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, device=get_accelerator().device_name(0))) optimizer = DeepSpeedCPUAdam([param]) #torch.set_num_threads(128) @@ -19,7 +17,7 @@ avg = 0 for i in range(100): start = time.time() - optimizer.step(fp16_param_groups=[param_fp16]) + optimizer.step() stop = time.time() avg += (stop - start) param.grad = torch.ones(model_size, device=device) * 2 diff --git a/tests/pytest.ini b/tests/pytest.ini index 08b666867b79..f841c47afc0c 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,8 +1,13 @@ [pytest] -addopts = -m "not sequential and not nightly and not inference and not seq_inference and not inference_ops" +addopts = -m "not sequential and not nightly and not inference and not seq_inference and not inference_ops and not inference_v2 and not inference_v2_ops and not stable_diffusion and not evaluation" markers = sequential:Tests that need to be run sequentially inference:Inference model tests inference_ops:Individual inference operator tests + inference_v2:Inference tests for the v2 stack + inference_v2_ops:Op tests for the v2 stack seq_inference:Inference model tests to run sequentially nightly:Tests that should be run nightly + world_size:Change world size of individual tests in a class + stable_diffusion:Tests that run Stable Diffusion + evaluation:Tests that evaluate model correctness diff --git a/tests/small_model_debugging/partial_offload_test.py b/tests/small_model_debugging/partial_offload_test.py new file mode 100644 index 000000000000..2094448d534d --- /dev/null +++ b/tests/small_model_debugging/partial_offload_test.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import argparse +import torch +import deepspeed +from torch.utils.data.distributed import DistributedSampler +import deepspeed.comm as dist + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleModel, self).__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) + if empty_grad: + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + hidden = self.linear(hidden) + hidden = self.linear2(hidden) + hidden = self.linear3(hidden) + hidden = self.linear4(hidden) + return self.cross_entropy_loss(hidden, y) + + +def create_config_from_dict(tmpdir, config_dict): + config_path = os.path.join(tmpdir, 'temp_config.json') + with open(config_path, 'w') as fd: + json.dump(config_dict, fd) + return config_path + + +def get_data_loader(model, total_samples, hidden_dim, device): + batch_size = model.train_micro_batch_size_per_gpu() + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + sampler = DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +def get_args(tmpdir, config_dict): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--zero', type=int, default=0) + args = parser.parse_args() #args='' + + config_dict["zero_optimization"]["stage"] = args.zero + print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) + config_path = create_config_from_dict(tmpdir, config_dict) + + args.deepspeed_config = config_path + return args + + +def print0(msg): + if dist.get_rank() == 0: + print(msg, flush=True) + + +rank = int(os.environ['RANK']) +print('seed:', 2222 + rank) +torch.random.manual_seed(2222 + rank) + +config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 0, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } +} +# "initial_scale_power": 15 +args = get_args('/tmp/', config_dict) +hidden_dim = 4 * 1024 + +model = SimpleModel(hidden_dim, empty_grad=False) + +model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=True) + + +def print_params(tag, model): + if dist.get_rank() == 0: + for n, p in model.named_parameters(): + print0("{} {}:{}".format(tag, n, p)) + + +data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) +#print_params('pre-train', model) +#while True: +for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0: + print("LOSS:", loss.item()) + model.backward(loss) + model.step() + #print_params('step={}'.format(n), model) + if n == 2: break diff --git a/tests/small_model_debugging/stage3_test.py b/tests/small_model_debugging/stage3_test.py index 3a92d31f1b7a..5bd8e728caf1 100644 --- a/tests/small_model_debugging/stage3_test.py +++ b/tests/small_model_debugging/stage3_test.py @@ -15,9 +15,9 @@ class VerboseLinear(torch.nn.Linear): def __init__(self, **kwargs): - print(f'Begin VerboseLinear.__init__') + print('Begin VerboseLinear.__init__') super().__init__(**kwargs) - print(f'End VerboseLinear.__init__') + print('End VerboseLinear.__init__') class LinearStack(torch.nn.Module): diff --git a/tests/small_model_debugging/test_mics.sh b/tests/small_model_debugging/test_mics.sh new file mode 100755 index 000000000000..9f306a7055d3 --- /dev/null +++ b/tests/small_model_debugging/test_mics.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +deepspeed test_mics_config.py --mics_shard_size=1 + +deepspeed test_mics_config.py --mics_shard_size=2 + +# for debugging the hierarchical params gathering +export NDEV_PER_NODE=2 +deepspeed test_mics_config.py --mics_shard_size=4 --mics_hierarchical_params_gather diff --git a/tests/small_model_debugging/test_mics_config.py b/tests/small_model_debugging/test_mics_config.py new file mode 100644 index 000000000000..ccb34fadaefe --- /dev/null +++ b/tests/small_model_debugging/test_mics_config.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Testing on a 8 GPUs node +NDEV_PER_NODE=2 torchrun --nnodes 1 --nproc-per-node 8 test_mics_config.py +""" + +import os +import json +import argparse +import torch +import deepspeed +from torch.utils.data.distributed import DistributedSampler +import deepspeed.comm as dist + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleModel, self).__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + if empty_grad: + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + hidden = self.linear(hidden) + return self.cross_entropy_loss(hidden, y) + + +def create_config_from_dict(tmpdir, config_dict): + config_path = os.path.join(tmpdir, 'temp_config.json') + with open(config_path, 'w') as fd: + json.dump(config_dict, fd) + return config_path + + +def get_data_loader(model, total_samples, hidden_dim, device): + batch_size = model.train_micro_batch_size_per_gpu() + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.float) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + sampler = DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +def get_args(tmpdir, config_dict): + parser = argparse.ArgumentParser() + parser.add_argument('--zero', type=int, default=3) + parser.add_argument('--local_rank', type=int) + + parser.add_argument('--mics_shard_size', default=2, type=int) + parser.add_argument('--mics_hierarchical_params_gather', default=False, action='store_true') + args = parser.parse_args() #args='' + + config_dict["zero_optimization"]["stage"] = args.zero + config_dict["zero_optimization"]["mics_shard_size"] = args.mics_shard_size + config_dict["zero_optimization"]["mics_hierarchical_params_gather"] = args.mics_hierarchical_params_gather + + # print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) + config_path = create_config_from_dict(tmpdir, config_dict) + + args.deepspeed_config = config_path + return args + + +def print0(msg): + if dist.get_rank() == 0: + print(msg, flush=True) + + +rank = int(os.environ['RANK']) +print('seed:', 2222 + rank) +torch.random.manual_seed(2222 + rank) + +config_dict = { + "train_batch_size": 8, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": False, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 3, + "reduce_bucket_size": 20, + "mics_shard_size": 4, + "mics_hierarchical_params_gather": True, + "stage3_model_persistence_threshold": 10 + } +} +# "initial_scale_power": 15 +args = get_args('/tmp/', config_dict) +hidden_dim = 32 + +with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, empty_grad=False) +# print('------> init model with deepspeed.zero.Init()') + +model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=True) + + +def print_params(tag, model): + if dist.get_rank() == 0: + for n, p in model.named_parameters(): + print0("{} {}:{}".format(tag, n, p)) + + +data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) +#print_params('pre-train', model) +for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0: + print("LOSS:", loss.item()) + model.backward(loss) + model.step() + #print_params('step={}'.format(n), model) + if n == 5: break diff --git a/tests/small_model_debugging/test_model.py b/tests/small_model_debugging/test_model.py index 66dfe149a956..2706cde980d4 100755 --- a/tests/small_model_debugging/test_model.py +++ b/tests/small_model_debugging/test_model.py @@ -16,15 +16,18 @@ class SimpleModel(torch.nn.Module): def __init__(self, hidden_dim, empty_grad=False): super(SimpleModel, self).__init__() - self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) + self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) if empty_grad: - self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, + hidden_dim)]) #QuantizeLinear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): hidden = x - hidden = self.linear(hidden) - return self.cross_entropy_loss(hidden, y) + hidden1 = self.linear(hidden) + hidden2 = self.linear(hidden1) + return self.cross_entropy_loss(hidden2, y) def create_config_from_dict(tmpdir, config_dict): @@ -48,9 +51,11 @@ def get_args(tmpdir, config_dict): parser = argparse.ArgumentParser() parser.add_argument("--local_rank", type=int, default=0) parser.add_argument('--zero', type=int, default=0) + parser.add_argument('--zero_hpz_partition_size', type=int, default=1) args = parser.parse_args() #args='' config_dict["zero_optimization"]["stage"] = args.zero + config_dict["zero_optimization"]["zero_hpz_partition_size"] = args.zero_hpz_partition_size print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) config_path = create_config_from_dict(tmpdir, config_dict) @@ -68,7 +73,7 @@ def print0(msg): torch.random.manual_seed(2222 + rank) config_dict = { - "train_batch_size": 8, + "train_batch_size": 256, "steps_per_print": 1, "optimizer": { "type": "Adam", @@ -78,16 +83,20 @@ def print0(msg): }, "fp16": { "enabled": True, - "initial_scale_power": 15 + "initial_scale_power": 8 }, "zero_optimization": { "stage": 0, - "reduce_bucket_size": 20 + "reduce_bucket_size": 20, + "zero_hpz_partition_size": 1, + "reduce_scatter": True, + "zero_quantized_weights": False, + "zero_quantized_gradients": False } } # "initial_scale_power": 15 args = get_args('/tmp/', config_dict) -hidden_dim = 4 +hidden_dim = 4 * 1024 model = SimpleModel(hidden_dim, empty_grad=False) @@ -103,8 +112,9 @@ def print_params(tag, model): print0("{} {}:{}".format(tag, n, p)) -data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) +data_loader = get_data_loader(model=model, total_samples=256, hidden_dim=hidden_dim, device=model.device) #print_params('pre-train', model) + for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if dist.get_rank() == 0: @@ -112,4 +122,4 @@ def print_params(tag, model): model.backward(loss) model.step() #print_params('step={}'.format(n), model) - if n == 5: break + #if n == 5: break diff --git a/tests/torch_compile/ds_config_z2.json b/tests/torch_compile/ds_config_z2.json new file mode 100644 index 000000000000..6550c4cdedf3 --- /dev/null +++ b/tests/torch_compile/ds_config_z2.json @@ -0,0 +1,35 @@ +{ + "train_batch_size": 8, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.8, + 0.999 + ], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000 + } + }, + "gradient_clipping": 1.0, + "prescale_gradients": false, + "bf16": { + "enabled": true + }, + "wall_clock_breakdown": false, + "zero_optimization": { + "stage": 2, + "overlap_comm": false, + "contiguous_gradients": false + } +} diff --git a/tests/torch_compile/ds_config_z3.json b/tests/torch_compile/ds_config_z3.json new file mode 100644 index 000000000000..6aed1648da23 --- /dev/null +++ b/tests/torch_compile/ds_config_z3.json @@ -0,0 +1,36 @@ +{ + "train_batch_size": 8, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.8, + 0.999 + ], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000 + } + }, + "gradient_clipping": 1.0, + "prescale_gradients": false, + "bf16": { + "enabled": true + }, + "wall_clock_breakdown": false, + "zero_optimization": { + "stage": 3, + "reduce_scatter": true, + "overlap_comm": false, + "contiguous_gradients": false + } +} diff --git a/tests/torch_compile/test_compile.py b/tests/torch_compile/test_compile.py new file mode 100644 index 000000000000..20bd386b3f56 --- /dev/null +++ b/tests/torch_compile/test_compile.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed import comm + +import torch +from torch.utils.data import Dataset, DataLoader + +torch._dynamo.config.cache_size_limit = 100 + + +def get_dynamo_stats(): + return torch._dynamo.utils.counters["graph_break"] + + +class RandomDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size).to(torch.bfloat16) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +data_size = 1024 +data_length = 100 +rand_loader = DataLoader(dataset=RandomDataset(data_size, data_length), batch_size=1, shuffle=False) + + +class MyModule(torch.nn.Module): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.fc0 = torch.nn.Linear(1024, 256, bias=False) + self.fc1 = torch.nn.Linear(256, 256, bias=False) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, data, residual): + output = residual + self.fc1(self.fc0(self.dropout(data))) * 0.5 + return output + + +model = MyModule() +params = model.parameters() + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') +parser.add_argument('--deepspeed_config', + type=str, + default='ds_config_z3.json', + help='path to DeepSpeed configuration file') +cmd_args = parser.parse_args() + +# initialize the DeepSpeed engine +model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=params) +model_engine.compile() + +residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_device_name()) + +start_stats = get_dynamo_stats() + +if comm.get_rank() == 0: + #print(dynamo_stats['graph_breaks']) + for item in start_stats.items(): + print(item) + +for step, batch in enumerate(rand_loader): + if step % 10 == 0 and comm.get_rank() == 0: + print(f'step={step}') + # forward() method + loss = model_engine(batch.to(get_accelerator().current_device_name()), residual).sum() + # runs backpropagation + model_engine.backward(loss) + # weight update + model_engine.step() + +dynamo_stats = get_dynamo_stats() + +if comm.get_rank() == 0: + # print break down of graph break stats with markdown, print in table format, start with reason, then count + # print a tag 'dynamo_output' before each line to allow post processing + print("dynamo_output | Reason | Count |") + print("dynamo_output | ------ | ----- |") + for item in dynamo_stats.items(): + # replace '|' in item[0] with a literal '|' to avoid mess with table format + item = (item[0].replace('|', r'\|'), item[1]) + print(f"dynamo_output | {item[0]} | {item[1]} |") + print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |") diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py new file mode 100644 index 000000000000..964cf2b24f4e --- /dev/null +++ b/tests/unit/accelerator/test_accelerator.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +import os +import sys +import importlib +import re + +import deepspeed + +DS_ACCEL_PATH = "deepspeed.accelerator" +IGNORE_FILES = ["abstract_accelerator.py", "real_accelerator.py"] + + +@pytest.fixture +def accel_class_name(module_name): + class_list = [] + mocked_modules = [] + + # Get the accelerator class name for a given module + while True: + try: + module = importlib.import_module(module_name) + break + except ModuleNotFoundError as e: + # If the environment is missing a module, mock it so we can still + # test importing the accelerator class + missing_module = re.search(r"\'(.*)\'", e.msg).group().strip("'") + sys.modules[missing_module] = lambda x: None + mocked_modules.append(missing_module) + for name in dir(module): + if name.endswith("_Accelerator"): + class_list.append(name) + + assert len(class_list) == 1, f"Multiple accelerator classes found in {module_name}" + + yield class_list[0] + + # Clean up mocked modules so as to not impact other tests + for module in mocked_modules: + del sys.modules[module] + + +@pytest.mark.parametrize( + "module_name", + [ + DS_ACCEL_PATH + "." + f.rstrip(".py") for f in os.listdir(deepspeed.accelerator.__path__[0]) + if f.endswith("_accelerator.py") and f not in IGNORE_FILES + ], +) +def test_abstract_methods_defined(module_name, accel_class_name): + module = importlib.import_module(module_name) + accel_class = getattr(module, accel_class_name) + accel_class.__init__ = lambda self: None + _ = accel_class() diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index 7f9e37f289f0..6fe84edf4eda 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -4,14 +4,17 @@ # DeepSpeed Team import pytest +import os import torch import torch.nn as nn import torch.nn.functional as F import deepspeed import deepspeed.comm as dist import deepspeed.runtime.utils as ds_utils +from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec +from .util import no_child_process_in_deepspeed_io class AlexNet(nn.Module): @@ -81,7 +84,7 @@ def cast_to_half(x): def cifar_trainset(fp16=False): torchvision = pytest.importorskip("torchvision", minversion="0.5.0") - import torchvision.transforms as transforms + from torchvision import transforms transform_list = [ transforms.ToTensor(), @@ -98,14 +101,25 @@ def cifar_trainset(fp16=False): dist.barrier() if local_rank != 0: dist.barrier() - trainset = torchvision.datasets.CIFAR10(root='/blob/cifar10-data', train=True, download=True, transform=transform) + data_root = os.getenv("TEST_DATA_DIR", "/tmp/") + if os.getenv("CIFAR10_DATASET_PATH"): + data_root = os.getenv("CIFAR10_DATASET_PATH") + download = False + else: + data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data") + download = True + trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform) if local_rank == 0: dist.barrier() return trainset def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): - with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]): + if required_torch_version(min_version=2.1): + fork_kwargs = {"device_type": get_accelerator().device_name()} + else: + fork_kwargs = {} + with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs): ds_utils.set_random_seed(seed) # disable dropout @@ -114,10 +128,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, trainset = cifar_trainset(fp16=fp16) config['local_rank'] = dist.get_rank() - engine, _, _, _ = deepspeed.initialize(config=config, - model=model, - model_parameters=[p for p in model.parameters()], - training_data=trainset) + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=trainset) losses = [] for step in range(num_steps): diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index 48b3b8017a2c..0daa1b070850 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -12,8 +12,11 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from unit.common import preferred_dtype from unit.simple_model import * +from unittest.mock import MagicMock, patch def compare_deepspeed_states(saved_model, loaded_model): @@ -25,24 +28,32 @@ def compare_deepspeed_states(saved_model, loaded_model): assert saved_model.global_steps == loaded_model.global_steps +def zero3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load_module_only=False): if not load_module_only: compare_deepspeed_states(saved_model, loaded_model) - for p0, p1 in zip(saved_model.module.named_parameters(), loaded_model.module.named_parameters()): - np0, p0 = p0 - np1, p1 = p1 - if 'deepspeed_moe.gate.wg' in np0: - # these params are converted to float at runtime, cast to half for comparison - p1 = p1.half() - p0 = p0.half() - assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}' - try: - assert torch.allclose(p0, p1, - atol=1e-07), f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}" - except RuntimeError as err: - print(f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}") - raise err + params_to_fetch = zero3_params_to_fetch( + list(saved_model.module.named_parameters()) + list(loaded_model.module.named_parameters())) + enable_gather = len(params_to_fetch) > 0 + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=enable_gather): + for p0, p1 in zip(saved_model.module.named_parameters(), loaded_model.module.named_parameters()): + np0, p0 = p0 + np1, p1 = p1 + if 'deepspeed_moe.gate.wg' in np0: + # these params are converted to float at runtime, cast to half for comparison + p1 = p1.half() + p0 = p0.half() + assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}' + try: + assert torch.allclose(p0, p1, + atol=1e-07), f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}" + except RuntimeError as err: + print(f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}") + raise err if not compare_optimizer: return @@ -75,15 +86,33 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load def compare_state_dicts(state0, state1, expected_mismatch_keys=[]): - for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()): - assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' - if k0 in expected_mismatch_keys: + key_set0 = set(k for k in state0.keys() if k not in expected_mismatch_keys) + key_set1 = set(k for k in state1.keys() if k not in expected_mismatch_keys) + assert key_set0 == key_set1, f'failure due to key mismatch {key_set0} != {key_set1}' + + for k in key_set0: + s0 = state0[k] + s1 = state1[k] + if k in expected_mismatch_keys: continue if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' assert torch.equal(s0.to('cpu'), s1.to('cpu')) else: - assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' + assert s0 == s1, f'failures with keys = {k}, {k}, values = {s0} and {s1}' + + +def compare_opt_state_dicts(state0, state1, expected_mismatch_keys=[]): + for param_group0, saved_param_group1 in zip(state0['param_groups'], state1['param_groups']): + compare_state_dicts(param_group0, saved_param_group1, expected_mismatch_keys) + + assert "state" in state0 + assert "state" in state1 + assert len([state0["state"].keys()]) == len([state1["state"].keys()]) + + for (k0, s0), (k1, s1) in zip(state0["state"].items(), state1["state"].items()): + assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' + compare_state_dicts(s0, s1, expected_mismatch_keys) def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): @@ -130,6 +159,7 @@ def create_deepspeed_model(config_dict, model, base_optimizer): model=model, model_parameters=create_moe_param_groups(model), optimizer=base_optimizer) + ds_model.empty_partition_cache() return ds_model @@ -139,13 +169,15 @@ def checkpoint_correctness_verification(config_dict, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False, - fp16=True, train_batch=False, base_optimizers=[None, None], empty_tag=False, seq_dataloader=False, - load_module_only=False): - dtype = torch.half if fp16 else torch.float32 + load_module_only=False, + dtype=None): + if dtype is None: + dtype = preferred_dtype() + ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0]) if seq_dataloader: @@ -171,6 +203,9 @@ def checkpoint_correctness_verification(config_dict, ds_model.backward(loss) ds_model.step() + # Flush zero stage 3 cache + ds_model.empty_partition_cache() + trained_model = ds_model save_folder = os.path.join(tmpdir, 'saved_checkpoint') @@ -183,7 +218,7 @@ def checkpoint_correctness_verification(config_dict, for root, _, files in os.walk(save_folder): for f in files: if "_expert_" in f and "_model_states" in f: - expert = torch.load(os.path.join(root, f)) + expert = torch.load(os.path.join(root, f), weights_only=False) needed, storages = 0, {} for name, tensor in expert.items(): needed += tensor.size().numel() @@ -196,11 +231,17 @@ def checkpoint_correctness_verification(config_dict, loaded_model = create_deepspeed_model(config_dict=config_dict, model=models[1], base_optimizer=base_optimizers[1]) assert list(trained_model.parameters())[0].dtype == list(loaded_model.parameters())[0].dtype - loaded_model.load_checkpoint(save_folder, - tag=save_tag, - load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, - load_module_only=load_module_only) + context = patch.object(loaded_model, "_get_optimizer_ckpt_name", + wraps=loaded_model._get_optimizer_ckpt_name) if not load_optimizer_states else MagicMock() + with context as optim_load_state_dict_mock: + loaded_model.load_checkpoint(save_folder, + tag=save_tag, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only) + if not load_optimizer_states: + # should not attempt to get the file name to load it + optim_load_state_dict_mock.assert_not_called() compare_model_states(trained_model, loaded_model, @@ -208,7 +249,7 @@ def checkpoint_correctness_verification(config_dict, load_module_only=load_module_only) if load_optimizer_states: - compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16) + compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16) if load_lr_scheduler_states: compare_lr_scheduler_states(trained_model, loaded_model) diff --git a/tests/unit/checkpoint/test_autotp_uc_checkpoint.py b/tests/unit/checkpoint/test_autotp_uc_checkpoint.py new file mode 100644 index 000000000000..4a23e5b43716 --- /dev/null +++ b/tests/unit/checkpoint/test_autotp_uc_checkpoint.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import types +from types import SimpleNamespace + +import torch + +from deepspeed.checkpoint.constants import (CAT_DIM, FP32_WEIGHT_KEY, PARAM, PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, + PARAMETER_WITH_SUB_PARAMS, SUB_PARAM_SHAPE, + TP_REPLICATED_PARAMETER_PATTERNS, UNIVERSAL_CHECKPOINT_INFO) +from deepspeed.checkpoint.universal_checkpoint import SubparamShape as CheckpointSubparamShape +from deepspeed.checkpoint.ds_to_universal import merge_tp_slices +from deepspeed.checkpoint.universal_checkpoint import (_get_param_uc_restore_meta, _resolve_autotp_partition, + load_hp_checkpoint_state) +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + + +class _DummyAddress: + + def __init__(self, start, numel): + self.start = start + self.numel = numel + + +class _DummyHPMapping: + + def __init__(self, param): + self.lp_fragment_address = _DummyAddress(0, param.numel()) + self._param = param + self.optim_fragment = {} + + def get_hp_fragment(self): + return self._param.view(-1) + + def get_optim_state_keys(self): + return [] + + +def _make_param(shape, meta=None): + param = torch.nn.Parameter(torch.zeros(shape, dtype=torch.float32)) + param._hp_mapping = _DummyHPMapping(param) + if meta is not None: + setattr(param, 'ds_autotp_universal_checkpoint_meta', meta) + return param + + +def test_resolve_autotp_partition_row_parallel_weight(): + param = _make_param( + (4, 4), { + 'partition_type': 'row', + 'partition_dim': 1, + 'logical_shape': (4, 8), + 'output_shape': (4, ), + 'sub_param_shape': None, + 'original_shape': (4, 8), + 'is_bias': False, + 'replicated': False, + }) + full_hp_param = torch.arange(32, dtype=torch.float32).view(4, 8) + + slice_flat = _resolve_autotp_partition(param, {PARAM: full_hp_param}, full_hp_param, tp_rank=1, tp_world_size=2) + + expected = full_hp_param.chunk(2, dim=1)[1].flatten() + assert torch.equal(slice_flat, expected) + + +def test_resolve_autotp_partition_subparam_column_weight(): + param = _make_param( + (3, 4), { + 'partition_type': 'column', + 'partition_dim': 0, + 'logical_shape': (6, 4), + 'output_shape': (6, ), + 'sub_param_shape': ((2, 2, 2), 4), + 'original_shape': (6, 4), + 'is_bias': False, + 'replicated': False, + }) + full_hp_param = torch.arange(24, dtype=torch.float32).view(6, 4) + + slice_flat = _resolve_autotp_partition(param, {PARAM: full_hp_param}, full_hp_param, tp_rank=0, tp_world_size=2) + + chunks = [sub.chunk(2, dim=0)[0] for sub in full_hp_param.view(3, 2, 4)] + expected = torch.cat(chunks, dim=0).flatten() + assert torch.equal(slice_flat, expected) + + +def test_resolve_autotp_partition_subparam_sizes_uneven_gqa_like(): + # Simulate a fused QKV weight where Q/K/V have uneven sizes along partition_dim=0. + # Example (GQA-like): + # Q: 8 + # K: 4 + # V: 4 + # Total: 16 + # + # With tp_world_size=2, correct slicing is: + # Q chunk -> 4 per rank + # K chunk -> 2 per rank + # V chunk -> 2 per rank + # Each rank gets 8 rows total, but importantly boundaries must align with Q/K/V. + sub_param_sizes = [8, 4, 4] + tp_world_size = 2 + tp_rank = 1 + + param = _make_param( + (8, 2), + { + "partition_type": "column", + "partition_dim": 0, + "logical_shape": (sum(sub_param_sizes), 2), # (16, 2) + "output_shape": (sum(sub_param_sizes), ), # (16,) + "sub_param_shape": (tuple(sub_param_sizes), 2), + "sub_param_sizes": sub_param_sizes, + "original_shape": (sum(sub_param_sizes), 2), + "is_bias": False, + "replicated": False, + }) + + # Full (unsharded) HP parameter: shape (16, 2) + full_hp_param = torch.arange(sum(sub_param_sizes) * 2, dtype=torch.float32).view(sum(sub_param_sizes), 2) + + slice_flat = _resolve_autotp_partition(param, {PARAM: full_hp_param}, + full_hp_param, + tp_rank=tp_rank, + tp_world_size=tp_world_size) + + # Expected: split into Q/K/V blocks, chunk each block by TP, take tp_rank slice, concat back. + q, k, v = torch.split(full_hp_param, sub_param_sizes, dim=0) + expected = torch.cat([ + q.chunk(tp_world_size, dim=0)[tp_rank], + k.chunk(tp_world_size, dim=0)[tp_rank], + v.chunk(tp_world_size, dim=0)[tp_rank] + ], + dim=0).flatten() + + assert torch.equal(slice_flat, expected) + + +def test_resolve_autotp_partition_replicated_bias(): + full_hp_param = torch.arange(8, dtype=torch.float32) + param = _make_param( + (8, ), { + 'partition_type': 'row', + 'partition_dim': None, + 'logical_shape': (8, ), + 'output_shape': (8, ), + 'sub_param_shape': None, + 'original_shape': (8, ), + 'is_bias': True, + 'replicated': True, + }) + + slice_flat = _resolve_autotp_partition(param, {PARAM: full_hp_param}, full_hp_param, tp_rank=1, tp_world_size=2) + + assert torch.equal(slice_flat, full_hp_param) + + +def test_load_hp_checkpoint_state_prefers_autotp_metadata(tmp_path, monkeypatch): + param = _make_param( + (4, 4), { + 'partition_type': 'row', + 'partition_dim': 1, + 'logical_shape': (4, 8), + 'output_shape': (4, ), + 'sub_param_shape': None, + 'original_shape': (4, 8), + 'is_bias': False, + 'replicated': False, + }) + param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param) + + import deepspeed.checkpoint.universal_checkpoint as uc + monkeypatch.setattr(uc, "current_param", param, raising=False) + + ckpt_dir = tmp_path / "weight" + ckpt_dir.mkdir(parents=True) + full_hp_param = torch.arange(32, dtype=torch.float32).view(4, 8) + torch.save({PARAM: full_hp_param}, ckpt_dir / f"{FP32_WEIGHT_KEY}.pt") + + monkeypatch.setattr( + torch, + "load", + lambda *args, **kwargs: {PARAM: full_hp_param} if str(args[0]).endswith("fp32.pt") else 0, + ) + + step = param.load_hp_checkpoint_state(str(ckpt_dir), tp_rank=1, tp_world_size=2) + + assert step is None + expected = full_hp_param.chunk(2, dim=1)[1].flatten() + assert torch.equal(param.data.flatten(), expected) + + +def _write_tp_slice(base_dir, param_name, tp_idx, state_name, tensor): + shard_dir = base_dir / param_name / str(tp_idx) + shard_dir.mkdir(parents=True, exist_ok=True) + torch.save(tensor.reshape(-1), shard_dir / f"{state_name}.00") + + +def _write_tp_states(base_dir, param_name, tp_idx, fp32_tensor): + # merge_tp_slices attempts to merge these three states, so the test must write all of them. + _write_tp_slice(base_dir, param_name, tp_idx, "fp32", fp32_tensor) + _write_tp_slice(base_dir, param_name, tp_idx, "exp_avg", torch.zeros_like(fp32_tensor)) + _write_tp_slice(base_dir, param_name, tp_idx, "exp_avg_sq", torch.zeros_like(fp32_tensor)) + + +def test_merge_tp_slices_emits_subparam_shape_metadata(tmp_path): + slice_dir = tmp_path / "slices" + output_dir = tmp_path / "out" + param_name = "module.qkv.weight" + + tp0 = torch.arange(12, dtype=torch.float32).view(3, 4) + tp1 = torch.arange(12, 24, dtype=torch.float32).view(3, 4) + _write_tp_states(slice_dir, param_name, 0, tp0) + _write_tp_states(slice_dir, param_name, 1, tp1) + + uc_info = { + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS: [], + TP_REPLICATED_PARAMETER_PATTERNS: [], + PARAMETER_WITH_SUB_PARAMS: [{ + "patterns": [rf"^{param_name}$"], + "shape": [(2, 2, 2), 4], + "partition_dim": 0, + }], + } + + ds_checkpoint = SimpleNamespace( + get_checkpoint_info=lambda key: uc_info if key == UNIVERSAL_CHECKPOINT_INFO else {}) + + unmatched = merge_tp_slices(ds_checkpoint, str(output_dir), str(slice_dir), 2, (param_name, torch.Size([3, 4]))) + + ckpt = torch.load(output_dir / param_name / "fp32.pt", weights_only=False) + assert not unmatched + assert isinstance(ckpt[SUB_PARAM_SHAPE], CheckpointSubparamShape) + assert ckpt[SUB_PARAM_SHAPE].partition_dim == 0 + + +def test_merge_tp_slices_uses_row_parallel_cat_dim(tmp_path): + slice_dir = tmp_path / "slices" + output_dir = tmp_path / "out" + param_name = "module.proj.weight" + + tp0 = torch.arange(16, dtype=torch.float32).view(4, 4) + tp1 = torch.arange(16, 32, dtype=torch.float32).view(4, 4) + _write_tp_states(slice_dir, param_name, 0, tp0) + _write_tp_states(slice_dir, param_name, 1, tp1) + + uc_info = { + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS: [rf"^{param_name}$"], + TP_REPLICATED_PARAMETER_PATTERNS: [], + PARAMETER_WITH_SUB_PARAMS: [], + } + + ds_checkpoint = SimpleNamespace( + get_checkpoint_info=lambda key: uc_info if key == UNIVERSAL_CHECKPOINT_INFO else {}) + + merge_tp_slices(ds_checkpoint, str(output_dir), str(slice_dir), 2, (param_name, torch.Size([4, 4]))) + + ckpt = torch.load(output_dir / param_name / "fp32.pt", weights_only=False) + assert ckpt[CAT_DIM] == 1 + assert torch.equal(ckpt[PARAM], torch.cat([tp0, tp1], dim=1)) + + +def test_zero_optimizer_uc_info_comes_from_cached_state(): + param = _make_param((2, 2)) + expected_uc_info = {"key": "value"} + setattr(param, UNIVERSAL_CHECKPOINT_INFO, expected_uc_info) + + optimizer = object.__new__(DeepSpeedZeroOptimizer) + optimizer.bit16_groups = [[param]] + optimizer._enable_universal_checkpoint() + delattr(param, UNIVERSAL_CHECKPOINT_INFO) + + assert optimizer._get_universal_checkpoint_info() == expected_uc_info + + +def test_bf16_optimizer_uc_info_comes_from_cached_state(): + param = _make_param((2, 2)) + expected_uc_info = {"key": "value"} + setattr(param, UNIVERSAL_CHECKPOINT_INFO, expected_uc_info) + + optimizer = object.__new__(BF16_Optimizer) + optimizer.bf16_groups = [[param]] + optimizer._enable_universal_checkpoint() + delattr(param, UNIVERSAL_CHECKPOINT_INFO) + + assert optimizer._get_universal_checkpoint_info() == expected_uc_info + + +def test_get_param_uc_restore_meta_returns_top_level_restore_schema(): + meta = { + "partition_dim": 1, + "logical_shape": (4, 8), + "output_shape": (4, ), + "sub_param_shape": None, + "sub_param_sizes": None, + "target_partition_shape": (4, 4), + "is_bias": False, + "replicated": False, + "conversion": { + "partition_dim": 999 + }, + } + param = _make_param((4, 4), meta) + + restore_meta = _get_param_uc_restore_meta(param) + + assert restore_meta["partition_dim"] == 1 + assert restore_meta["conversion"]["partition_dim"] == 999 diff --git a/tests/unit/checkpoint/test_convert_checkpoint.py b/tests/unit/checkpoint/test_convert_checkpoint.py new file mode 100644 index 000000000000..68fdecb32e16 --- /dev/null +++ b/tests/unit/checkpoint/test_convert_checkpoint.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from unit.common import DistributedTest + + +class ModelWithSharedWeights(nn.Module): + + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(100, 100) + self.layer1 = nn.Linear(200, 200) + self.layer2 = nn.Linear(300, 300) + # tie layer 1 and layer 2 + self.layer1.weight = self.layer2.weight + + +class TestCheckpointConvert(DistributedTest): + world_size = 2 + + def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir): + config = { + "train_micro_batch_size_per_gpu": 2, + "zero_allow_untested_optimizer": True, + "zero_optimization": { + "stage": 3 + }, + } + model = ModelWithSharedWeights() + optimizer = torch.optim.Adam(model.parameters()) + + deepspeed_engine, _, _, _ = deepspeed.initialize( + config=config, + model=model, + optimizer=optimizer, + ) + ds_save_dir = tmpdir / "checkpoint_ds" + deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint") + + model = ModelWithSharedWeights() + + # save checkpoint + fp32_save_dir = tmpdir / "checkpoint_fp32" + convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir) + + # load state_dict from fp32 checkpoint + state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin') + + # check shared tensor + assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight']) + + # load state_dict into model + model.load_state_dict(state_dict, strict=True) diff --git a/tests/unit/checkpoint/test_latest_checkpoint.py b/tests/unit/checkpoint/test_latest_checkpoint.py index e2d2f9db8043..5d795c4dadcf 100644 --- a/tests/unit/checkpoint/test_latest_checkpoint.py +++ b/tests/unit/checkpoint/test_latest_checkpoint.py @@ -5,10 +5,15 @@ import deepspeed +import pytest from unit.common import DistributedTest from unit.simple_model import * from unit.checkpoint.common import checkpoint_correctness_verification +from deepspeed.ops.op_builder import FusedAdamBuilder + +if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) class TestLatestCheckpoint(DistributedTest): @@ -33,8 +38,8 @@ def test_existing_latest(self, tmpdir): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=False, - empty_tag=True) + empty_tag=True, + dtype=torch.float) def test_missing_latest(self, tmpdir): config_dict = { diff --git a/tests/unit/checkpoint/test_lr_scheduler.py b/tests/unit/checkpoint/test_lr_scheduler.py index c4c6773cd474..6dd7e3279521 100644 --- a/tests/unit/checkpoint/test_lr_scheduler.py +++ b/tests/unit/checkpoint/test_lr_scheduler.py @@ -5,6 +5,7 @@ import deepspeed from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest from unit.simple_model import * @@ -22,6 +23,8 @@ class TestLRSchedulerCheckpoint(DistributedTest): def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") + if get_accelerator().device_name() == 'cpu': + pytest.skip("CPU accelerator does not support this test.") config_dict = { "train_batch_size": 2, @@ -35,9 +38,6 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload @@ -51,12 +51,16 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): } } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: global DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config_dict_or_path=config_dict): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -71,6 +75,8 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") + if get_accelerator().device_name() == 'cpu': + pytest.skip("CPU accelerator does not support this test.") config_dict = { "train_batch_size": 2, @@ -81,9 +87,6 @@ def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): "lr": 1e-5 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload @@ -97,10 +100,14 @@ def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): } }, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 if zero_stage == 3: - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config_dict_or_path=config_dict): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] diff --git a/tests/unit/checkpoint/test_mics_optimizer.py b/tests/unit/checkpoint/test_mics_optimizer.py new file mode 100644 index 000000000000..9e56bf3446fa --- /dev/null +++ b/tests/unit/checkpoint/test_mics_optimizer.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import deepspeed + +from deepspeed.utils.torch import required_torch_version +from unit.common import DistributedTest +from unit.simple_model import * +from unit.checkpoint.common import * + +import pytest + +if not required_torch_version(max_version=2.0): + pytest.skip("Skipping until we resolve problems with torch 2.1", allow_module_level=True) + + +class TestMiCSCheckpoint(DistributedTest): + world_size = 4 + + def _toy_model_config(self, shard_size): + + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.00015, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + }, + "wall_clock_breakdown": True, + "zero_optimization": { + "stage": 3, + "mics_shard_size": shard_size + } + } + + hidden_dim = 10 + with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict): + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] + + return config_dict, hidden_dim, models + + @pytest.mark.parametrize('shard_size', [1, 2, 4]) + def test_load_optimizer_state(self, tmpdir, shard_size): + config_dict, hidden_dim, models = self._toy_model_config(shard_size) + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True) + + @pytest.mark.parametrize('shard_size', [1, 2, 4]) + def test_not_load_optimizer_state(self, tmpdir, shard_size): + config_dict, hidden_dim, models = self._toy_model_config(shard_size) + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False) + + @pytest.mark.parametrize('shard_size', [1, 2, 4]) + def test_load_module_only(self, tmpdir, shard_size): + config_dict, hidden_dim, models = self._toy_model_config(shard_size) + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True) + + @pytest.mark.parametrize('shard_size', [1, 2, 4]) + def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size): + config_dict, _, models = self._toy_model_config(shard_size) + ds_engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=models[0], + model_parameters=models[0].parameters(), + optimizer=None) + + ds_engine.save_checkpoint(tmpdir) + if ds_engine.global_rank < shard_size: + assert ds_engine.save_non_zero_checkpoint == True + else: + assert ds_engine.save_non_zero_checkpoint == False diff --git a/tests/unit/checkpoint/test_moe_checkpoint.py b/tests/unit/checkpoint/test_moe_checkpoint.py index 470a0d51f579..89878b5d8fa9 100644 --- a/tests/unit/checkpoint/test_moe_checkpoint.py +++ b/tests/unit/checkpoint/test_moe_checkpoint.py @@ -4,10 +4,10 @@ # DeepSpeed Team from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer +from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest from unit.simple_model import * -from unit.util import required_torch_version from unit.checkpoint.common import checkpoint_correctness_verification @@ -19,7 +19,7 @@ class TestMoECheckpoint(DistributedTest): @pytest.mark.parametrize("ep_size", [4]) def test_checkpoint_moe(self, tmpdir, ep_size): - if not required_torch_version(): + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} @@ -33,14 +33,14 @@ def test_checkpoint_moe(self, tmpdir, ep_size): tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)]) def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): - if not required_torch_version(): + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = { @@ -77,7 +77,7 @@ def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): tmpdir=tmpdir, load_optimizer_states=load_optim_states, load_lr_scheduler_states=False, - fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, - seq_dataloader=True) + seq_dataloader=True, + dtype=torch.float16) diff --git a/tests/unit/checkpoint/test_other_optimizer.py b/tests/unit/checkpoint/test_other_optimizer.py index 9cb8c4286880..9d623260f1dd 100644 --- a/tests/unit/checkpoint/test_other_optimizer.py +++ b/tests/unit/checkpoint/test_other_optimizer.py @@ -19,6 +19,8 @@ class TestOtherOptimizerCheckpoint(DistributedTest): @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test_checkpoint_unfused_optimizer(self, tmpdir): + #if not get_accelerator().is_fp16_supported(): + # pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -29,9 +31,6 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True - }, "scheduler": { "type": "OneCycle", "params": { @@ -49,6 +48,14 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): } } } + dtype = torch.float + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + dtype = torch.float16 + + # with bf16 fails with: DeepSpeed lamb optimizer requires dynamic loss scaling + # if get_accelerator().is_bf16_supported(): + # config_dict["bf16"] = {"enabled": True} args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -59,16 +66,20 @@ def test_checkpoint_unfused_optimizer(self, tmpdir): models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - load_optimizer_states=True) + load_optimizer_states=True, + dtype=dtype) # Ignore optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - load_optimizer_states=False) + load_optimizer_states=False, + dtype=dtype) def test_checkpoint_fused_optimizer(self, tmpdir): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -81,10 +92,11 @@ def test_checkpoint_fused_optimizer(self, tmpdir): "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - } } + dtype = torch.float + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + dtype = torch.float16 args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -95,14 +107,16 @@ def test_checkpoint_fused_optimizer(self, tmpdir): models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - load_optimizer_states=True) + load_optimizer_states=True, + dtype=dtype) # Ignore optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - load_optimizer_states=False) + load_optimizer_states=False, + dtype=dtype) def test_checkpoint_fp32_optimizer(self, tmpdir): config_dict = { @@ -129,4 +143,4 @@ def test_checkpoint_fp32_optimizer(self, tmpdir): models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, - fp16=False) + dtype=torch.float32) diff --git a/tests/unit/checkpoint/test_pipeline.py b/tests/unit/checkpoint/test_pipeline.py index 99f1ba2ec433..c6c228ccada7 100644 --- a/tests/unit/checkpoint/test_pipeline.py +++ b/tests/unit/checkpoint/test_pipeline.py @@ -58,10 +58,10 @@ def test_checkpoint_pipe_engine(self, zero_stage, tmpdir): models=models, hidden_dim=models[0].hidden_dim, tmpdir=tmpdir, - fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, load_lr_scheduler_states=True, - train_batch=True) + train_batch=True, + dtype=torch.float16 if zero_stage > 0 else torch.float32) @pytest.mark.parametrize( "base_topo,test_topo", diff --git a/tests/unit/checkpoint/test_shared_weights.py b/tests/unit/checkpoint/test_shared_weights.py new file mode 100644 index 000000000000..ed69073fb81c --- /dev/null +++ b/tests/unit/checkpoint/test_shared_weights.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint +from unit.common import DistributedTest + + +class ModelWithSharedWeights(nn.Module): + + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(100, 100) + self.layer1 = nn.Linear(200, 200) + self.layer2 = nn.Linear(300, 300) + # tie layer 1 and layer 2 + self.layer1.weight = self.layer2.weight + + +class TestCheckpointSharedWeights(DistributedTest): + world_size = 2 + + def test_checkpoint_shared_weights(self, tmp_path): + config = { + "train_micro_batch_size_per_gpu": 2, + "zero_allow_untested_optimizer": True, + "zero_optimization": { + "stage": 2 + }, + } + model = ModelWithSharedWeights() + optimizer = torch.optim.Adam(model.parameters()) + + deepspeed_engine, _, _, _ = deepspeed.initialize( + config=config, + model=model, + optimizer=optimizer, + ) + filename = tmp_path / "checkpoint.pt" + deepspeed_engine.save_checkpoint(filename, tag="checkpoint") + + model = ModelWithSharedWeights() + state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint") + model.load_state_dict(state_dict, strict=True) diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py new file mode 100644 index 000000000000..27e151103cc4 --- /dev/null +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import math + +import deepspeed +from types import SimpleNamespace +from torch.utils._pytree import tree_map + +from deepspeed.utils.torch import required_torch_version +from deepspeed.checkpoint import UNIVERSAL_CHECKPOINT_INFO +from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal + +from unit.common import DistributedTest, DistributedFixture +from unit.simple_model import * +from unit.util import bf16_required_version_check + +from unit.checkpoint.common import compare_opt_state_dicts, compare_state_dicts + +import pytest +import deepspeed.comm as dist + + +def get_expected_mismatch_keys(): + # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to + # false positive mismatches in checkpoint state comparisons. + # Newer torch versions store tensor ids as 0, 1, 2, ... + return [] if required_torch_version(min_version=1.4) else ['params'] + + +def maybe_step(t): + return not torch.is_tensor(t) or (t.device.type == 'cpu' and t.numel() == 1) + + +def gather_opt_state(optimizer_state): + + def gather_tensor(t): + + if maybe_step(t): + return t + else: + buffer = [torch.zeros_like(t.flatten()) for _ in range(dist.get_world_size())] + dist.all_gather(buffer, t.flatten()) + return torch.cat(buffer) + + return tree_map(gather_tensor, optimizer_state) + + +def remove_pad_in_opt_state(optimizer_state, num_params): + + def remove_pad(t): + if maybe_step(t): + return t + else: + return t[:num_params] + + return tree_map(remove_pad, optimizer_state) + + +CP_TAG = "test_tag" + + +def init_ds_engine(model, ds_config, use_torch_adam): + + if use_torch_adam: + ds_optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + del ds_config["optimizer"] + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, optimizer=ds_optimizer) + else: + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + + return model + + +def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + return + + test_step = 8 + + model = SimpleModel(hidden_dim, nlayers=2) + model = init_ds_engine(model, ds_config, use_torch_adam) + data_loader = random_dataloader(model=model, + total_samples=test_step, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + if ds_config["zero_optimization"]["stage"] == 3: + model.optimizer._set_fp32_optimizer_param_groups() + sd = model.optimizer.optimizer.state_dict() if load_optim else None + model.optimizer._clear_fp32_optimizer_param_groups() + else: + sd = model.optimizer.optimizer.state_dict() if load_optim else None + + client_state = {} + client_state[UNIVERSAL_CHECKPOINT_INFO] = {} + client_state['iteration'] = test_step + model.save_checkpoint(tmpdir, tag=CP_TAG, client_state=client_state) + + cp_dir = os.path.join(tmpdir, CP_TAG) + univ_cp_dir = f"{cp_dir}_universal" + + args = SimpleNamespace(input_folder=cp_dir, + output_folder=univ_cp_dir, + num_extract_workers=1, + num_merge_workers=1, + keep_temp_folder=False, + strict=True, + inject_missing_state=False) + + dist.barrier() + if dist.get_rank() == 0: + convert_to_universal(args) + + model_state = model.state_dict() + optimizer_state = None + if load_optim: + if ds_config["zero_optimization"]["stage"] == 3: + model.optimizer._set_fp32_optimizer_param_groups() + optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict()) + model.optimizer._clear_fp32_optimizer_param_groups() + update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size) + else: + optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict()) + + if dist.get_rank() == 0: + torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt")) + + dist.barrier() + model.destroy() + + +@pytest.fixture +def ds_config(zero_stage, dtype, sub_group_size): + ds_config = { + "train_batch_size": 8, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if dtype == torch.float16: + ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + ds_config["bf16"] = {"enabled": True} + if sub_group_size > 0: + ds_config["zero_optimization"]["sub_group_size"] = sub_group_size + return ds_config + + +class _baseline(DistributedFixture): + world_size = None + + def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam): + hidden_dim = 10 + train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size) + + +class baseline_ws2(_baseline): + world_size = 2 + + +class baseline_ws4(_baseline): + world_size = 4 + + +# Stage3 use shard parameter, need to reorganize the optimizer parameters. +def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size): + for sub_group_id, group in enumerate(optimizer_state["param_groups"]): + group["params"] = None + + new_state = {} + for sub_group_id, sub_group_param_shape in enumerate(param_shapes): + total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel() + assert total_numel % world_size == 0 + numel_per_rank = total_numel // world_size + param_offset_in_current_rank = 0 + for param_name, param_shape in sub_group_param_shape.items(): + param_numel = param_shape.numel() + param_partition_numel = math.ceil(param_numel / world_size) + param_optimizer_tensor = { + "exp_avg": torch.zeros(param_numel), + "exp_avg_sq": torch.zeros(param_numel), + "step": optimizer_state['state'][sub_group_id]['step'], + } + for key in ["exp_avg", "exp_avg_sq"]: + write_offset = 0 + for rank in range(world_size): + offset = param_offset_in_current_rank + rank * numel_per_rank + length = min(param_partition_numel, param_numel - rank * param_partition_numel) + tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length) + param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp) + write_offset += length + param_offset_in_current_rank += param_partition_numel + new_state[param_name] = param_optimizer_tensor + optimizer_state["state"] = new_state + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("zero_stage", [1, 3]) +@pytest.mark.parametrize("use_torch_adam", [False, True]) +@pytest.mark.parametrize("load_optim", [False, True]) +@pytest.mark.parametrize("sub_group_size", [-1, 100]) +class TestZeROUniversalCheckpointDP(DistributedTest): + + def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + hidden_dim = 10 + loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False) + + ds_config["checkpoint"] = {"load_universal": True} + univ_model = SimpleModel(hidden_dim, nlayers=2) + univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam) + univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim) + + model_state = univ_model.state_dict() + compare_state_dicts(model_state, loaded_model_state) + + if load_optim: + if ds_config["zero_optimization"]["stage"] == 3: + univ_model.optimizer._set_fp32_optimizer_param_groups() + optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) + univ_model.optimizer._clear_fp32_optimizer_param_groups() + update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size) + else: + optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) + # padding sizes may differ when dp sizes are different + param_count = sum(p.numel() for p in univ_model.parameters()) + optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count) + loaded_optimizer_state = remove_pad_in_opt_state(loaded_optimizer_state, param_count) + + compare_opt_state_dicts(optimizer_state, loaded_optimizer_state, get_expected_mismatch_keys()) + + # Run training again to verify that the optimizer has necessary states + test_step = 8 + data_loader = random_dataloader(model=univ_model, + total_samples=test_step, + hidden_dim=hidden_dim, + device=univ_model.device, + dtype=dtype) + for batch in data_loader: + loss = univ_model(batch[0], batch[1]) + univ_model.backward(loss) + univ_model.step() + + univ_model.destroy() + + @pytest.mark.world_size(2) + def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2) + + @pytest.mark.world_size(2) + def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2) + + @pytest.mark.world_size(4) + def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4) diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 6a59da1546c8..9ad785071602 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -4,11 +4,15 @@ # DeepSpeed Team import deepspeed +from types import SimpleNamespace from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.checkpoint.utils import clone_tensors_for_torch_save, get_model_ckpt_name_for_rank +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import ZeroParamStatus +from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest, DistributedFixture from unit.simple_model import * -from unit.util import required_minimum_torch_version from unit.checkpoint.common import * @@ -18,7 +22,31 @@ class TestZeROCheckpoint(DistributedTest): world_size = 2 - @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(1, False, 'Adam'), (2, False, 'Adam'), + @pytest.mark.parametrize('zero_stage', [3]) + def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage): + config_dict = { + "train_batch_size": 2, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + "pipeline_loading_checkpoint": True, + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + hidden_dim = 10 + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] + + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True) + + @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [(0, False, 'Adam'), (1, False, 'Adam'), + (2, False, 'Adam'), (2, True, 'deepspeed_adam'), (3, False, 'Adam'), (3, True, 'deepspeed_adam')]) @@ -38,20 +66,20 @@ def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_op "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "wall_clock_breakdown": True, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 if zero_stage == 3: - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config_dict_or_path=config_dict): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -78,20 +106,21 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, ada "weight_decay": 3e-7 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 if zero_stage == 3: global DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config_dict_or_path=config_dict): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -108,11 +137,11 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage): "stage": zero_stage }, "zero_allow_untested_optimizer": True, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)] optimizers = [HybridStateOptimizer(model.parameters()) for model in models] @@ -126,23 +155,25 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage): @pytest.mark.parametrize('zero_stage', [0, 1, 2, 3]) def test_load_module_only(self, tmpdir, zero_stage): + if zero_stage == 0 and get_accelerator().device_name() == "cpu": + pytest.skip("CPU Accelerator does not support this test") config_dict = { "train_batch_size": 2, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 if zero_stage == 3: - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config_dict_or_path=config_dict): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @@ -154,24 +185,24 @@ class ws4_model_checkpoint(DistributedFixture): world_size = 4 def run(self, class_tmpdir, elastic_save, load_optim): - ds_config = { + config_dict = { "train_batch_size": 4, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_save } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -183,6 +214,33 @@ def run(self, class_tmpdir, elastic_save, load_optim): model.save_checkpoint(class_tmpdir) +class ws4_model_checkpoint_zeropp(DistributedFixture): + + world_size = 4 + + def run(self, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + } + } + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + for param in model.parameters(): + param.data = torch.ones_like(param.data, device=param.data.device, requires_grad=False) + + # save model and zero checkpoint + torch.save(model.state_dict(), os.path.join(class_tmpdir, "model.pt")) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + ds_model.save_checkpoint(class_tmpdir) + + @pytest.mark.parametrize("elastic_save", [True, False]) @pytest.mark.parametrize("elastic_load", [True, False]) @pytest.mark.parametrize("load_optim", [True, False]) @@ -190,50 +248,54 @@ class TestZeROElasticCheckpoint(DistributedTest): world_size = 2 def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, load_optim): - ds_config = { + config_dict = { "train_batch_size": 2, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_save } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to # false positive mismatches in checkpoint state comparisons. # Newer torch versions store tensor ids as 0, 1, 2, ... - expected_mismatch_keys = [] if required_minimum_torch_version(1, 4) else ['params'] + expected_mismatch_keys = [] if required_torch_version(min_version=1.4) else ['params'] models = [SimpleModel(hidden_dim) for _ in range(2)] - model, _, _, _ = deepspeed.initialize(config=ds_config, + model, _, _, _ = deepspeed.initialize(config=config_dict, model=models[0], model_parameters=models[0].parameters()) - data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) + run_steps = 8 + data_loader = random_dataloader(model=model, + total_samples=run_steps, + hidden_dim=hidden_dim, + device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() if load_optim: - torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, 'opt-state-dict')) + opt_state_dict_file = f'opt-state-dict_rank{dist.get_rank()}' + torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, opt_state_dict_file)) model.save_checkpoint(tmpdir) - ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load - model, _, _, _ = deepspeed.initialize(config=ds_config, + config_dict["zero_optimization"]["elastic_checkpoint"] = elastic_load + model, _, _, _ = deepspeed.initialize(config=config_dict, model=models[1], model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False) curr_sd = model.optimizer.optimizer.state_dict() - for curr_param_group, saved_param_group in zip(curr_sd['param_groups'], saved_sd['param_groups']): - compare_state_dicts(curr_param_group, saved_param_group, expected_mismatch_keys) + compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): @@ -243,25 +305,25 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l def test_elastic_checkpoint_change_dp(self, ws4_model_checkpoint, class_tmpdir, elastic_save, elastic_load, load_optim): - ds_config = { + config_dict = { "train_batch_size": 4, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 2, "elastic_checkpoint": elastic_load } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 model = SimpleModel(hidden_dim) # Load checkpoint with dp world size = 2 - model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) if load_optim: with pytest.raises(deepspeed.runtime.zero.utils.ZeRORuntimeException): model.load_checkpoint(class_tmpdir, load_optimizer_states=load_optim) @@ -279,14 +341,14 @@ def test_immediate_save_load(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -299,34 +361,33 @@ def test_immediate_save_load(self, tmpdir, zero_stage): @pytest.mark.parametrize('zero_stage', [0, 1, 2, 3]) def test_load_immediate_save(self, tmpdir, zero_stage): + if zero_stage == 0 and get_accelerator().device_name() == "cpu": + pytest.skip("CPU Accelerator does not support this test") config_dict = { "train_batch_size": 4, "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 model = SimpleModel(hidden_dim) # 1. pretrain a model and save it - dtype = torch.half ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) - data_loader = random_dataloader(model=ds_model, - total_samples=1, - hidden_dim=hidden_dim, - device=ds_model.device, - dtype=dtype) + data_loader = random_dataloader(model=ds_model, total_samples=1, hidden_dim=hidden_dim, device=ds_model.device) for _, batch in enumerate(data_loader): loss = ds_model(batch[0], batch[1]) ds_model.backward(loss) ds_model.step() + + ds_model.empty_partition_cache() ds_model.save_checkpoint(tmpdir) # 2. load and immediately save a model with a fresh ds engine @@ -343,10 +404,6 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): "optimizer": { "type": 'Adam' }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": zero_stage, "stage3_gather_fp16_weights_on_model_save": True, @@ -355,6 +412,10 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): "train_micro_batch_size_per_gpu": 1, "train_batch_size": 4, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -363,19 +424,337 @@ def test_save_before_accum_grad_is_done(self, tmpdir, zero_stage): # So we config grad_accum=2 and step only once and save_16bit_model ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) - data_loader = random_dataloader(model=ds_model, - total_samples=2, - hidden_dim=hidden_dim, - device=ds_model.device, - dtype=torch.half) + data_loader = random_dataloader(model=ds_model, total_samples=2, hidden_dim=hidden_dim, device=ds_model.device) batch = next(iter(data_loader)) loss = ds_model(batch[0], batch[1]) ds_model.backward(loss) ds_model.step() + ds_model.empty_partition_cache() + # we stepped only once, and now save 16bit model before gradient_accumulation_steps=2 is complete ds_model.save_16bit_model(tmpdir, "model.pt") # let's test just as well that we can save the checkpoint too ds_model.save_checkpoint(tmpdir) + + +class TestZeROCheckpointFrozenWeights(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + def test_load_optimizer_state(self, tmpdir, zero_stage): + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.00015, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "wall_clock_breakdown": True, + "zero_optimization": { + "stage": zero_stage + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + hidden_dim = 10 + + with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=config_dict): + models = [SimpleFrozenModel(hidden_dim, empty_grad=False) for _ in range(2)] + + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True) + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + def test_not_load_optimizer_state(self, tmpdir, zero_stage): + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.00015, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "zero_optimization": { + "stage": zero_stage + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + hidden_dim = 10 + + with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=config_dict): + models = [SimpleFrozenModel(hidden_dim, empty_grad=False) for _ in range(2)] + + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False) + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + def test_load_module_only(self, tmpdir, zero_stage): + config_dict = { + "train_batch_size": 2, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + hidden_dim = 10 + + with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=config_dict): + models = [SimpleFrozenModel(hidden_dim, empty_grad=False) for _ in range(2)] + + checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True) + + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): + world_size = 1 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + hidden_dim = 10 + + model = SimpleFrozenModel(hidden_dim, empty_grad=False) + + ds_engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + # Validate backwards-compatibility of including frozen parameters in checkpoint + all_ckpt_folder = os.path.join(tmpdir, 'all_params') + ds_engine.save_checkpoint(all_ckpt_folder) + all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00') + loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module'] + all_param_names = set([n for n, p in model.named_parameters()]) + assert set(loaded_all_param_model.keys()) == all_param_names + + # Validate exclusion of frozen parameters + trainable_ckpt_folder = os.path.join(tmpdir, 'no_frozen_params') + ds_engine.save_checkpoint(trainable_ckpt_folder, exclude_frozen_parameters=True) + + trainable_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(trainable_ckpt_folder, 'global_step0'), '00') + + # Excluding frozen parameters should reduce checkpoint size + assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file) + + loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module'] + frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad]) + loaded_trainable_param_names = set(loaded_trainable_param_model.keys()) + overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names) + assert len(overlap_names) == 0 + + trainable_param_names = set([n for n, p in model.named_parameters() if p.requires_grad]) + assert loaded_trainable_param_names == trainable_param_names + + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage): + world_size = 1 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + hidden_dim = 10 + + model = SimpleFrozenModel(hidden_dim, empty_grad=False) + + ds_engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + # Validate custom state_dict model + state_dict_bk = model.state_dict + model.state_dict = model.custom_state_dict + custom_state_dict_ckpt_folder = os.path.join(tmpdir, 'custom_state_dict') + ds_engine.save_checkpoint(custom_state_dict_ckpt_folder, exclude_frozen_parameters=True) + + custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank( + os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00') + loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module'] + loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys()) + + custom_state_dict_param_names = set([k for k, v in model.state_dict().items()]) + trainable_param_names = set([n for n, p in model.named_parameters() if p.requires_grad]) + overlap_names = set.intersection(custom_state_dict_param_names, trainable_param_names) + + assert loaded_custom_state_dict_param_names == overlap_names + + model.state_dict = state_dict_bk + + +class TestSaveTensorClone(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('zero_stage', [1, 2]) + @pytest.mark.parametrize('use_cpu_device', [True, False]) + def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device): + + config_dict = { + "optimizer": { + "type": "AdamW", + }, + "zero_optimization": { + "stage": zero_stage + }, + "train_batch_size": 1, + "train_micro_batch_size_per_gpu": 1 + } + hidden_dim = 1024 + model = SimpleModel(hidden_dim, nlayers=4).half() + ref_model_state_dict = model.state_dict() + + ds_engine, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict) + clone_device = torch.device('cpu') if use_cpu_device else get_accelerator().current_device() + clone_state_dict = clone_tensors_for_torch_save(ds_engine.module.state_dict()) + compare_state_dicts(ref_model_state_dict, clone_state_dict) + + ref_ckpt_file = os.path.join(tmpdir, 'ref_ckpt.pt') + torch.save(ref_model_state_dict, ref_ckpt_file) + clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt') + torch.save(clone_state_dict, clone_ckpt_file) + + compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False), + torch.load(clone_ckpt_file, weights_only=False)) + + +class TestZeRONonDistributed(DistributedTest): + world_size = 1 + init_distributed = False + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + def test_chmod_exception_handling(self, monkeypatch, zero_stage): + + config_dict = { + "optimizer": { + "type": "AdamW" + }, + "train_batch_size": 1, + "zero_optimization": { + "stage": zero_stage + } + } + args = SimpleNamespace(local_rank=0) + net = SimpleModel(hidden_dim=4) + engine, _, _, _ = deepspeed.initialize(args=args, + config=config_dict, + model=net, + model_parameters=net.parameters()) + + log_called = False + + def mock_logger_info(message, *args, **kwargs): + nonlocal log_called + log_called = True + + monkeypatch.setattr("deepspeed.utils.logger.info", mock_logger_info) + """ + This is presented for use-cases like Azure Storage File Share (where permissions are not allowed) + We use a fake file for this test (file not existing would present a similar issue as not being able to chmod) + """ + fake_recovery_script_dst = os.path.join("tmp", "zero_to_fp32.py") + engine._change_recovery_script_permissions(fake_recovery_script_dst) + + assert log_called, "Expected deepspeed.utils.logger.info to be called." + + +class TestZeROPPLoadCheckpoint(DistributedTest): + + world_size = 4 + + def test_load_zeropp_model(self, ws4_model_checkpoint_zeropp, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + "stage3_param_persistence_threshold": 1 + } + } + + # Init model and load saved model + hidden_dim = 10 + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + + with deepspeed.zero.GatheredParameters(ds_model.module.parameters(), modifier_rank=0): + if dist.get_rank() == 0: + state_dict = torch.load(os.path.join(class_tmpdir, "model.pt")) + ds_model.module.load_state_dict(state_dict) + + # Check the parameters after gather + params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + if len(params_to_gather) > 0: + handle = params_to_gather[0].all_gather_coalesced(params_to_gather) + handle.wait() + for ds_param in params_to_gather: + for v in ds_param.data.cpu().flatten().numpy(): + assert v == 1.0 + + def test_load_zeropp_checkpoint(self, ws4_model_checkpoint_zeropp, class_tmpdir): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": 3, + "zero_hpz_partition_size": 2, + "stage3_param_persistence_threshold": 1 + } + } + + # Init model and load zero checkpoint + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None) + ds_model.load_checkpoint(class_tmpdir, + load_optimizer_states=True, + load_lr_scheduler_states=False, + load_module_only=False) + + # Check the parameters after gather + params_to_gather = [p for p in ds_model.module.parameters() if p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + if len(params_to_gather) > 0: + handle = params_to_gather[0].all_gather_coalesced(params_to_gather) + handle.wait() + for ds_param in params_to_gather: + for v in ds_param.data.cpu().flatten().numpy(): + assert v == 1.0 diff --git a/tests/unit/comm/test_dist.py b/tests/unit/comm/test_dist.py index b98a1f52e874..0cbc611dc38c 100644 --- a/tests/unit/comm/test_dist.py +++ b/tests/unit/comm/test_dist.py @@ -13,6 +13,10 @@ from deepspeed.accelerator import get_accelerator import pytest +from deepspeed.ops.op_builder import FusedAdamBuilder + +if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) class TestInit(DistributedTest): @@ -106,17 +110,45 @@ def test(self, distributed_fixture, class_tmpdir, val1, val2): assert int(os.environ["WORLD_SIZE"]) == 1 +@pytest.mark.parametrize("num_elements", [128, 3]) class TestDistAllReduce(DistributedTest): - world_size = [1, 2, 4] - - def test(self): - x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) + device_count = get_accelerator().device_count() + if device_count >= 4: + world_size = [1, 2, 4] + elif device_count >= 2: + world_size = [1, 2] + else: + world_size = [1] + + def test(self, num_elements): + x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 - result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks + result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks dist.all_reduce(x) assert torch.all(x == result) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("num_elements", [128, 3]) +class TestDistInferenceAllReduce(DistributedTest): + device_count = get_accelerator().device_count() + if device_count >= 4: + world_size = [1, 2, 4] + elif device_count >= 2: + world_size = [1, 2] + else: + world_size = [1] + + def test(self, dtype, num_elements): + x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1) + sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 + result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks + result = result.to(dtype) + x = x.to(dtype) + dist.inference_all_reduce(x) + assert torch.all(x == result) + + @pytest.mark.parametrize("dist_init_required", [True, False, None]) class TestDistInit(DistributedTest): init_distributed = False diff --git a/tests/unit/common.py b/tests/unit/common.py index 5eca38cc83f8..426b974eefb4 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -4,27 +4,36 @@ # DeepSpeed Team import os +import re import time import inspect +import socket +import subprocess from abc import ABC, abstractmethod from pathlib import Path +import random +import tempfile +import numpy as np +from typing import Callable, Any import torch import torch.multiprocessing as mp import deepspeed from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist -from torch.multiprocessing import Process + +from .util import torch_assert_close import pytest from _pytest.outcomes import Skipped from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker -# Worker timeout *after* the first worker has completed. -DEEPSPEED_UNIT_WORKER_TIMEOUT = 120 - # Worker timeout for tests that hang -DEEPSPEED_TEST_TIMEOUT = 600 +DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) + + +def is_rocm_pytorch(): + return hasattr(torch.version, 'hip') and torch.version.hip is not None def get_xdist_worker_id(): @@ -35,12 +44,38 @@ def get_xdist_worker_id(): return None -def get_master_port(): - master_port = os.environ.get('DS_TEST_PORT', '29503') +def get_master_port(base_port=29500, port_range_size=1000): xdist_worker_id = get_xdist_worker_id() if xdist_worker_id is not None: - master_port = str(int(master_port) + xdist_worker_id) - return master_port + # Make xdist workers use different port ranges to avoid race conditions + base_port += port_range_size * xdist_worker_id + + # Select first open port in range + port = base_port + max_port = base_port + port_range_size + sock = socket.socket() + while port < max_port: + try: + sock.bind(('', port)) + sock.close() + return str(port) + except OSError: + port += 1 + raise IOError('no free ports') + + +def _get_cpu_socket_count(): + import shlex + p1 = subprocess.Popen(shlex.split("cat /proc/cpuinfo"), stdout=subprocess.PIPE) + p2 = subprocess.Popen(["grep", "physical id"], stdin=p1.stdout, stdout=subprocess.PIPE) + p1.stdout.close() + p3 = subprocess.Popen(shlex.split("sort -u"), stdin=p2.stdout, stdout=subprocess.PIPE) + p2.stdout.close() + p4 = subprocess.Popen(shlex.split("wc -l"), stdin=p3.stdout, stdout=subprocess.PIPE) + p3.stdout.close() + r = int(p4.communicate()[0]) + p4.stdout.close() + return r def set_accelerator_visible(): @@ -50,28 +85,42 @@ def set_accelerator_visible(): xdist_worker_id = 0 if cuda_visible is None: # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead - import subprocess if get_accelerator().device_name() == 'cuda': - is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None - if is_rocm_pytorch: + if is_rocm_pytorch(): rocm_smi = subprocess.check_output(['rocm-smi', '--showid']) gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n')) - num_gpus = len(list(gpu_ids)) + num_accelerators = len(list(gpu_ids)) else: nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus']) - num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n')) - else: - assert get_accelerator().device_name() == 'xpu' - import re + num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n')) + elif get_accelerator().device_name() == 'xpu': clinfo = subprocess.check_output(['clinfo']) lines = clinfo.decode('utf-8').strip().split('\n') - num_gpus = 0 + num_accelerators = 0 for line in lines: match = re.search('Device Type.*GPU', line) if match: - num_gpus += 1 + num_accelerators += 1 + elif get_accelerator().device_name() == 'hpu': + try: + hl_smi = subprocess.check_output(['hl-smi', "-L"]) + num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode()) + except FileNotFoundError: + sim_list = subprocess.check_output(['ls', '-1', '/dev/accel']) + num_accelerators = re.findall(r"accel(\d+)", sim_list.decode()) + num_accelerators = sorted(num_accelerators, key=int) + os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators) + elif get_accelerator().device_name() == 'npu': + npu_smi = subprocess.check_output(['npu-smi', 'info', '-l']) + num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip()) + else: + assert get_accelerator().device_name() == 'cpu' + num_accelerators = _get_cpu_socket_count() - cuda_visible = ",".join(map(str, range(num_gpus))) + if isinstance(num_accelerators, list): + cuda_visible = ",".join(num_accelerators) + else: + cuda_visible = ",".join(map(str, range(num_accelerators))) # rotate list based on xdist worker id, example below # wid=0 -> ['0', '1', '2', '3'] @@ -93,22 +142,22 @@ class DistributedExec(ABC): init_distributed = True set_dist_env = True requires_cuda_env = True + reuse_dist_env = False + non_daemonic_procs = False + _pool_cache = {} + exec_timeout = DEEPSPEED_TEST_TIMEOUT @abstractmethod def run(self): ... - def __call__(self, request=None): + def __call__(self, request): self._fixture_kwargs = self._get_fixture_kwargs(request, self.run) world_size = self.world_size if self.requires_cuda_env and not get_accelerator().is_available(): pytest.skip("only supported in accelerator environments.") - if isinstance(world_size, int): - world_size = [world_size] - for procs in world_size: - self._launch_procs(procs) - time.sleep(0.5) + self._launch_with_file_store(request, world_size) def _get_fixture_kwargs(self, request, func): if not request: @@ -124,24 +173,64 @@ def _get_fixture_kwargs(self, request, func): pass # test methods can have kwargs that are not fixtures return fixture_kwargs - def _launch_procs(self, num_procs): - if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: - pytest.skip( - f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" - ) - mp.set_start_method('forkserver', force=True) + def _launch_daemonic_procs(self, num_procs, init_method): + # Create process pool or use cached one + master_port = None + + if get_accelerator().device_name() == 'hpu': + if self.reuse_dist_env: + print("Ignoring reuse_dist_env for hpu") + self.reuse_dist_env = False + + if self.reuse_dist_env: + if num_procs not in self._pool_cache: + self._pool_cache[num_procs] = mp.Pool(processes=num_procs) + master_port = get_master_port() + pool = self._pool_cache[num_procs] + else: + pool = mp.Pool(processes=num_procs) + master_port = get_master_port() + + # Run the test + args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)] + skip_msgs_async = pool.starmap_async(self._dist_run, args) + + try: + skip_msgs = skip_msgs_async.get(self.exec_timeout) + except mp.TimeoutError: + # Shortcut to exit pytest in the case of a hanged test. This + # usually means an environment error and the rest of tests will + # hang (causing super long unit test runtimes) + pytest.exit("Test hanged, exiting", returncode=1) + finally: + # Regardless of the outcome, ensure proper teardown + # Tear down distributed environment and close process pools + self._close_pool(pool, num_procs) + + # If we skipped a test, propagate that to this process + if any(skip_msgs): + assert len(set(skip_msgs)) == 1, "Multiple different skip messages received" + pytest.skip(skip_msgs[0]) + + def _launch_non_daemonic_procs(self, num_procs, init_method): + assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" + + master_port = get_master_port() skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason processes = [] + prev_start_method = mp.get_start_method() + mp.set_start_method('spawn', force=True) for local_rank in range(num_procs): - p = Process(target=self._dist_init, args=(local_rank, num_procs, skip_msg)) + p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg)) p.start() processes.append(p) + mp.set_start_method(prev_start_method, force=True) # Now loop and wait for a test to complete. The spin-wait here isn't a big # deal because the number of processes will be O(#GPUs) << O(#CPUs). any_done = False start = time.time() - while (not any_done) and ((time.time() - start) < DEEPSPEED_TEST_TIMEOUT): + while (not any_done) and ((time.time() - start) < self.exec_timeout): for p in processes: if not p.is_alive(): any_done = True @@ -152,11 +241,11 @@ def _launch_procs(self, num_procs): if not any_done: for p in processes: p.terminate() - pytest.exit("Test hanged, exiting", returncode=0) + pytest.exit("Test hanged, exiting", returncode=1) # Wait for all other processes to complete for p in processes: - p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT) + p.join(self.exec_timeout) failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] for rank, p in failed: @@ -174,43 +263,100 @@ def _launch_procs(self, num_procs): # add a check here to assert all exit messages are equal pytest.skip(skip_msg.get()) - def _dist_init(self, local_rank, num_procs, skip_msg): - """Initialize deepspeed.comm and execute the user function. """ - if self.set_dist_env: - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = get_master_port() - os.environ['LOCAL_RANK'] = str(local_rank) - # NOTE: unit tests don't support multi-node so local_rank == global rank - os.environ['RANK'] = str(local_rank) - os.environ['WORLD_SIZE'] = str(num_procs) + def _launch_procs(self, num_procs, init_method): + # Verify we have enough accelerator devices to run this test + if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: + pytest.skip( + f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" + ) - # turn off NCCL logging if set - os.environ.pop('NCCL_DEBUG', None) + if get_accelerator().device_name() == 'xpu': + self.non_daemonic_procs = True + self.reuse_dist_env = False - if get_accelerator().is_available(): - set_accelerator_visible() + # Allow disabling reuse_dist_env via environment variable. + # This is useful for CI full test runs where reusing distributed environment + # can cause pool worker cleanup to hang after tests complete. + if os.environ.get('DS_DISABLE_REUSE_DIST_ENV', '0') == '1': + self.reuse_dist_env = False - if self.init_distributed: - deepspeed.init_distributed(dist_backend=self.backend) - dist.barrier() + # Set start method to `forkserver` (or `fork`) + mp.set_start_method('forkserver', force=True) + + if self.non_daemonic_procs: + self._launch_non_daemonic_procs(num_procs, init_method) + else: + self._launch_daemonic_procs(num_procs, init_method) - if get_accelerator().is_available(): - get_accelerator().set_device(local_rank) + def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): + if dist.is_initialized(): + if get_accelerator().is_available(): + # local_rank might not match the rank in the previous run if you are reusing the environment + get_accelerator().set_device(dist.get_rank()) + else: + """ Initialize deepspeed.comm and execute the user function. """ + if self.set_dist_env: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(master_port) + os.environ['LOCAL_RANK'] = str(local_rank) + # NOTE: unit tests don't support multi-node so local_rank == global rank + os.environ['RANK'] = str(local_rank) + # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE + # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly + os.environ['LOCAL_SIZE'] = str(num_procs) + os.environ['WORLD_SIZE'] = str(num_procs) + + # turn off NCCL logging if set + os.environ.pop('NCCL_DEBUG', None) + + if get_accelerator().is_available(): + set_accelerator_visible() + + if get_accelerator().is_available(): + get_accelerator().set_device(local_rank) + + if self.init_distributed: + deepspeed.init_distributed(dist_backend=self.backend, + init_method=init_method, + rank=local_rank, + world_size=num_procs) + dist.barrier() try: self.run(**self._fixture_kwargs) except BaseException as e: if isinstance(e, Skipped): - skip_msg.put(e.msg) + if self.non_daemonic_procs: + skip_msg.put(e.msg) + else: + skip_msg = e.msg else: raise e - if self.init_distributed or dist.is_initialized(): - # make sure all ranks finish at the same time + return skip_msg + + def _launch_with_file_store(self, request, world_size): + tmpdir = request.getfixturevalue("tmpdir") + + if isinstance(world_size, int): + world_size = [world_size] + for procs in world_size: + with tempfile.NamedTemporaryFile(delete=False, dir=str(tmpdir), suffix='_filestore') as fp: + init_method = f"file://{fp.name}" + self._launch_procs(procs, init_method) + time.sleep(0.5) + + def _dist_destroy(self): + if (dist is not None) and dist.is_initialized(): dist.barrier() - # tear down after test completes dist.destroy_process_group() + def _close_pool(self, pool, num_procs, force=False): + if force or not self.reuse_dist_env: + pool.starmap(self._dist_destroy, [() for _ in range(num_procs)]) + pool.close() + pool.join() + class DistributedFixture(DistributedExec): """ @@ -343,13 +489,9 @@ def __call__(self, request): world_size = mark.args[0] break else: - world_size = self.world_size + world_size = self._fixture_kwargs.get("world_size", self.world_size) - if isinstance(world_size, int): - world_size = [world_size] - for procs in world_size: - self._launch_procs(procs) - time.sleep(0.5) + self._launch_with_file_store(request, world_size) def _get_current_test_func(self, request): # DistributedTest subclasses may have multiple test methods @@ -360,3 +502,97 @@ def _get_current_test_func(self, request): def get_test_path(filename): curr_path = Path(__file__).parent return str(curr_path.joinpath(filename)) + + +# bf16 > fp16 > fp32 +def preferred_dtype(): + if get_accelerator().is_bf16_supported(): + return torch.bfloat16 + elif get_accelerator().is_fp16_supported(): + return torch.float16 + else: + return torch.float32 + + +class EnableDeterminism: + + def __init__(self, seed: int): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + self.seed = seed + local_rank + self.saved_random_state = None + self.saved_np_random_state = None + self.saved_cuda_launch_blocking = None + self.saved_cublas_workspace_config = None + self.saved_deterministic_algorithms = None + + def __enter__(self): + self.saved_random_state = random.getstate() + self.saved_np_random_state = np.random.get_state() + self.saved_acc_rng_state = get_accelerator().get_rng_state() + self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "") + self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") + self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + + random.seed(self.seed) + np.random.seed(self.seed) + get_accelerator().manual_seed(self.seed) + get_accelerator().manual_seed_all(self.seed) + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + def __exit__(self, type, value, traceback): + random.setstate(self.saved_random_state) + np.random.set_state(self.saved_np_random_state) + get_accelerator().set_rng_state(self.saved_acc_rng_state) + os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking + os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config + torch.use_deterministic_algorithms(self.saved_deterministic_algorithms) + + +def enable_determinism(seed: int): + + def decorator(func: Callable) -> Callable: + + def wrapper(*args: Any, **kwargs: Any): + with EnableDeterminism(seed): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def reduce_boolean_flags(flag: bool, op=all) -> bool: + if not dist.is_initialized(): + return flag + device = get_accelerator().current_device() + tensor_flag = torch.tensor(1 if flag else 0, dtype=torch.int, device=device) + world_size = dist.get_world_size() + tensor_flag_buf = torch.zeros(world_size, dtype=torch.int, device=device) + dist.all_gather_into_tensor(tensor_flag_buf, tensor_flag) + list_flags = [bool(f) for f in tensor_flag_buf.tolist()] + return op(list_flags) + + +def allclose_on_all_ranks(actual, expected, assert_message=None, **kwargs) -> None: + """ + Compare two tensors across all ranks. + We want to make sure that all ranks succeed or fail together. + """ + allclose_local = False + allclose_global = False + mismatch_msg = "" + try: + torch_assert_close(actual, expected, **kwargs) + allclose_local = True + allclose_global = reduce_boolean_flags(allclose_local, all) + except AssertionError: + allclose_global = reduce_boolean_flags(allclose_local, all) + mismatch_msg = f"Tensors are not close: {actual=}, {expected=} {kwargs=}" + + if not allclose_global: + message = "Tensors are not close on all ranks." if assert_message is None else assert_message + raise AssertionError(f"{message} {mismatch_msg}") diff --git a/tests/unit/compression/test_compression.py b/tests/unit/compression/test_compression.py index bf953c76a051..1802c09f33b5 100644 --- a/tests/unit/compression/test_compression.py +++ b/tests/unit/compression/test_compression.py @@ -14,10 +14,10 @@ from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress from deepspeed.compression.helper import convert_conv1d_to_linear from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest -from unit.util import required_minimum_torch_version, required_maximum_torch_version -pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5), +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.5), reason='Megatron-LM package requires Pytorch version 1.5 or above') @@ -224,8 +224,9 @@ def test_linear_layer_compress(self, tmpdir): assert isinstance(compressed_model.layer[0].attention.self.key, LinearLayer_Compress) assert isinstance(compressed_model.layer[0].attention.self.value, LinearLayer_Compress) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_mpu_compress(self, tmpdir): - if not required_maximum_torch_version(major_version=1, minor_version=13): + if not required_torch_version(max_version=1.13): pytest.skip("megatron not compatible with torch >1.13") from megatron import mpu args_defaults = { diff --git a/tests/unit/compression/test_dequantization.py b/tests/unit/compression/test_dequantization.py new file mode 100644 index 000000000000..8446904754b3 --- /dev/null +++ b/tests/unit/compression/test_dequantization.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright (c) 2023, 2023, Oracle and/or its affiliates. + +import os +import torch +import pytest +from unit.common import DistributedTest +import deepspeed +from deepspeed.accelerator import get_accelerator + + +class TestDequantization(DistributedTest): + + def init(self): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(get_accelerator().device_name(local_rank)) + + from deepspeed.ops.op_builder import InferenceBuilder + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("InferenceBuilder is not implemented") + else: + self.dequantize_func = InferenceBuilder().load().dequantize_fp16 + + def run_dequantize_test(self, M, N, num_groups): + weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device) + scale = torch.rand(num_groups, 1).to(device=self.device) + + weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous() + weight_deq_backend = self.dequantize_func(weight, scale, num_groups) + + assert torch.allclose(weight_deq, weight_deq_backend) + + def test_dequantize(self): + self.init() + + self.run_dequantize_test(14336, 7168, 32) + self.run_dequantize_test(14336, 1792, 32) + self.run_dequantize_test(768, 768, 32) + self.run_dequantize_test(768, 768, 48) diff --git a/tests/unit/elasticity/test_elastic.py b/tests/unit/elasticity/test_elastic.py index 2cd76c8c4fce..1f7cbbbca214 100644 --- a/tests/unit/elasticity/test_elastic.py +++ b/tests/unit/elasticity/test_elastic.py @@ -9,6 +9,10 @@ from deepspeed.git_version_info import version as ds_version import os from unit.simple_model import SimpleModel +from deepspeed.ops.op_builder import FusedAdamBuilder, FusedLambBuilder + +if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]: + pytest.skip("This op has not been implemented on this system.", allow_module_level=True) @pytest.fixture @@ -146,6 +150,8 @@ def test_proper_mbsz(ds_config): class TestNonElasticBatchParams(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): config_dict = { "train_batch_size": 2, @@ -178,6 +184,8 @@ def test(self): class TestNonElasticBatchParamsWithOverride(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): config_dict = { "train_batch_size": 2, @@ -209,6 +217,8 @@ def test(self): class TestElasticConfigChanged(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): config_dict = { "train_batch_size": 2, diff --git a/tests/unit/hybrid_engine/test_he_all.py b/tests/unit/hybrid_engine/test_he_all.py new file mode 100644 index 000000000000..aa1f120645b1 --- /dev/null +++ b/tests/unit/hybrid_engine/test_he_all.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +import pytest +import deepspeed +from deepspeed.ops.op_builder import OpBuilder +from unit.common import DistributedTest +from deepspeed.accelerator import get_accelerator + +from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM) +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + +rocm_version = OpBuilder.installed_rocm_version() +if rocm_version != (0, 0): + pytest.skip("skip inference tests on rocm for now", allow_module_level=True) + + +@pytest.mark.seq_inference +@pytest.mark.parametrize("batch_size", [1, 2], ids=["bsz=1", "bsz=2"]) +@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-1.3B", "facebook/opt-1.3b"]) +class TestHybridEngineTextGen(DistributedTest): + world_size = 1 + + def _generate(self, model, tokenizer, prompt): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + tokens = tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True) + for t in tokens: + if torch.is_tensor(tokens[t]): + tokens[t] = tokens[t].to(f'{get_accelerator().device_name()}:{local_rank}') + output = model.generate(**tokens, do_sample=False, max_length=100) + outputs = tokenizer.batch_decode(output, skip_special_tokens=True) + return outputs + + def get_model(self, model_name): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + model_config = AutoConfig.from_pretrained(model_name) + model_config.dropout = 0.0 + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + model = model.half() + model = model.to(f'{get_accelerator().device_name()}:{local_rank}') + return model + + def get_tokenizer(self, model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def get_prompt(self, batch_size): + if batch_size == 1: + prompt = ["Microsoft is in Washington"] + elif batch_size == 2: + prompt = ["DeepSpeed is", "Microsoft is in Washington"] + else: + raise NotImplementedError(f"batch_size {batch_size} not implemented") + return prompt + + def test_correctness(self, batch_size, model_name): + pytest.skip("skip test for now, will fix in follow-up PR") + model = self.get_model(model_name) + tokenizer = self.get_tokenizer(model_name) + prompt = self.get_prompt(batch_size) + + base_out = self._generate(model, tokenizer, prompt) + + ds_config = {"train_batch_size": 1, "fp16": {"enabled": True}, "hybrid_engine": {"enabled": True}} + model, *_ = deepspeed.initialize(model=model, config=ds_config) + + model.eval() + ds1_out = self._generate(model, tokenizer, prompt) + assert base_out == ds1_out, f"base_out: {base_out}, ds1_out: {ds1_out}" + + model.train() + model.eval() + ds2_out = self._generate(model, tokenizer, prompt) + assert base_out == ds2_out + + def test_functionality(self, batch_size, model_name): + model = self.get_model(model_name) + tokenizer = self.get_tokenizer(model_name) + prompt = self.get_prompt(batch_size) + + ds_config = {"train_batch_size": 1, "fp16": {"enabled": True}, "hybrid_engine": {"enabled": True}} + model, *_ = deepspeed.initialize(model=model, config=ds_config) + + model.eval() + ds1_out = self._generate(model, tokenizer, prompt) + + model.train() + model.eval() + ds2_out = self._generate(model, tokenizer, prompt) + + assert ds1_out == ds2_out, f"ds1_out: {ds1_out}, ds2_out: {ds2_out}" diff --git a/tests/unit/hybrid_engine/test_he_llama.py b/tests/unit/hybrid_engine/test_he_llama.py new file mode 100644 index 000000000000..fcf5b8ffb89b --- /dev/null +++ b/tests/unit/hybrid_engine/test_he_llama.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +import pytest +import deepspeed +from deepspeed.ops.op_builder import OpBuilder +from unit.common import DistributedTest +from deepspeed.accelerator import get_accelerator + +from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM) +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + +rocm_version = OpBuilder.installed_rocm_version() +if rocm_version != (0, 0): + pytest.skip("skip inference tests on rocm for now", allow_module_level=True) + + +@pytest.mark.seq_inference +@pytest.mark.parametrize("batch_size", [1, 2], ids=["bsz=1", "bsz=2"]) +@pytest.mark.parametrize("model_name", ["huggyllama/llama-7b"]) +class TestHybridEngineLlama(DistributedTest): + world_size = 1 + + def _generate(self, model, tokenizer, prompt): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + tokens = tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True) + for t in tokens: + if torch.is_tensor(tokens[t]): + tokens[t] = tokens[t].to(f'{get_accelerator().device_name()}:{local_rank}') + #output = model.generate(**tokens, do_sample=False, max_length=100) + output = model.generate(tokens.input_ids, do_sample=False, max_length=100) + outputs = tokenizer.batch_decode(output, skip_special_tokens=True) + return outputs + + def get_model(self, model_name): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + model_config = AutoConfig.from_pretrained(model_name) + model_config.dropout = 0.0 + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + # Make the model smaller so we can run it on a single GPU in CI + _ = [model.model.layers.pop(-1) for _ in range(8)] + model = model.half() + model = model.to(f'{get_accelerator().device_name()}:{local_rank}') + return model + + def get_tokenizer(self, model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def get_prompt(self, batch_size): + if batch_size == 1: + prompt = ["Microsoft is in Washington"] + elif batch_size == 2: + prompt = ["DeepSpeed is", "Microsoft is in Washington"] + else: + raise NotImplementedError(f"batch_size {batch_size} not implemented") + return prompt + + def test_correctness(self, batch_size, model_name): + pytest.skip("skip test for now, will fix in follow-up PR") + model = self.get_model(model_name) + tokenizer = self.get_tokenizer(model_name) + prompt = self.get_prompt(batch_size) + + base_out = self._generate(model, tokenizer, prompt) + + ds_config = {"train_batch_size": 1, "fp16": {"enabled": True}, "hybrid_engine": {"enabled": True}} + model, *_ = deepspeed.initialize(model=model, config=ds_config) + + model.eval() + ds1_out = self._generate(model, tokenizer, prompt) + assert base_out == ds1_out, f"base_out: {base_out}, ds1_out: {ds1_out}" + + model.train() + model.eval() + ds2_out = self._generate(model, tokenizer, prompt) + assert base_out == ds2_out + + def test_functionality(self, batch_size, model_name): + model = self.get_model(model_name) + tokenizer = self.get_tokenizer(model_name) + prompt = self.get_prompt(batch_size) + + ds_config = {"train_batch_size": 1, "fp16": {"enabled": True}, "hybrid_engine": {"enabled": True}} + model, *_ = deepspeed.initialize(model=model, config=ds_config) + + model.eval() + ds1_out = self._generate(model, tokenizer, prompt) + + model.train() + model.eval() + ds2_out = self._generate(model, tokenizer, prompt) + + assert ds1_out == ds2_out, f"ds1_out: {ds1_out}, ds2_out: {ds2_out}" diff --git a/tests/unit/hybrid_engine/test_he_lora.py b/tests/unit/hybrid_engine/test_he_lora.py new file mode 100644 index 000000000000..1d9405d8f742 --- /dev/null +++ b/tests/unit/hybrid_engine/test_he_lora.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import math +import torch +import torch.nn.functional as F +import pytest +import deepspeed +from deepspeed.runtime.zero import GatheredParameters +from deepspeed.ops.op_builder import OpBuilder +from deepspeed.utils import safe_get_full_grad +import numpy.testing as npt +from unit.common import DistributedTest +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + +from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM) + +rocm_version = OpBuilder.installed_rocm_version() +if rocm_version != (0, 0): + pytest.skip("skip inference tests on rocm for now", allow_module_level=True) + + +def to_device(batch, device): + output = {} + for k, v in batch.items(): + try: + output[k] = v.to(device) + except Exception: + output[k] = v + return output + + +def convert_linear_layer_to_lora(model, part_module_name, lora_dim=0, lora_scaling=1, lora_droppout=0): + from deepspeed.compression.helper import recursive_getattr, recursive_setattr + + repalce_name = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and part_module_name in name: + repalce_name.append(name) + for name in repalce_name: + module = recursive_getattr(model, name) + tmp = LinearLayer_LoRA(module.weight, lora_dim, lora_scaling, lora_droppout, + module.bias).to(module.weight.device).to(module.weight.dtype) + recursive_setattr(model, name, tmp) + return model + + +class LinearLayer_LoRA(torch.nn.Module): + # an simple implementation of LoRA + # for now only support Linear Layer + def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_droppout=0, bias=None): + super(LinearLayer_LoRA, self).__init__() + self.weight = weight + self.bias = bias + + if lora_dim <= 0: + raise ValueError("You are training to use LoRA, whose reduced dim should be larger than 1") + + try: + # for zero stage 3 + rows, columns = weight.ds_shape + except Exception: + rows, columns = weight.shape + self.lora_right_weight = torch.nn.Parameter(torch.zeros( + columns, lora_dim)) # apply transpose so in forward we do not need to transpose again + self.lora_left_weight = torch.nn.Parameter(torch.zeros(lora_dim, rows)) + self.lora_scaling = lora_scaling / lora_dim + + if lora_droppout > 0: + self.lora_dropout = torch.nn.Dropout(lora_droppout) + else: + self.lora_dropout = torch.nn.Identity() + + self.reset_parameters() + # disable the original weight gradient + self.weight.requires_grad = False + # fuse LoRA to the original weight + self.fuse_lora = False + + def eval(self): + self.lora_dropout.eval() + + def train(self, mode=True): + self.lora_dropout.train(mode) + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_left_weight) + + def forward(self, input): + if self.fuse_lora: + return F.linear(input, self.weight, self.bias) + else: + return F.linear(input, self.weight, self.bias) + ( + self.lora_dropout(input) @ self.lora_right_weight @ self.lora_left_weight) * self.lora_scaling + + +def only_optimize_lora_parameters(model): + # turn off the gradient of all the parameters except the LoRA parameters + for name, param in model.named_parameters(): + if "lora_right_weight" in name or "lora_left_weight" in name: + param.requires_grad = True + else: + param.requires_grad = False + return model + + +@pytest.mark.seq_inference +@pytest.mark.parametrize("batch_size", [1], ids=["bsz=1"]) +@pytest.mark.parametrize("zero_stage", [2, 3], ids=["zero_stage=2", "zero_stage=3"]) +@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-125m", "facebook/opt-350m", "bigscience/bloom-560m"]) +@pytest.mark.parametrize("offload_device", ["none", "cpu"]) +class TestHybridEngineLoRA(DistributedTest): + world_size = 1 + + def get_model(self, model_name): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + model_config = AutoConfig.from_pretrained(model_name) + model_config.dropout = 0.0 + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + model = model.half() + device = get_accelerator().device_name() + model = model.to(f'{device}:{local_rank}') + return model + + def get_tokenizer(self, model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def get_train_sentences(self, batch_size): + sentences = [ + r"\n\nHuman: I am trying to write a fairy tale. What is the most popular plot?\n\n" + r"Assistant: The most popular plot might be a princess goes to a faraway land, falls in love", + r"\n\nHuman: What flowers should I grow to attract bees?\n\nAssistant: The reason you want bees " + r"in your garden is to attract pollinators and get more fruit or vegetable production." + ] + if batch_size <= 2: + return sentences[:batch_size] + else: + raise NotImplementedError(f"batch_size {batch_size} not implemented") + + def test_lora(self, batch_size, model_name, zero_stage, offload_device): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + model = self.get_model(model_name) + tokenizer = self.get_tokenizer(model_name) + train_sentences = self.get_train_sentences(batch_size) + + # Inject LoRA + model = convert_linear_layer_to_lora(model, "", 8) + model = only_optimize_lora_parameters(model) + + ds_config = { + "optimizer": { + "type": "Adam", + "params": { + "lr": 1.0, + "betas": [0.9, 0.95] + } + }, + "train_batch_size": batch_size, + "fp16": { + "enabled": True, + "initial_scale_power": 12 + }, + "hybrid_engine": { + "enabled": True, + "pin_parameters": True + }, + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": { + "device": offload_device + } + } + } + + model, *_ = deepspeed.initialize(model=model, config=ds_config) + + # Verify gradient norm is larger than 0 + before_grad_update_layer0_params = [ + ele.detach().cpu().float().numpy() for ele in model.layer_params[0] + if ele is not None and len(ele.shape) > 1 + ] + + model.train() + batch = tokenizer(train_sentences, max_length=16, padding="max_length", truncation=True, return_tensors="pt") + device = get_accelerator().device_name() + batch = to_device(batch, f'{device}:{local_rank}') + batch["labels"] = batch["input_ids"] + outputs = model(**batch, use_cache=False) + loss = outputs.loss + model.backward(loss) + + grad_norm_dict = dict() + for name, param in model.named_parameters(): + if param.requires_grad is True: + grad_norm_dict[name] = torch.linalg.norm(safe_get_full_grad(param)) + + model.step() + grad_norm = sum([ele.detach().cpu().numpy() for ele in grad_norm_dict.values()]) + assert grad_norm > 1E-5 + + # Verify parameter remains the same + after_grad_update_layer0_params = [ + ele.detach().cpu().float().numpy() for ele in model.layer_params[0] + if ele is not None and len(ele.shape) > 1 + ] + for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params): + npt.assert_allclose(lhs, rhs, 1E-5, 1E-5) + + # Verify fuse will mutate layer_params + model.eval() + with GatheredParameters(model.parameters()): + model.fuse_lora_weight() + + after_grad_update_layer0_params_lora_fused = [ + ele.detach().cpu().float().numpy() for ele in model.layer_params[0] + if ele is not None and len(ele.shape) > 1 + ] + + for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params_lora_fused): + with pytest.raises(AssertionError): + npt.assert_allclose(lhs, rhs, 1E-5, 1E-5) + + with GatheredParameters(model.parameters()): + model.unfuse_lora_weight() diff --git a/tests/unit/inference/__init__.py b/tests/unit/inference/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/quantization/test_intX_quantization.py b/tests/unit/inference/quantization/test_intX_quantization.py new file mode 100644 index 000000000000..8169912ae487 --- /dev/null +++ b/tests/unit/inference/quantization/test_intX_quantization.py @@ -0,0 +1,413 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np +import torch +import torch.nn as nn +from unit.common import DistributedTest +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.quantization.quantization import _init_group_wise_weight_quantization +from deepspeed.inference.quantization.utils import Quantizer, DeQuantizer +from deepspeed.inference.quantization.layers import QuantizedLinear +from deepspeed.utils.torch import required_torch_version +from transformers.models.opt.modeling_opt import OPTDecoderLayer +from transformers import AutoConfig, OPTConfig, AutoModel +import pytest +from collections import OrderedDict +from typing import Dict +from deepspeed.ops.aio import AsyncIOBuilder + +device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu' + +if not required_torch_version(min_version=1.11): + pytest.skip("torch.Tensor.bitwise_left_shift in INT4 quantizer needs torch 1.11 or above.", + allow_module_level=True) + + +def reset_random(seed=1234): + np.random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + get_accelerator().manual_seed_all(seed) + + +def quantization_test_helper(pre_quant_type: torch.dtype, num_bits: int): + reset_random() + num_group = 1024 * 32 + group_size = 64 + quantization_config = {'num_bits': num_bits, 'group_size': group_size, 'group_dim': 1, 'symmetric': False} + + quantizer = Quantizer(config=quantization_config) + dequantizer = DeQuantizer(config=quantization_config, dtype=pre_quant_type) + + data = torch.randn(num_group, group_size, dtype=pre_quant_type, device=device) + + quantized_data, scale_buf, min_vals = quantizer.quantize(data) + dequantized_data = dequantizer.dequantize(quantized_data, scale_buf, min_vals) + + max_diff = torch.max(torch.abs(data - dequantized_data)) + mean_diff = torch.mean(torch.abs(data - dequantized_data)) + + # This threshold value is emperically selected. + assert mean_diff < 0.15 and max_diff < 0.5, f'Numeric error exceed threshold, mean diff {mean_diff} (threshold 0.15), max diff {max_diff} (threshold 0.5)' + + +def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): + import deepspeed + from transformers.integrations.deepspeed import HfDeepSpeedConfig + + if nvme_offload and not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: + GB = 1 << 30 + + ds_config = { + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": 2 * hf_config.hidden_size * hf_config.hidden_size, + "stage3_param_persistence_threshold": hf_config.hidden_size, + "stage3_max_live_parameters": 2 * hf_config.hidden_size * hf_config.hidden_size + }, + "steps_per_print": 2000, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": False, + 'weight_quantization': { + 'post_init_quant': { + 'fc': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'self_attn.q_proj': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'self_attn.k_proj': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'self_attn.v_proj': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'self_attn.out_proj': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'lm_head': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + 'embed_tokens': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + } + } + } + + if cpu_offload: + ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=1) + if nvme_offload: + ds_config["zero_optimization"]["offload_param"] = dict( + device="nvme", + pin_memory=True, + nvme_path='~/tmp_offload_dir', + buffer_count=5, + buffer_size=1 * GB, + ) + ds_config["aio"] = { + "block_size": 1048576, + "queue_depth": 8, + "thread_count": 1, + "single_submit": False, + "overlap_events": True, + } + + return ds_config + + hf_config = AutoConfig.from_pretrained('facebook/opt-125m') + ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits) + + input_ids = torch.ones(1, 16, dtype=torch.int32, device=device) + attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device) + + with torch.no_grad(): + ref_model = AutoModel.from_pretrained('facebook/opt-125m', torch_dtype=torch.float16).to(device) + ref_model.eval() + ref_output = ref_model(input_ids=input_ids, attention_mask=attention_mask) + + with torch.no_grad(): + dschf = HfDeepSpeedConfig(ds_config) + model = AutoModel.from_pretrained('facebook/opt-125m', torch_dtype=torch.float16) + model = model.eval() + + model = _init_group_wise_weight_quantization(model, ds_config) + ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] + ds_engine.module.eval() + model = ds_engine.module + + output = model(input_ids=input_ids, attention_mask=attention_mask) + + mean_diff = torch.mean(torch.abs(output.last_hidden_state - ref_output.last_hidden_state)) + + # This threshold value is emperically selected. + assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)' + + +def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): + import deepspeed + from transformers.integrations.deepspeed import HfDeepSpeedConfig + + if nvme_offload and not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: + GB = 1 << 30 + + ds_config = { + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": 2 * hf_config.hidden_size * hf_config.hidden_size, + "stage3_param_persistence_threshold": hf_config.hidden_size, + "stage3_max_live_parameters": 2 * hf_config.hidden_size * hf_config.hidden_size + }, + "steps_per_print": 2000, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": False, + 'weight_quantization': { + 'quantized_initialization': { + 'num_bits': bits, + 'group_size': 32, + 'group_dim': 1, + 'symmetric': False + }, + } + } + + if cpu_offload: + ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=1) + if nvme_offload: + ds_config["zero_optimization"]["offload_param"] = dict( + device="nvme", + pin_memory=True, + nvme_path='~/tmp_offload_dir', + buffer_count=5, + buffer_size=1 * GB, + ) + ds_config["aio"] = { + "block_size": 1048576, + "queue_depth": 8, + "thread_count": 1, + "single_submit": False, + "overlap_events": True, + } + + return ds_config + + hf_config = AutoConfig.from_pretrained('facebook/opt-125m') + ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits) + + input_ids = torch.ones(1, 16, dtype=torch.int32, device=device) + attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device) + + with torch.no_grad(): + ref_model = AutoModel.from_pretrained('facebook/opt-125m', torch_dtype=torch.float16).to(device) + ref_model.eval() + ref_output = ref_model(input_ids=input_ids, attention_mask=attention_mask) + + with torch.no_grad(): + dschf = HfDeepSpeedConfig(ds_config) + model = AutoModel.from_pretrained('facebook/opt-125m', torch_dtype=torch.float16) + model = model.eval() + ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] + ds_engine.module.eval() + model = ds_engine.module + + output = model(input_ids=input_ids, attention_mask=attention_mask) + + mean_diff = torch.mean(torch.abs(output.last_hidden_state - ref_output.last_hidden_state)) + + # This threshold value is emperically selected. + assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)' + + +@pytest.fixture(params=[4, 8], ids=["4bits", "8bits"]) +def quantization_bits(request): + return request.param + + +@pytest.fixture(params=[0, 1], ids=["0", "1"]) +def group_dim(request): + return request.param + + +class TestQuantizedInt(DistributedTest): + + def test_model_quantization(self, quantization_bits): + reset_random() + + config = AutoConfig.from_pretrained('facebook/opt-125m') + + with torch.no_grad(): + model = OPTDecoderLayer(config).half().to(device) + bits = quantization_bits + + ds_config = { + 'weight_quantization': { + 'post_init_quant': { + 'fc': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.q_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.k_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.v_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.out_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + } + } + } + } + + model = _init_group_wise_weight_quantization(model, ds_config) + + assert type(model.fc1) is QuantizedLinear + assert type(model.fc2) is QuantizedLinear + assert type(model.self_attn.q_proj) is QuantizedLinear + assert type(model.self_attn.k_proj) is QuantizedLinear + assert type(model.self_attn.v_proj) is QuantizedLinear + assert type(model.self_attn.out_proj) is QuantizedLinear + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_quantized_linear(self, quantization_bits, group_dim): + reset_random() + + layers = [] + + for idx in range(5): + layers.append( + (f'layer_{idx}', nn.Linear(in_features=128, out_features=128, dtype=torch.float16, device=device))) + + input_tensor = torch.randn(32, 128, dtype=torch.float16, device=device) + with torch.no_grad(): + model = nn.Sequential(OrderedDict(layers)) + + ref_output = model(input_tensor) + + ds_config = { + 'weight_quantization': { + 'post_init_quant': { + 'layer': { + 'num_bits': quantization_bits, + 'group_size': 64, + 'group_dim': group_dim, + 'symmetric': False + } + } + } + } + + model = _init_group_wise_weight_quantization(model, ds_config) + + assert type(model.layer_0) is QuantizedLinear + assert type(model.layer_1) is QuantizedLinear + assert type(model.layer_2) is QuantizedLinear + assert type(model.layer_3) is QuantizedLinear + assert type(model.layer_4) is QuantizedLinear + + output = model(input_tensor) + + mean_diff = torch.mean(torch.abs(ref_output - output)) + + # This threshold value is emperically selected. + assert mean_diff < 0.15, f'Numeric error exceed threshold, mean diff {mean_diff}' + + def test_float_int4_quantization(self): + reset_random() + quantization_test_helper(torch.float32, 4) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_half_int4_quantization(self): + reset_random() + quantization_test_helper(torch.float16, 4) + + def test_float_int8_quantization(self): + reset_random() + quantization_test_helper(torch.float32, 8) + + def test_half_int8_quantization(self): + reset_random() + quantization_test_helper(torch.float16, 8) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_post_init_quant(self, quantization_bits): + reset_random() + zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_post_init_quant_cpu_offload(self, quantization_bits): + reset_random() + zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_post_init_quant_nvme_offload(self): + reset_random() + zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True, bits=4) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_quantized_initialization(self, quantization_bits): + reset_random() + zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_quantized_initialization_cpu_offload(self, quantization_bits): + reset_random() + zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits) + + @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') + def test_zero3_int4_quantized_initialization_nvme_offload(self): + reset_random() + zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True, bits=4) diff --git a/tests/unit/inference/test_checkpoint_sharding.py b/tests/unit/inference/test_checkpoint_sharding.py index 611e6fc69edf..877f59ccf745 100644 --- a/tests/unit/inference/test_checkpoint_sharding.py +++ b/tests/unit/inference/test_checkpoint_sharding.py @@ -10,6 +10,22 @@ from deepspeed.model_implementations import DeepSpeedTransformerInference from unit.common import DistributedTest, DistributedFixture from transformers import AutoConfig, AutoModelForCausalLM +import deepspeed.comm as dist +from huggingface_hub import snapshot_download +from deepspeed.ops.op_builder import InferenceBuilder + +# Handle different versions of transformers +try: + from transformers.utils import is_offline_mode +except ImportError: + # For transformers >= 5.0, is_offline_mode was removed + # transformers >= 5.0 requires huggingface_hub >= 1.2.1 which has is_offline_mode + from huggingface_hub import is_offline_mode + +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) def check_dtype(model, expected_dtype): @@ -28,14 +44,17 @@ def find_dtype(module): assert (found_dtype == expected_dtype), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}" -@pytest.fixture( - params=["bigscience/bloom-560m", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-125M", "facebook/opt-125m"]) +@pytest.fixture(params=[ + "bigscience/bloom-560m", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-125M", "facebook/opt-350m", "facebook/opt-125m" +]) def model_name(request): return request.param @pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"]) def dtype(request): + if request.param not in get_accelerator().supported_dtypes(): + pytest.skip(f"{request.param} not supported by {get_accelerator().device_name()}.") return request.param @@ -85,3 +104,47 @@ def test(self, model_name, dtype, class_tmpdir, save_shard): model = model.eval() model = deepspeed.init_inference(model, config=inf_config) check_dtype(model, dtype) + + +@pytest.mark.seq_inference +class TestCheckpointShardinAutoTP(DistributedTest): + world_size = 2 + + def test(self, model_name, class_tmpdir): + + def write_checkpoints_json(model_name, class_tmpdir): + import json + from pathlib import Path + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if local_rank == 0: + # download only on first process + cached_repo_dir = snapshot_download( + model_name, + local_files_only=is_offline_mode(), + cache_dir=os.getenv("HF_HOME", None), + ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"], + ) + file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] + data = {"type": "ds_model", "checkpoints": file_list, "version": 1.0} + os.makedirs(os.path.join(class_tmpdir, model_name), exist_ok=True) + json.dump(data, open(os.path.join(class_tmpdir, model_name, "ds_inference_config.json"), "w")) + dist.barrier() + + world_size = int(os.getenv("WORLD_SIZE", "1")) + inf_config = { + "replace_with_kernel_inject": False, + "tensor_parallel": { + "tp_size": world_size + }, + "checkpoint": os.path.join(class_tmpdir, model_name, "ds_inference_config.json"), + } + + write_checkpoints_json(model_name, class_tmpdir) + + # Load model on meta tensors + model_config = AutoConfig.from_pretrained(model_name) + # Note that we use half precision to load initially, even for int8 + with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"): + model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16) + model = model.eval() + model = deepspeed.init_inference(model, config=inf_config) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index adff8e074974..9337eb67ff1e 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -3,84 +3,172 @@ # DeepSpeed Team -import os -import time -import torch import pytest + import itertools +import pickle +import os +import time +import requests +import fcntl + +from dataclasses import dataclass +from typing import List + import deepspeed -from deepspeed.git_version_info import torch_info -from unit.common import DistributedTest +import torch + +from huggingface_hub import HfApi from packaging import version as pkg_version -from deepspeed.ops.op_builder import OpBuilder +from torch import nn from transformers import pipeline from transformers.models.t5.modeling_t5 import T5Block from transformers.models.roberta.modeling_roberta import RobertaLayer -from huggingface_hub import HfApi -from deepspeed.model_implementations import DeepSpeedTransformerInference -from torch import nn + from deepspeed.accelerator import get_accelerator +from deepspeed.git_version_info import torch_info +from deepspeed.model_implementations import DeepSpeedTransformerInference +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.op_builder import OpBuilder + +from unit.common import DistributedTest rocm_version = OpBuilder.installed_rocm_version() if rocm_version != (0, 0): pytest.skip("skip inference tests on rocm for now", allow_module_level=True) _bert_models = [ - "bert-base-cased", - "bert-base-uncased", - "bert-large-cased", - "bert-large-uncased", - "bert-base-multilingual-cased", - "bert-base-multilingual-uncased", + "google-bert/bert-base-cased", + "google-bert/bert-base-uncased", + "google-bert/bert-large-cased", + "google-bert/bert-large-uncased", + "google-bert/bert-base-multilingual-cased", + "google-bert/bert-base-multilingual-uncased", "deepset/minilm-uncased-squad2", "cross-encoder/ms-marco-MiniLM-L-12-v2", "dslim/bert-base-NER", - "bert-large-uncased-whole-word-masking-finetuned-squad", - "distilbert-base-cased-distilled-squad", + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", + "distilbert/distilbert-base-cased-distilled-squad", ] _roberta_models = [ - "roberta-large", - "roberta-base", + "FacebookAI/roberta-large", + "FacebookAI/roberta-base", "deepset/roberta-base-squad2", "j-hartmann/emotion-english-distilroberta-base", "Jean-Baptiste/roberta-large-ner-english", ] _gpt_models = [ - "gpt2", - "distilgpt2", + "openai-community/gpt2", + "distilbert/distilgpt2", "Norod78/hebrew-bad_wiki-gpt_neo-tiny", - #"EleutherAI/gpt-j-6B", # Removed as this is causing OOM errors randomly + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m-deduped", "bigscience/bloom-560m", ] _opt_models = [ "facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture. - "facebook/opt-350m", # 350m applies layer norm after attnention layer which is different than other variants. + "facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants. ] -_all_models = HfApi().list_models() - -test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models) -test_tasks = [ +_test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models) +_test_tasks = [ "fill-mask", "question-answering", "text-classification", "token-classification", "text-generation", "text2text-generation", "summarization", "translation" ] -pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks} - -_model_w_tasks = itertools.product(*[test_models, test_tasks]) - - -def _valid_model_task(model_task): - m, t = model_task - return m in pytest.all_models[t] - - -pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks)) -pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks] -""" -These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph -""" -@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names) +@dataclass +class ModelInfo: + id: str + pipeline_tag: str + tags: List[str] + + +def _hf_model_list() -> List[ModelInfo]: + """ Caches HF model list to avoid repeated API calls """ + + cache_dir = os.getenv("HF_HOME", "~/.cache/huggingface") + cache_file_path = os.path.join(cache_dir, "DS_model_cache.pkl") + num_days = os.getenv("HF_CACHE_EXPIRY_DAYS", 1) + cache_expiration_seconds = num_days * 60 * 60 * 24 + + # Load or initialize the cache + model_data = {"cache_time": 0, "model_list": []} + if os.path.isfile(cache_file_path): + with open(cache_file_path, 'rb') as f: + try: + fcntl.flock(f, fcntl.LOCK_SH) + model_data = pickle.load(f) + except Exception as e: + print(f"Error loading cache file {cache_file_path}: {e}") + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + current_time = time.time() + + # Update the cache if it has expired + if ((model_data["cache_time"] + cache_expiration_seconds) < current_time) or os.getenv("FORCE_UPDATE_HF_CACHE", + default=False): + api = HfApi() + while True: + try: + model_list = [] + for model in _test_models: + model_list.extend(api.list_models(model_name=model)) + model_data["model_list"] = [ + ModelInfo(id=m.id, pipeline_tag=m.pipeline_tag, tags=m.tags) for m in model_list + ] + break # Exit the loop if the operation is successful + except requests.exceptions.HTTPError as e: + if e.response.status_code == 429: + print("Rate limit exceeded. Retrying in 60 seconds...") + time.sleep(60) + else: + raise # Re-raise the exception if it's not a 429 error + model_data["cache_time"] = current_time + + # Save the updated cache + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file_path, 'wb') as f: + try: + fcntl.flock(f, fcntl.LOCK_EX) + pickle.dump(model_data, f) + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + return model_data["model_list"] + + +# Get a list of all models and mapping from task to supported models +_hf_models = _hf_model_list() +_hf_model_names = [m.id for m in _hf_models] +_hf_task_to_models = {task: [m.id for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks} + +# Get all combinations of task:model to test +_model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]] + +# Assign to pytest variables for testing +pytest.model_w_tasks = _model_w_tasks +pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks] + + +@pytest.fixture(scope="module", autouse=True) +def verify_models(): + # Verify all test models are registered in HF + _test_models_not_found = [m for m in _test_models if m not in _hf_model_names] + if _test_models_not_found: + pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}") + + # Verify all models are assigned to at least one task + _models_to_be_tested = set(m for m, t in _model_w_tasks) + _missing_task_models = _models_to_be_tested.difference(_test_models) + if _missing_task_models: + pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}") + + +""" Fixtures for inference config """ + + +@pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names) def model_w_task(request): return request.param @@ -95,43 +183,17 @@ def enable_cuda_graph(request): return request.param -""" -This fixture will validate the configuration -""" +@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"]) +def enable_triton(request): + return request.param -@pytest.fixture() -def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph): - model, task = model_w_task - msg = "" - if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): - msg = "DS inference injection doesn't work well on older torch versions" - elif model not in pytest.all_models[task]: - msg = f"Not a valid model / task combination: {model} / {task}" - elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"): - msg = "CUDA not detected, cannot use CUDA Graph" - elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"): - msg = "CUDA Graph is only available in torch versions >= 1.10" - elif "gpt-j-6B" in model: - if dtype != torch.half: - msg = f"Not enough GPU memory to run {model} with dtype {dtype}" - elif enable_cuda_graph: - msg = f"Not enough GPU memory to run {model} with CUDA Graph enabled" - elif "gpt-neox-20b" in model: # TODO: remove this when neox issues resolved - msg = "Skipping gpt-neox-20b for now" - elif ("gpt-neox-20b" in model) and (dtype != torch.half): - msg = f"Not enough GPU memory to run {model} with dtype {dtype}" - elif ("bloom" in model) and (dtype != torch.half): - msg = f"Bloom models only support half precision, cannot use dtype {dtype}" - elif ("bert" not in model.lower()) and enable_cuda_graph: - msg = "Non bert/roberta models do no support CUDA Graph" - return msg +@pytest.fixture(params=[1, 2], ids=["ws1", "ws2"]) +def world_size(request): + return request.param -""" -These fixtures can be used to customize the query, inference args, and assert -statement for each combination of model /task -""" +""" Fixtures for running query """ @pytest.fixture @@ -167,14 +229,17 @@ def query(model_w_task): def inf_kwargs(model_w_task): model, task = model_w_task if task == "text-generation": - if model == "EleutherAI/gpt-j-6B": + if model == "EleutherAI/gpt-j-6b": # This model on V100 is hitting memory problems that limit the number of output tokens - return {"do_sample": False, "max_length": 12} - return {"do_sample": False, "max_length": 20} + return {"do_sample": False, "temperature": 1.0, "max_length": 12} + return {"do_sample": False, "temperature": 1.0, "max_length": 20} else: return {} +""" Assertion fixture for verifying model outputs """ + + def fill_mask_assert(x, y): return set(res["token_str"] for res in x) == set(res["token_str"] for res in y) @@ -226,6 +291,7 @@ def assert_fn(model_w_task): return assert_fn +# Used to verify DeepSpeed kernel injection worked with a model def check_injection(model): def verify_injection(module): @@ -240,9 +306,46 @@ def verify_injection(module): verify_injection(model) -""" -Tests -""" +# Used to Get Device name +def getDeviceId(local_rank): + device = torch.device(f"{get_accelerator().device_name(local_rank)}") + return device + + +# Verify that test is valid +def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton): + model, task = model_w_task + msg = "" + if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"): + msg = "CUDA not detected, cannot use CUDA Graph" + elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"): + msg = "CUDA Graph is only available in torch versions >= 1.10" + elif "gpt-j-6b" in model: + if dtype != torch.half: + msg = f"Not enough GPU memory to run {model} with dtype {dtype}" + elif enable_cuda_graph: + msg = f"Not enough GPU memory to run {model} with CUDA Graph enabled" + elif "gpt-neox-20b" in model: # TODO: remove this when neox issues resolved + msg = "Skipping gpt-neox-20b for now" + elif ("gpt-neox-20b" in model) and (dtype != torch.half): + msg = f"Not enough GPU memory to run {model} with dtype {dtype}" + elif ("bloom" in model) and (dtype != torch.half): + msg = f"Bloom models only support half precision, cannot use dtype {dtype}" + elif (model not in _bert_models + _roberta_models) and enable_cuda_graph: + msg = "Non bert/roberta models do no support CUDA Graph" + elif enable_triton and not (dtype in [torch.half]): + msg = "Triton is for fp16" + elif enable_triton and not deepspeed.HAS_TRITON: + msg = "triton needs to be installed for the test" + elif (model not in _bert_models + _roberta_models) and enable_triton: + msg = "Triton kernels do not support Non bert/roberta models yet" + + # These should be removed once we fix several inference tests failing + if model in [ + "EleutherAI/pythia-70m-deduped", "distilbert/distilbert-base-cased-distilled-squad", "EleutherAI/gpt-j-6b" + ]: + msg = "Test is currently broken" + return msg @pytest.mark.inference @@ -254,13 +357,21 @@ def test( model_w_task, dtype, enable_cuda_graph, + enable_triton, query, inf_kwargs, assert_fn, - invalid_model_task_config, + perf_meas=True, ): - if invalid_model_task_config: - pytest.skip(invalid_model_task_config) + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") + + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -284,13 +395,18 @@ def test( get_accelerator().synchronize() bs_time = time.time() - start - pipe.model = deepspeed.init_inference( - pipe.model, - mp_size=1, - dtype=dtype, - replace_with_kernel_inject=True, - enable_cuda_graph=enable_cuda_graph, - ) + args = { + 'mp_size': 1, + 'dtype': dtype, + 'replace_with_kernel_inject': True, + 'enable_cuda_graph': enable_cuda_graph, + 'use_triton': enable_triton, + 'triton_autotune': False, + } + if pipe.tokenizer.model_max_length < deepspeed.ops.transformer.inference.config.DeepSpeedInferenceConfig( + ).max_out_tokens: + args.update({'max_out_tokens': pipe.tokenizer.model_max_length}) + pipe.model = deepspeed.init_inference(pipe.model, **args) check_injection(pipe.model) # Warm-up queries for perf measurement #for i in range(10): @@ -301,6 +417,11 @@ def test( get_accelerator().synchronize() ds_time = time.time() - start + if perf_meas: + print( + f"model={model}, task={task}, dtype={dtype}, cuda_graph={enable_cuda_graph}, triton={enable_triton}, bs_time={bs_time}, ds_time={ds_time}" + ) + # facebook/opt* and some bigscient/bloom* models are not matching # baseline exactly, adding an exception to them for now if ("opt" in model) or ("bloom" in model): @@ -309,6 +430,7 @@ def test( # These performance tests are only measuring the time for a single # inference request, we just want to check that performance isn't terrible #assert ds_time <= (bs_time * 1.1) + assert assert_fn(bs_output, ds_output) @@ -316,10 +438,10 @@ def test( @pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"), ("EleutherAI/gpt-neox-20b", "text-generation"), ("bigscience/bloom-3b", "text-generation"), - ("EleutherAI/gpt-j-6B", "text-generation")], + ("EleutherAI/gpt-j-6b", "text-generation")], ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"]) class TestMPSize(DistributedTest): - world_size = 4 + world_size = 2 def test( self, @@ -328,10 +450,13 @@ def test( query, inf_kwargs, assert_fn, - invalid_model_task_config, ): - if invalid_model_task_config: - pytest.skip(invalid_model_task_config) + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -355,6 +480,38 @@ def test( assert assert_fn(bs_output, ds_output) +@pytest.mark.inference +@pytest.mark.parametrize("model_w_task", [("openai-community/gpt2", "text-generation")], ids=["gpt2"]) +class TestLowCpuMemUsage(DistributedTest): + world_size = 1 + + def test( + self, + model_w_task, + query, + inf_kwargs, + assert_fn, + ): + model, task = model_w_task + dtype = torch.float16 + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") + + local_rank = int(os.getenv("LOCAL_RANK", "0")) + device = getDeviceId(local_rank) + pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=device, framework="pt") + bs_output = pipe(query, **inf_kwargs) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=self.world_size, + dtype=dtype, + replace_method="auto", + replace_with_kernel_inject=True) + + ds_output = pipe(query, **inf_kwargs) + + assert assert_fn(bs_output, ds_output) + + @pytest.mark.seq_inference @pytest.mark.parametrize( "model_w_task, injection_policy", @@ -362,46 +519,33 @@ def test( (("google/t5-v1_1-small", "text2text-generation"), { T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo') }), - (("roberta-large", "fill-mask"), { + (("FacebookAI/roberta-large", "fill-mask"), { RobertaLayer: ('output.dense') }), ], ids=["t5", "roberta"], ) @pytest.mark.parametrize("dtype", [torch.float], ids=["fp32"]) -@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"]) class TestInjectionPolicy(DistributedTest): - world_size = [1, 2] - def test( - self, - model_w_task, - injection_policy, - query, - inf_kwargs, - assert_fn, - invalid_model_task_config, - dtype, - enable_cuda_graph, - ): - if invalid_model_task_config: - pytest.skip(invalid_model_task_config) + def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dtype, world_size): + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "2")) - # We have to load these large models on CPU with pipeline because not - # enough GPU memory - pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt") + pipe = pipeline(task, + model=model, + device=torch.device(get_accelerator().device_name(local_rank)), + framework="pt") bs_output = pipe(query, **inf_kwargs) pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype, injection_policy=injection_policy) - # Switch device to GPU so that input tensors are not on CPU - pipe.device = torch.device(get_accelerator().device_name(local_rank)) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) @@ -410,17 +554,68 @@ def test( @pytest.mark.seq_inference +@pytest.mark.parametrize("model_w_task", [("Felladrin/Llama-160M-Chat-v1", "text-generation")], ids=["llama"]) +@pytest.mark.parametrize("dtype", [torch.half], ids=["fp16"]) +class TestLlamaInjection(DistributedTest): + world_size = 1 + + def test(self, model_w_task, dtype, query, inf_kwargs, assert_fn): + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Accelerator {get_accelerator().device_name()} does not support {dtype}.") + + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + + model, task = model_w_task + + local_rank = int(os.getenv("LOCAL_RANK", "0")) + device = torch.device(get_accelerator().device_name(local_rank)) + + pipe = pipeline(task, + model=model, + device=torch.device("cpu"), + model_kwargs={"low_cpu_mem_usage": True}, + framework="pt") + + if dtype == torch.half: + pipe.model.half() + + pipe.device = device + pipe.model.to(device) + bs_output = pipe(query, **inf_kwargs) + + try: + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=self.world_size, + dtype=dtype, + replace_with_kernel_inject=True) + check_injection(pipe.model) + except AttributeError as e: + if "'LlamaAttention' object has no attribute 'num_heads'" in str(e): + pytest.skip("Skipping due to transformers version compatibility issue with self-attention") + raise e + + ds_output = pipe(query, **inf_kwargs) + + print(local_rank, "baseline", bs_output) + print(local_rank, "deepspeed", ds_output) + # Llama models are not matching baseline exactly + # We skip the result check for now, since this is irrelevant to this test + # assert assert_fn(bs_output, ds_output) + + +@pytest.mark.seq_inference +@pytest.mark.parametrize('keep_module_on_host', [True, False]) @pytest.mark.parametrize( "model_w_task", - [ - ("Helsinki-NLP/opus-mt-en-de", "translation"), - ], - ids=[ - "marian", - ], + [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")], + ids=["marian", "codegen"], #codegen has fusedqkv weight. ) -@pytest.mark.parametrize("dtype", [torch.float16], ids=["fp16"]) -@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) class TestAutoTensorParallelism(DistributedTest): world_size = [2] @@ -430,44 +625,98 @@ def test( query, inf_kwargs, assert_fn, - invalid_model_task_config, dtype, - enable_cuda_graph, + keep_module_on_host, ): - if invalid_model_task_config: - pytest.skip(invalid_model_task_config) + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "2")) - # We have to load these large models on CPU with pipeline because not - # enough GPU memory - pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt") + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") + + if model == "Salesforce/codegen-350M-mono": + pytest.skip("Disable Codegen model due to slight result difference") + #TODO: re-enable this test once we have a fix for the slight result difference + + pipe = pipeline(task, + model=model, + device=torch.device(get_accelerator().device_name(local_rank)), + framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) - # Switch device to GPU so that input tensors are not on CPU - pipe.device = torch.device(get_accelerator().device_name(local_rank)) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) + ds_output = pipe(query, **inf_kwargs) + + print(local_rank, "baseline", bs_output) + print(local_rank, "deepspeed", ds_output) + assert assert_fn(bs_output, ds_output) + + if keep_module_on_host: + for name, param in model.named_parameters(): + assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu" + + @pytest.mark.world_size(3) + def test_odd_world_size( + self, + model_w_task, + query, + inf_kwargs, + assert_fn, + dtype, + keep_module_on_host, + ): + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + model, task = model_w_task + if model == "Salesforce/codegen-350M-mono": + pytest.skip("codegen does not supported by odd world_size") + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "3")) + + pipe = pipeline(task, + model=model, + device=torch.device(get_accelerator().device_name(local_rank)), + framework="pt") + bs_output = pipe(query, **inf_kwargs) + + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) + if keep_module_on_host: + for name, param in model.named_parameters(): + assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu" + @pytest.mark.nightly @pytest.mark.parametrize( "model_family, model_name", ( ["gpt2", "EleutherAI/gpt-neo-2.7B"], - ["gpt2", "EleutherAI/gpt-j-6B"], - ["gpt2", "gpt2-xl"], + #["gpt2", "EleutherAI/gpt-j-6b"], # Causing OOM for this test + ["gpt2", "openai-community/gpt2-xl"], ), ) @pytest.mark.parametrize("task", ["lambada_standard"]) class TestLMCorrectness(DistributedTest): world_size = 1 + exec_timeout = 1200 # Give these tests longer to complete def test(self, model_family, model_name, task): # imports here to avoid import errors when pytest collects tests @@ -476,20 +725,42 @@ def test(self, model_family, model_name, task): import lm_eval.tasks import lm_eval.evaluator + # The bootstrap_stderr function in lm_eval.metrics uses a + # multiprocessing Pool to increase performance. Since we use a Pool for + # our distributed tests and cannot nest Pools, we must redefine and + # patch this function with a version that does not use Pool. + def no_pool_bootstrap_stderr(f, xs, iters): + from lm_eval.metrics import _bootstrap_internal + from lm_eval.metrics import sample_stddev + res = [] + chunk_size = min(1000, iters) + for i in range(iters // chunk_size): + res.extend(_bootstrap_internal(f, chunk_size)((i, xs))) + return sample_stddev(res) + + lm_eval.metrics.bootstrap_stderr = no_pool_bootstrap_stderr + local_rank = os.getenv("LOCAL_RANK", "0") device = torch.device(get_accelerator().device_name(local_rank)) dtype = torch.float task_dict = lm_eval.tasks.get_task_dict([task]) - if 'gpt-j-6B' in model_name: + if 'gpt-j-6b' in model_name: dtype = torch.half lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}", {"device": "cpu"}) setattr(lm, model_family, getattr(lm, model_family).half().to(device)) lm._device = device else: - lm = lm_eval.models.get_model(model_family).create_from_arg_string( - f"pretrained={model_name}", {"device": get_accelerator().device_name()}) + if get_accelerator().device_name() == 'hpu': + #lm_eval not supporting HPU device, so get model with CPU and move it to HPU. + lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}", + {"device": "cpu"}) + setattr(lm, model_family, getattr(lm, model_family).to(device)) + lm._device = device + else: + lm = lm_eval.models.get_model(model_family).create_from_arg_string( + f"pretrained={model_name}", {"device": get_accelerator().device_name()}) get_accelerator().synchronize() start = time.time() @@ -497,6 +768,7 @@ def test(self, model_family, model_name, task): get_accelerator().synchronize() bs_time = time.time() - start + getattr(lm, model_family).to("cpu") ds_model = deepspeed.init_inference( getattr(lm, model_family), mp_size=1, diff --git a/tests/unit/inference/test_inference_config.py b/tests/unit/inference/test_inference_config.py index 375563abf65b..39d62d17372c 100644 --- a/tests/unit/inference/test_inference_config.py +++ b/tests/unit/inference/test_inference_config.py @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest): world_size = 1 def test_overlap_kwargs(self): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": torch.float32} kwargs = {"replace_with_kernel_inject": True} engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs) @@ -37,7 +37,7 @@ def test_kwargs_and_config(self): assert engine._config.dtype == kwargs["dtype"] def test_json_config(self, tmpdir): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"} config_json = create_config_from_dict(tmpdir, config) engine = deepspeed.init_inference(torch.nn.Module(), config=config_json) diff --git a/tests/unit/inference/test_model_profiling.py b/tests/unit/inference/test_model_profiling.py index 626bfd11f2cc..319055d0ea55 100644 --- a/tests/unit/inference/test_model_profiling.py +++ b/tests/unit/inference/test_model_profiling.py @@ -11,55 +11,36 @@ from transformers import pipeline from unit.common import DistributedTest from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import InferenceBuilder +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) -@pytest.fixture -def query(model, task): - if task == "text-generation": - return "DeepSpeed is" - elif task == "fill-mask": - if "roberta" in model: - return "I am a model" - else: - return "I am a [MASK] model" - else: - raise NotImplementedError - - -@pytest.fixture -def inf_kwargs(task): - if task == "text-generation": - return {"do_sample": False, "min_length": 50, "max_length": 50} - else: - return {} +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) @pytest.mark.inference -@pytest.mark.parametrize("model,task", [ - ("bert-base-cased", "fill-mask"), - ("roberta-base", "fill-mask"), - ("gpt2", "text-generation"), - ("facebook/opt-125m", "text-generation"), - ("bigscience/bloom-560m", "text-generation"), -]) -@pytest.mark.parametrize("cuda_graphs", [True, False]) @pytest.mark.parametrize("use_cuda_events", [True, False]) +@pytest.mark.parametrize("enable_cuda_graph", [True, False]) class TestModelProfiling(DistributedTest): world_size = 1 - def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dtype=torch.float16): - if cuda_graphs and "bert" not in model: - pytest.skip(f"CUDA Graph not supported for {model}") + def test(self, enable_cuda_graph, use_cuda_events): + task = "fill-mask" + model = "bert-base-cased" + dtype = torch.float16 + query = "I am a [MASK] model" local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) - pipe = pipeline(task, model, framework="pt", device=local_rank) + pipe = pipeline(task, model, framework="pt", device=get_accelerator().device_name(local_rank)) pipe.model = deepspeed.init_inference(pipe.model, dtype=dtype, mp_size=world_size, replace_with_kernel_inject=True, - enable_cuda_graph=cuda_graphs) + enable_cuda_graph=enable_cuda_graph) pipe.model.profile_model_time(use_cuda_events=use_cuda_events) e2e_times = [] @@ -68,7 +49,7 @@ def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dty get_accelerator().synchronize() start = time.perf_counter_ns() - r = pipe(query, **inf_kwargs) + r = pipe(query) get_accelerator().synchronize() end = time.perf_counter_ns() diff --git a/tests/unit/inference/test_stable_diffusion.py b/tests/unit/inference/test_stable_diffusion.py new file mode 100644 index 000000000000..775a02c2e878 --- /dev/null +++ b/tests/unit/inference/test_stable_diffusion.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch +import pytest +import deepspeed +import numpy +from unit.common import DistributedTest +from deepspeed.accelerator import get_accelerator + + +# Setup for these models is different from other pipelines, so we add a separate test +@pytest.mark.stable_diffusion +class TestStableDiffusion(DistributedTest): + world_size = 1 + + def test(self): + from diffusers import DiffusionPipeline + from image_similarity_measures.quality_metrics import rmse + dev = get_accelerator().device_name() + generator = torch.Generator(device=dev) + seed = 0xABEDABE7 + generator.manual_seed(seed) + prompt = "a dog on a rocket" + model = "prompthero/midjourney-v4-diffusion" + local_rank = int(os.getenv("LOCAL_RANK", "0")) + device = torch.device(f"{dev}:{local_rank}") + pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half) + pipe = pipe.to(device) + baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0] + + pipe = deepspeed.init_inference( + pipe, + mp_size=1, + dtype=torch.half, + replace_with_kernel_inject=True, + enable_cuda_graph=True, + ) + generator.manual_seed(seed) + deepspeed_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0] + + rmse_value = rmse(org_img=numpy.asarray(baseline_image), pred_img=numpy.asarray(deepspeed_image)) + + # RMSE threshold value is arbitrary, may need to adjust as needed + assert rmse_value <= 0.01 diff --git a/tests/unit/inference/v2/__init__.py b/tests/unit/inference/v2/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/inference_test_utils.py b/tests/unit/inference/v2/inference_test_utils.py new file mode 100644 index 000000000000..d63c51267e51 --- /dev/null +++ b/tests/unit/inference/v2/inference_test_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch +from deepspeed.accelerator import get_accelerator + +TOLERANCES = None + + +def get_tolerances(): + global TOLERANCES + if TOLERANCES is None: + TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)} + if get_accelerator().is_bf16_supported(): + # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs + # 10 (+1) bits) + TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2) + return TOLERANCES + + +DTYPES = None + + +def get_dtypes(include_float=True): + global DTYPES + if DTYPES is None: + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass + return DTYPES + + +def allclose(x, y, tolerances: Tuple[int, int] = None): + assert x.dtype == y.dtype + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances + return torch.allclose(x, y, rtol=rtol, atol=atol) diff --git a/tests/unit/inference/v2/kernels/__init__.py b/tests/unit/inference/v2/kernels/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/core_ops/__init__.py b/tests/unit/inference/v2/kernels/core_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py b/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py new file mode 100644 index 000000000000..376188b92565 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDABiasActivation +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_bias_act_implementation(input: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + bias_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + dtype = input.dtype + input_f = input.to(torch.float32) + if bias is not None: + bias_f = bias.to(torch.float32) + output_f = input_f + bias_f + else: + output_f = input_f + output_f = bias_func_map[act_type](output_f) + + return output_f.to(dtype) + + +def _bias_activation_test_helper(tokens: int, + channels: int, + act_fn: ActivationType, + dtype: DtypeEnum, + use_bias: bool = True) -> None: + """ + Fully parameterized testing entry point. + """ + # Input vals + input_tensor = torch.randn((tokens, channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + if use_bias: + bias = torch.randn((channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + else: + bias = None + + # Reference output + ref_output = reference_bias_act_implementation(input_tensor, bias, act_fn) + + bias_act = CUDABiasActivation(channels, dtype, act_fn) + + # New output + ds_tensor = input_tensor.clone() + bias_act(ds_tensor, bias) + + # Check + assert allclose(ds_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_token_channels_permutations(tokens: int, channels: int, dtype: torch.dtype) -> None: + """ + Validate bias activation kernel with different token and channel permutations when using the RELU + activation function. + """ + act_fn = ActivationType.RELU + dtype = DtypeEnum(dtype) + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", + [ActivationType.RELU, ActivationType.GELU, ActivationType.SILU, ActivationType.IDENTITY]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate bias activation kernel with different activation functions. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_bias() -> None: + """ + Validate bias activation kernel with no bias. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + act_fn = ActivationType.IDENTITY + _bias_activation_test_helper(tokens, channels, act_fn, dtype, use_bias=False) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py b/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py new file mode 100644 index 000000000000..864db6204a16 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import BlasLibLinear +from ....v2.inference_test_utils import allclose + +# Note: only testing with FP16 and BF16 because we use TF32 on Ampere and we don't have a good +# set of tolerances. Since this is just on top of BLAS though, the test is more about +# making sure the stride/contiguity is correct and that's data type agnostic. + + +def reference_implementation(hidden_states, weights): + return hidden_states @ weights.t() + + +problem_shapes = [ + (1, 1, 1024, 1024), + (1, 1024, 1024, 1024), + (2, 1024, 1024, 1024), + (1, 128, 768, 3072), + (1, 128, 3072, 768), + (1, 1024, 8192, 8192), + (1, 733, 8192, 32768), + (1, 13, 32768, 8192), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + ds_output = ds_kernel(ds_output, hidden_states, weights) + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear_t(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + # Transpose the weights then revert to the format we expect. + weights = weights.t().contiguous() + weights = weights.t() + ds_output = ds_kernel(ds_output, hidden_states, weights) + + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py b/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py new file mode 100644 index 000000000000..8cb95a6cdcba --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAGatedActivation +from deepspeed.inference.v2.inference_utils import ActivationType +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_geglu_implementation(input: torch.Tensor, + bias: Optional[torch.Tensor] = None, + act_fn: Optional[ActivationType] = ActivationType.GEGLU) -> torch.Tensor: + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + dtype = input.dtype + input = input.to(torch.float32) + + if bias is not None: + bias = bias.to(torch.float32) + input = input + bias + + act_act = input[..., ::2] + act_linear = input[..., 1::2] + + act_act = act_func_map[act_fn](act_act) + + return (act_act * act_linear).to(dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("shape", [(1372, 16384), (2, 743, 22016)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_dtypes(shape: Iterable[int], dtype: torch.dtype) -> None: + input_tensor = torch.randn(shape, dtype=dtype, device=get_accelerator().current_device_name()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + # Build kernel + geglu = CUDAGatedActivation(input_tensor.size(-1), input_tensor.dtype, ActivationType.GEGLU) + + # New output + output_shape = list(input_tensor.shape) + output_shape[-1] //= 2 + output_tensor = torch.empty(output_shape, dtype=input_tensor.dtype, device=get_accelerator().current_device_name()) + geglu(output_tensor, input_tensor) + + # Check + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]) +def test_act_fn(act_fn: ActivationType) -> None: + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=act_fn) + + cuda_act = CUDAGatedActivation(4096, torch.float16, act_fn) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_act_with_bias(): + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + bias = torch.randn(4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, bias=bias, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(4096, torch.float16, ActivationType.GEGLU) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + + cuda_act(output_tensor, input_tensor, bias) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_max_channels(): + input_tensor = torch.randn(832, 48152, dtype=torch.float16, device=get_accelerator().current_device()) + + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(48152, torch.float16, ActivationType.GEGLU) + + output_tensor = torch.empty(832, 24076, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_bad_dtype() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.int8, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_bad_act_fn() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.float16, ActivationType.RELU) + + +@pytest.mark.inference_v2_ops +def test_bad_alignment() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(127, torch.float16, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_too_many_channels() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(49160, torch.float16, ActivationType.GEGLU) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py b/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py new file mode 100644 index 000000000000..0b489894bb9b --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPostLN +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_kernel = CUDAFPPostLN(hidden_states.size(-1), residual.dtype) + ds_output = torch.empty_like(residual) + post_ln_kernel(ds_output, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py b/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py new file mode 100644 index 000000000000..ffb748e57af2 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPreLN +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + residual_out = residual_f + hidden_states_f + hidden_out = torch.nn.functional.layer_norm(residual_out, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon) + return residual_out.to(hidden_states.dtype), hidden_out.to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_pre_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output_res, ref_output_hid = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_kernel = CUDAFPPreLN(hidden_states.size(-1), residual.dtype) + ds_output_res = torch.empty_like(residual) + ds_output_hid = torch.empty_like(hidden_states) + pre_ln_kernel(ds_output_res, ds_output_hid, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output_res, ref_output_res) + assert allclose(ds_output_hid, ref_output_hid) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py b/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py new file mode 100644 index 000000000000..63b16da171c9 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_rms_norm(vals: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor: + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + +def reference_rms_pre_norm(vals: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + epsilon: float = 1e-5) -> torch.Tensor: + residual = residual + vals + return residual, reference_rms_norm(residual, gamma, epsilon) + + +def _rms_norm_testing_helper(rows: int, channels: int, do_residual: bool, dtype: DtypeEnum) -> None: + device = get_accelerator().current_device_name() + t_dtype = dtype.value + + vals = torch.randn((rows, channels), dtype=t_dtype, device=device) + gamma = torch.randn((channels), dtype=t_dtype, device=device) + epsilon = 1e-5 + + if do_residual: + residual_in = torch.randn((rows, channels), dtype=t_dtype, device=device) + ds_residual = residual_in.clone() + + ref_residual, ref_output = reference_rms_pre_norm(vals, residual_in, gamma, epsilon) + + kernel = CUDARMSPreNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(ds_residual) + + kernel(ds_residual, ds_out, residual_in, vals, gamma) + + assert allclose(ds_out, ref_output) + assert allclose(ds_residual, ref_residual) + else: + + ref_output = reference_rms_norm(vals, gamma, epsilon) + + kernel = CUDARMSNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(vals) + + kernel(ds_out, vals, gamma) + + assert allclose(ds_out, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes()) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_dtypes(dtype: DtypeEnum, do_residual: bool) -> None: + _rms_norm_testing_helper(883, 1024, do_residual, DtypeEnum(dtype)) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("rows, cols", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_shapes(rows: int, cols: int, do_residual: bool) -> None: + _rms_norm_testing_helper(rows, cols, do_residual, DtypeEnum.fp16) diff --git a/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py b/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py b/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py new file mode 100644 index 000000000000..ed76dabe1f4c --- /dev/null +++ b/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.cutlass_ops import MoEGEMM +from ....v2.inference_test_utils import allclose + +SINGLE_EXPERT_CASES = [(13, 2048, 2048), (256, 1024, 4096), (278, 5120, 2048), (893, 5120, 2560)] + +PYTORCH_ACT_FN_MAP = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.RELU: torch.nn.functional.relu +} + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, in_neurons, out_neurons", SINGLE_EXPERT_CASES) +def test_single_expert(n_tokens: int, in_neurons: int, out_neurons: int) -> None: + """ + Validate that the GEMM kernel produces identical results for a single GEMM instance. + """ + device = get_accelerator().current_device() + + activations = torch.rand((n_tokens, in_neurons), device=device, dtype=torch.float16) - 0.5 + weights = torch.rand((1, in_neurons, out_neurons), device=device, dtype=torch.float16) - 0.5 + biases = torch.randn((1, out_neurons), device=device, dtype=torch.float16) + + weights_ref = weights.reshape(in_neurons, out_neurons) + biases_ref = biases.reshape(out_neurons) + ref_output = torch.matmul(activations, weights_ref) + biases_ref + + moe_gemm = MoEGEMM(DtypeEnum.fp16, ActivationType.IDENTITY) + output = torch.empty((n_tokens, out_neurons), device=device, dtype=torch.float16) + cumsum_rows = torch.tensor([n_tokens], dtype=torch.int64, device=device) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +def moe_test_helper(in_neurons: int, out_neurons: int, n_experts: int, max_tokens_per_expert: int, + act_fn: ActivationType, dtype: DtypeEnum) -> None: + """ + Helper function for validating the GEMM kernel for a single expert. + """ + device = get_accelerator().current_device() + + expert_allocations = torch.randint(0, max_tokens_per_expert, (n_experts, ), device=device, dtype=torch.int32) + cumsum_rows = expert_allocations.cumsum(dim=0) + print(cumsum_rows.dtype) + + activations = torch.rand((cumsum_rows[-1], in_neurons), device=device, dtype=dtype.value) - 0.5 + weights = torch.rand((n_experts, in_neurons, out_neurons), device=device, dtype=dtype.value) - 0.5 + biases = torch.randn((n_experts, out_neurons), device=device, dtype=dtype.value) + + out_ref = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + for expert_idx in range(n_experts): + start = cumsum_rows[expert_idx - 1] if expert_idx > 0 else 0 + end = cumsum_rows[expert_idx] + activations_slice = activations[start:end] + weights_slice = weights[expert_idx] + biases_slice = biases[expert_idx] + out_ref[start:end] = torch.matmul(activations_slice, weights_slice) + biases_slice + + if act_fn != ActivationType.IDENTITY: + act_fn_fn = PYTORCH_ACT_FN_MAP[act_fn] + out_ref = act_fn_fn(out_ref) + + moe_gemm = MoEGEMM(DtypeEnum.fp16, act_fn) + output = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + + if dtype == DtypeEnum.bf16: + assert allclose(output, out_ref, tolerances=(1e-1, 1e-1)) + else: + assert allclose(output, out_ref, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("max_tokens_per_expert", [1, 4, 16, 64, 128]) +def test_multi_expert(max_tokens_per_expert: int) -> None: + """ + Validate for multi-expert GEMM instances that the output is identical to the reference. + """ + moe_test_helper(5120, 2048, 64, max_tokens_per_expert, ActivationType.IDENTITY, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate activation function behavior. + """ + moe_test_helper(5120, 2048, 64, 32, act_fn, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", [DtypeEnum.fp16, DtypeEnum.bf16]) +def test_dtypes(dtype: DtypeEnum) -> None: + """ + Validate data type behavior. + """ + moe_test_helper(5120, 2048, 64, 32, ActivationType.IDENTITY, dtype) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/__init__.py b/tests/unit/inference/v2/kernels/ragged_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py new file mode 100644 index 000000000000..be7454fee4aa --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + AllocationMode, + DSSequenceDescriptor, + DSStateManager, + DSStateManagerConfig, + KVCacheConfig, + MemoryConfig, + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, +) +from ....v2.inference_test_utils import allclose + + +def build_simple_batch(seq_lens: List[int], + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> RaggedBatchWrapper: + """ + Construct a simple batch with the given sequence lengths. This method should not + be used for for testing scenarios that require information about KV or sequence + history. + """ + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + for seq_len in seq_lens: + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize(padding=padding) + + return batch + + +def build_complex_batch(seq_params: List[Tuple[int, int, int]], + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> Tuple[RaggedBatchWrapper, int]: + """ + Construct a fully paramtrized batch with the given sequence lengths. This method + can be used to construct more realistic inputs for testing scenarios that will interact + with all the members of the RaggedBatchWrapper. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + total_kv_blocks = 0 + + for seq_len, n_seen_tokens, kv_ptr in seq_params: + n_kv_blocks = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + seq_desc = PlaceholderSequenceDescriptor(seen_tokens=n_seen_tokens, + cur_allocated_blocks=n_kv_blocks, + kv_blocks_ptr=kv_ptr) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + total_kv_blocks += n_kv_blocks + + batch.finalize(padding=padding) + + return batch, total_kv_blocks + + +def build_batch_and_manager( + seq_params: List[Tuple[int, int]], + head_size: int, + n_heads_kv: int, + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False, + kv_fill: Optional[List[torch.Tensor]] = None +) -> Tuple[RaggedBatchWrapper, DSStateManager, List[DSSequenceDescriptor]]: + """ + Will construct and populate a batch and KVCache with the given sequence parameters. + + Arguments: + seq_params (List[Tuple[int, int]]): A list of tuples containing the sequence length and + the number of tokens that have already been seen for that sequence. + head_size (int): The size of each attention head. + n_heads_kv (int): The number of attention heads for the KV-cache. + kv_block_size (int): The size of each block in the KV-cache. + vocab_range (Optional[int]): The range of the vocabulary. Defaults to 100. + padding (Optional[bool]): Whether to pad the batch. Defaults to False. + kv_fill (Optional[List[torch.Tensor]]): A list of tensors to use to populate the KV-cache. + If this is not provided, the KV-cache will be treated as empty and the contents should + not be relied upon. NOTE(cmikeh2): This functionality relies on the functionality + of LinearBlockedKVCopy. If tests relying on this feature are failing, make sure that + LinearBlockedKVCopy is working correctly. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + fill_lens = [seq_param[1] for seq_param in seq_params] + max_created_batch_len = max(sum(seq_lens), sum(fill_lens)) + total_tokens = max(max_created_batch_len, 1024) + n_seqs = max(len(seq_lens), 128) + + req_kv_blocks = [None] * n_seqs + total_kv_blocks = 0 + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + req_kv_blocks[i] = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + total_kv_blocks += req_kv_blocks[i] + + kv_config = KVCacheConfig(block_size=kv_block_size, + num_allocation_groups=1, + cache_shape=(1, n_heads_kv, head_size)) + memory_config = MemoryConfig(mode=AllocationMode.ALLOCATE, size=total_kv_blocks) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens, + memory_config=memory_config) + + batch = RaggedBatchWrapper(config) + state_manager = DSStateManager(config, (kv_config, )) + + # At the beginning of operation, the design of the allocator is such that it will return + # linear blocks of memory. The following will "warm up" the allocator so that we can be + # more certain that code is not dependent on this behavior. + all_allocs = [] + for _ in range(20): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(0, total_kv_blocks) + if blocks_to_allocate <= state_manager.free_blocks[0] and blocks_to_allocate > 0: + all_allocs.append(state_manager.allocate_blocks(blocks_to_allocate)) + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + state_manager._kv_cache.free(all_allocs[idx]) + + del all_allocs[idx] + + for alloc in all_allocs: + state_manager._kv_cache.free(alloc) + + assert state_manager.free_blocks[0] == total_kv_blocks + + batch.clear() + seq_descs = [] + + if kv_fill is None or sum(fill_lens) == 0: + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + seq_desc.pre_forward(n_seen_tokens) + seq_desc.post_forward() + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + + # Insert sequence into batch + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + seq_descs.append(seq_desc) + else: + qkv = torch.empty((total_tokens, (n_heads_kv * 3) * head_size), + dtype=torch.float16, + device=get_accelerator().current_device()) + fills_as_tensor = torch.tensor(fill_lens, dtype=torch.int32) + fill_cumsum = torch.cat((torch.tensor([0], dtype=torch.int32), torch.cumsum(fills_as_tensor, dim=0))) + + for i, (_, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + if n_seen_tokens > 0: + dummy_fill_toks = torch.randint(0, vocab_range, (n_seen_tokens, )) + batch.insert_sequence(seq_desc, dummy_fill_toks) + seq_desc.pre_forward(n_seen_tokens) + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + seq_descs.append(seq_desc) + + if n_seen_tokens == 0: + continue + + assert kv_fill[i].shape[0] == n_seen_tokens + assert kv_fill[i].shape[1] == n_heads_kv * head_size * 2 + + local_q = torch.randn((n_seen_tokens, n_heads_kv * head_size), dtype=torch.float16, device=qkv.device) + local_qkv = torch.cat((local_q, kv_fill[i]), dim=1) + qkv[fill_cumsum[i]:fill_cumsum[i + 1]] = local_qkv + + batch.finalize(padding=padding) + + from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy + kv_copy = LinearBlockedKVCopy(head_size, n_heads_kv, n_heads_kv, torch.float16) + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + for seq_desc in seq_descs: + if seq_desc.in_flight_tokens > 0: + seq_desc.post_forward() + + batch.clear() + + for i, (seq_len, _) in enumerate(seq_params): + seq_desc = state_manager.get_or_create_sequence(i) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + + # We will skip KV cache allocation here because we did a lump allocation above + # for both the fill and the sequence itself. + + batch.finalize(padding=padding) + + return batch, state_manager, seq_descs + + +def validate_kv_cache(kv_cache: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_descs: List[DSSequenceDescriptor], + batch: RaggedBatchWrapper, + exact: bool = True) -> None: + """ + Given a QKV tensor and a KV cache, validate that the cache contains the correct values. + """ + block_size = kv_cache.shape[1] + n_kv_heads = kv_cache.shape[3] + head_size = kv_cache.shape[4] + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + assigned_kv_blocks = seq_desc.kv_cache_ids(on_device=False) + + real_k_values = k[start_idx:start_idx + seq_desc.in_flight_tokens] + real_v_values = v[start_idx:start_idx + seq_desc.in_flight_tokens] + + start_block_idx = seq_desc.seen_tokens // block_size + local_start_idx = 0 + cur_start_idx = seq_desc.seen_tokens + + for block_idx in range(start_block_idx, seq_desc.cur_allocated_blocks): + block = kv_cache[assigned_kv_blocks[0, block_idx].item()] + block_start_idx = cur_start_idx % block_size + n_tokens_to_check = min(block_size - block_start_idx, seq_desc.in_flight_tokens - local_start_idx) + block_end_idx = block_start_idx + n_tokens_to_check + + if exact: + assert torch.equal( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert torch.equal( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + else: + assert allclose( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert allclose( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + + local_start_idx += n_tokens_to_check + cur_start_idx += n_tokens_to_check diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py b/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py new file mode 100644 index 000000000000..a33c938a0608 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.kernels.ragged_ops import AtomBuilder +from .ragged_testing_utils import build_complex_batch + +Q_BLOCK_SIZE = 128 +KV_BLOCK_SIZE = 128 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_params', [(1, 0, 0), (1, 228, 0), (383, 0, 0), (1, 494, 0)]) +def test_single_sequence(seq_params) -> None: + seq_len, n_seen_tokens, _ = seq_params + + batch, _ = build_complex_batch([seq_params], kv_block_size=KV_BLOCK_SIZE, padding=False) + atom_builder = AtomBuilder() + + atoms = torch.empty((8, 8), dtype=torch.int32, device=torch.device("cpu")) + atoms, n_atoms = atom_builder(atoms, batch, Q_BLOCK_SIZE, KV_BLOCK_SIZE) + + calc_n_atoms = (seq_len + 127) // 128 + + assert n_atoms == calc_n_atoms + + for i, atom in enumerate(atoms[:n_atoms]): + # Since the ptr was 0, first 2 elements should be 0 + assert atom[0] == 0 + assert atom[1] == 0 + + # Since we have a single sequence, the q_start_idx should always be + # whichever atom we're on multiplied by the block size + assert atom[2] == i * Q_BLOCK_SIZE + assert atom[3] == min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + total_toks = i * Q_BLOCK_SIZE + min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + + assert atom[4] == (total_toks + n_seen_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + assert atom[5] == (total_toks + n_seen_tokens) + + assert atom[6] == n_seen_tokens + i * Q_BLOCK_SIZE diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py new file mode 100644 index 000000000000..ce5a178c9548 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + AtomBuilder, + BlockedFlashAttn, + get_q_block_size, + get_kv_block_size, + LinearBlockedKVCopy, +) +from deepspeed.inference.v2.ragged import split_kv +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from .ragged_testing_utils import build_batch_and_manager +from ....v2.inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False +""" +NOTE(cmikeh2): These tests depend on atom construction and KV-cache copying to behave correctly. +If one or the other of those is not working, then these tests will fail. Before debugging here, +make sure that the atom construction and KV-cache copying tests are passing. +""" + + +def _blocked_flash_testing_helper(head_size: int, n_heads_q: int, n_heads_kv: int, + seq_params: List[Tuple[int, int]]) -> None: + """ + Helper function for testing blocked flash attention. Used to enable parametrize to only set up + a subset of parameters before being passed to the unified test function. + """ + q_block_size = get_q_block_size(head_size) + kv_block_size = get_kv_block_size(head_size) + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + atom_builder = AtomBuilder() + kv_copy = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, DtypeEnum.fp16) + atom_flash = BlockedFlashAttn(head_size, DtypeEnum.fp16) + + total_atoms = sum((seq[0] + q_block_size - 1) // q_block_size for seq in seq_params) + atoms = torch.empty((total_atoms, 8), dtype=torch.int32, device=get_accelerator().current_device()) + alloc_func = RaggedUtilsBuilder().load().allocate_fast_host_buffer + atoms_host = alloc_func(atoms) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + atoms_host, n_atoms = atom_builder(atoms_host, batch, q_block_size, kv_block_size) + atoms.copy_(atoms_host[:n_atoms]) + + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + out = torch.empty((batch.current_tokens, head_size * n_heads_q), + device=get_accelerator().current_device(), + dtype=torch.float16) + k_cache, v_cache = split_kv(kv_cache) + q = qkv[:, :head_size * n_heads_q] + + atom_flash(out, q, k_cache, v_cache, atoms, 1.0) + + if validate_accuracy: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py new file mode 100644 index 000000000000..5a99422ba9ff --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 8), (63, 1)]) +@pytest.mark.parametrize("head_size", [64, 80, 96, 128]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, head_size: int): + """ + Validate that the copy works correctly + """ + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +@pytest.mark.parametrize("head_size", [64, 80, 96, 128]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, head_size: int): + """ + Validate that the copy works correctly + """ + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 80, 96, 128]) +def test_multi_sequence(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py new file mode 100644 index 000000000000..33dd0a4c2700 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import BlockedRotaryEmbeddings, BlockedTrainedRotaryEmbeddings +from deepspeed.inference.v2.ragged import RaggedBatchWrapper, DSSequenceDescriptor +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache +from ....v2.inference_test_utils import allclose +""" +NOTE(cmikeh2): It is very possible to see unit test failures (even on FP16) depending on when +certain values are casted up to or down from float32. If we are seeing accuracy issues, we should +make sure we are aligning on the training implementation's cast pattern here, given these tolerances +tend to be sufficient elsewhere. +""" + + +def rotary_pos_embs(q: torch.Tensor, + k: torch.Tensor, + seq_descs: List[DSSequenceDescriptor], + batch: RaggedBatchWrapper, + head_size: int, + rotary_dim: int = -1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + rotary_dim = rotary_dim if rotary_dim >= 0 else head_size + + def make_cos_sin_emb(seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: + t = torch.arange(seq_len, dtype=torch.float32, device=get_accelerator().current_device()) + inv_freq = (1.0 / (10000.0**(torch.arange( + 0, rotary_dim, 2, dtype=torch.float32, device=get_accelerator().current_device()) / rotary_dim))).half() + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + return torch.cos(emb)[:, None, :], torch.sin(emb)[:, None, :], inv_freq + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1) + + cos, sin, freqs = make_cos_sin_emb(1024) + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + n_heads_q = q.shape[1] // head_size + n_heads_kv = k.shape[1] // head_size + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + n_tokens = seq_desc.in_flight_tokens + + q_src = q[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_q, head_size).float() + k_src = k[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_kv, head_size).float() + freq_start_offset = seq_desc.seen_tokens + + q_src_rot = q_src[:, :, :rotary_dim] + k_src_rot = k_src[:, :, :rotary_dim] + + cos_chunk = cos[range(freq_start_offset, freq_start_offset + n_tokens)] + sin_chunk = sin[range(freq_start_offset, freq_start_offset + n_tokens)] + + q_rot = q_src_rot * cos_chunk + rotate_half(q_src_rot) * sin_chunk + k_rot = k_src_rot * cos_chunk + rotate_half(k_src_rot) * sin_chunk + + q_emb = torch.cat((q_rot, q_src[:, :, rotary_dim:]), dim=-1) + k_emb = torch.cat((k_rot, k_src[:, :, rotary_dim:]), dim=-1) + + q_out[start_idx:start_idx + n_tokens] = q_emb.reshape(n_tokens, n_heads_q * head_size).to(q_out.dtype) + k_out[start_idx:start_idx + n_tokens] = k_emb.reshape(n_tokens, n_heads_kv * head_size).to(k_out.dtype) + + return q_out, k_out, freqs + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 15), (1, 63)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +@pytest.mark.parametrize("head_size", [64, 80, 96]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): + """ + Validate that the copy works correctly + """ + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +@pytest.mark.parametrize("head_size", [64, 80, 96]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): + """ + Validate that the copy works correctly + """ + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_emb", [False, True]) +@pytest.mark.parametrize("head_size", [64, 80, 96]) +def test_multi_sequences(trained_emb: bool, head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [80, 96]) +def test_rotary_dim(head_size: int) -> None: + trained_emb = False + rotary_dim = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size, rotary_dim=rotary_dim) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, rotary_dim, 10000.0) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py new file mode 100644 index 000000000000..1feefa9ee588 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedLogitsGather +from ....v2.inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_simple_batch + + +def baseline_implementation(hidden_states: torch.Tensor, seq_lens: List[int]) -> torch.Tensor: + output = torch.empty((len(seq_lens), hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + + offset = 0 + for i, seq_len in enumerate(seq_lens): + output[i] = hidden_states[offset + seq_len - 1] + offset += seq_len + + return output + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('dtype', get_dtypes()) +def test_supported_dtypes(dtype: torch.dtype) -> None: + """ + Validate support on nominally supported data types. + """ + model_dim = 4096 + + batch = build_simple_batch([256], padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, [256]) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], + [63, 27, 74, 83, 32, 17, 1, 1, 1, 1, 1]]) +def test_multiple_sequences(seq_lens: List[int]) -> None: + """ + Validate support on more multi-sequence inputs. + """ + model_dim = 4096 + dtype = torch.float16 + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("model_dim", [1024, 6144, 6784]) +def test_problem_size_permutations(model_dim: int) -> None: + """ + Validate for different embedding sizes. + """ + dtype = torch.float16 + seq_lens = [128, 64, 192, 32] + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py new file mode 100644 index 000000000000..3907fc3e3a4b --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTopKGating, +) +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` and +``MoEScatter`` to produce correct inputs. If either of these kernels is broken +these tests will fail, so double check the unit test results there before +debugging here. +""" + +TEST_CASES = [ + # (n_tokens, n_experts, n_top_k) + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] + + +def build_inputs(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): + + assert n_tokens <= 2048, "This test will break if n_tokens > 2048" + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTopKGating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + + return batch, moe_input, scores, mapped_slots, expert_counts + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CASES) +@pytest.mark.parametrize("do_padding", [False]) +def test_moe_gather(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): + get_accelerator().manual_seed(0xC0FFEE) + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) + + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + effective_score = scores[token_idx].sum().item() + assert torch.equal( + output[token_idx], + torch.full((4096, ), + token_idx * effective_score, + dtype=torch.float16, + device=get_accelerator().current_device())) + + +@pytest.mark.inference_v2_ops +def test_moe_gather_normalize_scales(): + get_accelerator().manual_seed(0xC0FFEE) + + n_tokens = 72 + n_experts = 8 + n_top_k = 2 + do_padding = False + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096, normalize_scores=True) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), token_idx, dtype=torch.float16, device=get_accelerator().current_device())) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py new file mode 100644 index 000000000000..aae459f06a6f --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTopKGating +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` to produce correct +inputs. If ``RaggedTopKGating`` is broken, these tests will fail, so double check +the unit test results there before debugging here. +""" + +TEST_CONFIGS = [ + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CONFIGS) +@pytest.mark.parametrize("do_padding", [False, True]) +def test_moe_scatter(n_tokens, n_experts, n_top_k, do_padding): + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTopKGating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + get_accelerator().synchronize() + assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + + if not do_padding: + assert torch.unique(mapped_slots).size(0) == n_top_k * n_tokens + + for token_idx in range(batch.tensor_toks): + if token_idx < n_tokens: + for k in range(n_top_k): + expert_idx = expert_assignment[token_idx][k].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx][k] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx][k].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + else: + for k in range(n_top_k): + assert mapped_slots[token_idx][k].item() == -1 + + assert expert_cumsum[-1] == n_tokens * n_top_k diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py b/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py new file mode 100644 index 000000000000..f179f62a9b12 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedEmbeddingKernel +from ....v2.inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_batch_and_manager + + +def baseline_implementation(token_ids: torch.Tensor, + embedding_table: torch.Tensor, + unpadded_size: int, + positional_embedding_table: Optional[torch.Tensor] = None, + positional_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Baseline implementation for our ragged embedding kernel. + """ + if unpadded_size == token_ids.shape[0]: + token_embed = torch.nn.functional.embedding(token_ids, embedding_table) + + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + token_embed += pos_embed + return token_embed + else: + real_token_ids = token_ids[:unpadded_size] + output = torch.empty((token_ids.shape[0], embedding_table.shape[1]), + dtype=embedding_table.dtype, + device=get_accelerator().current_device()) + unpadded_output = torch.nn.functional.embedding(real_token_ids, embedding_table) + + # Positional embeddings aren't padded because it's simulated + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + unpadded_output += pos_embed + + output[:unpadded_size] = unpadded_output + return output + + +def _ragged_embed_test_helper(sequence_config: List[Tuple[int, int]], + embed_dtype: torch.dtype, + token_dtype: torch.dtype, + embed_dim: int, + vocab_size: int, + do_padding: bool = False, + pos_embed_size: int = -1, + pos_embed_offset: int = 0) -> None: + """ + Helper for embedding test to limit the number of tests to run. + + Params: + embed_dim (int): Model dimension + vocab_size (int): Leading dimension on embedding weight + pos_embed_size (int): Size of positional embedding. If negative, no positional embedding + is used. + pos_embed_offset (int): Offset for positional embedding. Effectively, the raw offsets + of a token into a sequence are offset by this amount into the embedding matrix. ( + i.e. the shape of the positional embeddings is (pos_embed_size + pos_embed_offset + embed_dim) + """ + device = get_accelerator().current_device() + + # Heads/Block size are irrelevant here but need something. + batch, _, _, = build_batch_and_manager(sequence_config, 64, 16, 64, vocab_range=vocab_size, padding=do_padding) + + embedding_table = torch.randn((vocab_size, embed_dim), dtype=embed_dtype, device=device) + + if pos_embed_size > 0: + pos_embedding_table = torch.randn((pos_embed_size + pos_embed_offset, embed_dim), + dtype=embed_dtype, + device=device) + positional_ids = torch.cat([ + torch.arange(start_idx, start_idx + seq_len, dtype=token_dtype, device=device) + for seq_len, start_idx in sequence_config + ]) + pos_embed_offset + else: + pos_embedding_table = None + positional_ids = None + + baseline_output = baseline_implementation(batch.input_ids().to(token_dtype), embedding_table, batch.current_tokens, + pos_embedding_table, positional_ids) + + kernel = RaggedEmbeddingKernel(embed_dtype, token_dtype, embed_dim) + output = torch.empty_like(baseline_output) + + kernel(output, + batch, + embedding_table, + position_embed_weight=pos_embedding_table, + position_embed_offset=pos_embed_offset) + + if do_padding: + assert output.shape[0] != batch.current_tokens + + assert allclose(output[:batch.current_tokens], baseline_output[:batch.current_tokens]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('token_dtype', [torch.int32, torch.int64]) +@pytest.mark.parametrize('embed_dtype', get_dtypes()) +def test_dtype_permutations(token_dtype: torch.dtype, embed_dtype: torch.dtype) -> None: + """ + Validate (on a single problem size) that the kernel support for different data types + is correct. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(256, 0)], embed_dtype, token_dtype, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('vocab_size, embed_dim', [(1024, 1024), (32000, 5120), (50304, 6144)]) +def test_problem_size_permutations(vocab_size: int, embed_dim: int) -> None: + """ + Validate on wider range of problem sizes. + """ + + _ragged_embed_test_helper([(256, 0)], torch.float16, torch.int32, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1]]) +@pytest.mark.parametrize('do_padding', [True, False]) +def test_complex_sequences(seq_lens: List[int], do_padding: bool) -> None: + """ + Validate on different ragged batch construction scenarios. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(seq_len, 0) for seq_len in seq_lens], + torch.float16, + torch.int32, + embed_dim, + vocab_size, + do_padding=do_padding) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_lens", [[(256, 0)], [(256, 0), + (128, 0)], [(256, 0), (128, 0), + (64, 0)], [(1, 877), (619, 0), (213, 372), (1, 45)]]) +def test_positional_embedding(seq_lens: List[Tuple[int, int]]) -> None: + """ + Validate that positional embedding works correctly. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper(seq_lens, torch.float16, torch.int32, embed_dim, vocab_size, pos_embed_size=2048) + + +@pytest.mark.inference_v2_ops +def test_positional_embedding_offset() -> None: + """ + Validate that positional embedding works correctly with an offset. + """ + embed_dim = 4096 + vocab_size = 50304 + seq_config = [(1, 877), (619, 0), (213, 372), (1, 45)] + + _ragged_embed_test_helper(seq_config, + torch.float16, + torch.int32, + embed_dim, + vocab_size, + pos_embed_size=2048, + pos_embed_offset=2) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py new file mode 100644 index 000000000000..5fa0c8a079f0 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import torch.nn.functional as F + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTopKGating +from .ragged_testing_utils import build_simple_batch +from ...inference_test_utils import allclose + + +def _top_k_gating_testing_helper(n_tokens: int, n_experts: int, n_top_k: int, seed: int = 0xC0FFEE) -> None: + + torch.manual_seed(seed) + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + ref_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + ref_scores, ref_indices = torch.topk(ref_weights, n_top_k, dim=-1) + + assert allclose(scores, ref_scores), f"expected {ref_scores}, got {scores}" + assert torch.equal(expert_assignment, + ref_indices.to(torch.int32)), f"expected {ref_indices}, got {expert_assignment}" + assert expert_counts.sum( + ) == n_tokens * n_top_k, f"expected {n_tokens * n_top_k} tokens, got {expert_counts.sum()}" + + # Ensure that the expert offsets are unique + for i in range(n_experts): + expert_idxs = torch.where(expert_assignment == i, expert_offset, 0) + if expert_counts[i] > 0: + assert expert_idxs.unique().shape[0] == expert_counts[ + i], f"expected {expert_counts[i]} unique offsets, got {expert_idxs.unique().shape[0]}" + assert expert_idxs.max( + ) == expert_counts[i] - 1, f"expected max offset {expert_counts[i] - 1}, got {expert_idxs.max()}" + else: + # Should have all 0's so one unique value + assert expert_idxs.unique().shape[0] == 1 + assert expert_idxs.max() == 0 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens', [1, 17, 32, 89, 433]) +def test_top_2_e_8_gating(n_tokens: int) -> None: + _top_k_gating_testing_helper(n_tokens=n_tokens, n_experts=8, n_top_k=2) + + +def _test_single_mapping_helper(n_tokens: int, + n_experts: int, + assigned_expert: int, + logit_fill: float = 0.0, + match_fill: float = 1.0) -> None: + + n_top_k = 1 + logits = torch.full((n_tokens, n_experts), + logit_fill, + dtype=torch.float16, + device=get_accelerator().current_device()) + + logits[:, assigned_expert] = match_fill + + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[assigned_expert] == n_tokens + assert torch.all(expert_assignment == assigned_expert) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert].reshape(-1, n_top_k)) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 128)]) +def test_single_mapping_gating(n_tokens: int, n_experts: int) -> None: + """ + Evaluate our expert stacking behavior in complete isolation. This ensures all tokens + mapped to the same expert are getting unique offsets and identical scores. + """ + assigned_expert = 13 + _test_single_mapping_helper(n_tokens, n_experts, assigned_expert) + + +@pytest.mark.inference_v2_ops +def test_negative_logits(): + """ + Ensure that scores/values are propagated correctly when all the logits are negative. An + earlier implementation of the scoring would return NaN for this case. + """ + _test_single_mapping_helper(128, 32, 13, logit_fill=-2.0, match_fill=-1.0) + + +@pytest.mark.inference_v2_ops +def test_determinism(): + """ + Ensure that ties between two logits are broken deterministically. This is essential when + the gating is distributed across multiple devices that need to map the same token to + the same expert. + """ + + n_tokens = 512 + n_experts = 64 + n_top_k = 1 + + logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + logits[:, 19] = 1.0 + logits[:, 26] = 1.0 + + gate = RaggedTopKGating(DtypeEnum.fp16) + + for _ in range(1024): + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[19] == n_tokens + assert expert_counts[26] == 0 + assert torch.all(expert_assignment == 19) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19].reshape(-1, 1)) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 2)]) +def test_score_accuracy(n_tokens: int, n_experts: int) -> None: + """ + Validate expert scores are correct. + """ + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + n_top_k = 1 + + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + + ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + ref_scores = ref_scores.reshape(-1, 1) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert allclose(scores, ref_scores) + assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/v2/model_implementations/__init__.py b/tests/unit/inference/v2/model_implementations/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/parameters/__init__.py b/tests/unit/inference/v2/model_implementations/parameters/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py b/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py new file mode 100644 index 000000000000..52ff0e134dfc --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.flat_model_helpers import ( + flatten_inference_model, + restore_inference_model, +) +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +from .utils import SimpleParam, DummyInferenceModel + + +class TransformerLayerContainer(LayerContainer): + """ + Stub layer container + """ + PARAM_MAPPING = { + "param_1": "param_1.param", + "param_2": "param_2.param", + } + + param_1: SimpleParam + + param_2: SimpleParam + + +class NonTransformerContainer(LayerContainer): + """ + Stub layer container + """ + PARAM_MAPPING = { + "param_1": "param_1.param", + "param_2": "param_2.param", + "param_3": "param_3.param", + } + + param_1: SimpleParam + + param_2: SimpleParam + + param_3: SimpleParam + + +@pytest.mark.inference_v2 +def test_contiguify_roundtrip(): + """ + Validate that contiguify round trips and reconstructions are correct. + """ + model = DummyInferenceModel() + + n_layers = 2 + transformer_params = [] + transformer_containers = [] + + # Create parameters and populate them into the containers + for i in range(n_layers): + transformer_containers.append(TransformerLayerContainer(model)) + layer_params = [] + for j in range(2): + layer_params.append(torch.rand(16, 16)) + transformer_containers[i].set_dependency(f"param_{j+1}", layer_params[j]) + + layer_params = [p.to(get_accelerator().current_device()) for p in layer_params] + + transformer_params.append(layer_params) + assert transformer_containers[i].is_populated == True + + non_transformer_params = [] + non_transformer_container = NonTransformerContainer(model) + + for i in range(3): + non_transformer_params.append(torch.rand(16, 16).permute(1, 0)) + non_transformer_container.set_dependency(f"param_{i+1}", non_transformer_params[i]) + + non_transformer_params = [p.to(get_accelerator().current_device()) for p in non_transformer_params] + + def validate_containers(t_containers: List[LayerContainer], n_t_containers: LayerContainer, + t_params: List[List[torch.Tensor]], n_t_params: List[torch.Tensor]): + """ + Validate params match what is on the containers. + """ + for i in range(n_layers): + l_c = t_containers[i] + + assert l_c.is_initialized == True + + assert torch.equal(l_c.param_1, t_params[i][0]) + assert torch.equal(l_c.param_2, t_params[i][1]) + + assert n_t_containers.is_initialized == True + assert torch.equal(n_t_containers.param_1, n_t_params[0]) + assert torch.equal(n_t_containers.param_2, n_t_params[1]) + assert torch.equal(n_t_containers.param_3, n_t_params[2]) + assert not n_t_containers.param_1.is_contiguous() + assert not n_t_containers.param_2.is_contiguous() + assert not n_t_containers.param_3.is_contiguous() + + buffer, metadata = flatten_inference_model(transformer_containers, non_transformer_container, "NoOpPolicy") + + # Validate containers before contiguify + validate_containers(transformer_containers, non_transformer_container, transformer_params, non_transformer_params) + + # Validate restore pass + transformer_containers_r = [] + for i in range(n_layers): + transformer_containers_r.append(TransformerLayerContainer(model)) + + non_transformer_container_r = NonTransformerContainer(model) + + restore_inference_model(buffer, metadata, transformer_containers_r, non_transformer_container_r) + + validate_containers(transformer_containers_r, non_transformer_container_r, transformer_params, + non_transformer_params) diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py b/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py new file mode 100644 index 000000000000..07ad87e6168d --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.inference_parameter import InferenceParameter +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import SimpleParam, DummyInferenceModel + + +class ParentLayer(LayerContainer): + """ + A layer that has a dependency on a simple parameter. + """ + + param_1: SimpleParam + + +class ChildLayer(ParentLayer): + """ + A layer that inherits from another layer. + """ + + param_2: SimpleParam + + +@pytest.mark.inference_v2 +def test_layer_inheritance(): + inference_model = DummyInferenceModel() + + multi_param_layer = ChildLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_populated is True + assert isinstance(multi_param_layer.param_1, InferenceParameter) + assert isinstance(multi_param_layer.param_2, InferenceParameter) diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py new file mode 100644 index 000000000000..52313cb6f202 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.inference_parameter import InferenceParameter +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MultiDependencyContainer(ParameterBase): + + dependency_1: torch.Tensor + + dependency_2: torch.Tensor + + @on_device + def finalize(self) -> torch.Tensor: + param = torch.cat([self.dependency_1, self.dependency_2]) + return InferenceParameter.initialize(param) + + +class ListDependencyContainer(ParameterBase): + + dependencies: ParamList("list_items") # noqa: F821 + + @on_device + def finalize(self) -> torch.Tensor: + param = torch.cat(tuple(self.dependencies)) + return InferenceParameter.initialize(param) + + +class MappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend.dependency_1", + "model.val.item.d_2": "multi_depend.dependency_2", + "model.list_vals.*.d": "list_depend.dependencies" + } + + multi_depend: MultiDependencyContainer + + list_depend: ListDependencyContainer + + +class SubMappingLayer(MappingLayer): + PARAM_MAPPING = { + "model.val.item2.d_1": "multi_depend2.dependency_1", + "model.val.item2.d_2": "multi_depend2.dependency_2", + } + + multi_depend2: MultiDependencyContainer + + +class DoubleMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": ["multi_depend.dependency_1", "multi_depend.dependency_2"], + } + + multi_depend: MultiDependencyContainer + + +class InferenceModel: + + @property + def list_items(self) -> int: + return 16 + + +@pytest.mark.inference_v2 +def test_mapping_syntax(): + model = InferenceModel() + + mapping_layer = MappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_populated == False + + assert isinstance(mapping_layer.list_depend, InferenceParameter) + assert mapping_layer.is_populated == True + + +@pytest.mark.inference_v2 +def test_sub_mapping_syntax(): + model = InferenceModel() + + mapping_layer = SubMappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, InferenceParameter) + + mapping_layer.set_dependency("model.val.item2.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item2.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend2, InferenceParameter) + + # We want to check into double digits to make sure that this isn't specific + # to single difit indexing. + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_populated == False + + assert isinstance(mapping_layer.list_depend, InferenceParameter) + assert mapping_layer.is_populated == True + + +@pytest.mark.inference_v2 +def test_double_mapping_syntax(): + model = InferenceModel() + + mapping_layer = DoubleMappingLayer(model) + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + + # The single parameter setting should immediately make the parameter finalized + # and the whole layer initialized. + assert isinstance(mapping_layer.multi_depend, InferenceParameter) + assert mapping_layer.is_populated == True + + +@pytest.mark.inference_v2 +def test_insufficient_mapping_syntax(): + """ + In the above example, we don't have a mapping for `multi_depend2.dependency_2`. + """ + + with pytest.raises(ValueError): + + class InsuffienctMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend2.dependency_1", + } + + multi_depend1: MultiDependencyContainer + + multi_depend2: MultiDependencyContainer + + +@pytest.mark.inference_v2 +def test_unknown_target_mapping_syntax(): + """ + In the above example, `multi_depend_unknown` does not exist + """ + + with pytest.raises(ValueError): + + class UnknownTargetMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend_unknown.dependency_1", + } + + multi_depend: MultiDependencyContainer diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py b/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py new file mode 100644 index 000000000000..b319bf6de4ad --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.inference_parameter import InferenceParameter +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import validate_device, SimpleParam, ListParam, DummyInferenceModel + + +class MultiParameterLayer(LayerContainer): + """ + Two dependencies, both of which are simple parameters. + """ + + param_1: SimpleParam + + param_2: SimpleParam + + +class MixedMultiParameterLayer(LayerContainer): + """ + Two dependencies, one of which is a simple parameter, the other is a list parameter. + """ + + param_1: SimpleParam + + param_2: ListParam + + +@pytest.mark.inference_v2 +def test_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + multi_param_layer = MultiParameterLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_populated is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_populated is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_populated is True + assert isinstance(multi_param_layer.param_1, InferenceParameter) + assert isinstance(multi_param_layer.param_2, InferenceParameter) + + +@pytest.mark.inference_v2 +def test_mixed_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + mixed_multi_param_layer = MixedMultiParameterLayer(inference_model) + + assert mixed_multi_param_layer.n_params == 2 + assert mixed_multi_param_layer.is_populated is False + + mixed_multi_param_layer.param_2.params[1] = torch.full((16, 16), 2.0) + assert mixed_multi_param_layer.is_populated is False + assert not isinstance(mixed_multi_param_layer.param_2, InferenceParameter) + + mixed_multi_param_layer.param_1.param = torch.ones(16, 16) + assert mixed_multi_param_layer.is_populated is False + assert isinstance(mixed_multi_param_layer.param_1, InferenceParameter) + + validate_device(mixed_multi_param_layer.param_1) + + mixed_multi_param_layer.param_2.params[0] = torch.full((16, 16), 2.0) + + assert mixed_multi_param_layer.is_populated is True + assert isinstance(mixed_multi_param_layer.param_2, InferenceParameter) + + validate_device(mixed_multi_param_layer.param_2) diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py new file mode 100644 index 000000000000..06ff9047d648 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.inference_parameter import InferenceParameter +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +from deepspeed.inference.v2.model_implementations.common_parameters import * + +from .utils import validate_device + + +class SimpleMoELayer(LayerContainer): + + moe_mlp_1: UnfusedMoEMLP1Parameter + + +class DummyInferenceModel: + + def __init__(self, experts_per_rank: int) -> None: + self._num_experts = experts_per_rank + + @property + def n_experts(self) -> int: + return self._num_experts + + @on_device + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + return InferenceParameter.initialize(param) + + +@pytest.mark.inference_v2 +def test_simple_moe_layer(): + + inference_model = DummyInferenceModel(experts_per_rank=2) + + simple_moe_layer = SimpleMoELayer(inference_model) + + assert simple_moe_layer.moe_mlp_1.experts[0] is None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + # Set the first expert + simple_moe_layer.moe_mlp_1.experts[0] = torch.zeros(16, 16) + + assert simple_moe_layer.moe_mlp_1.experts[0] is not None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + assert not simple_moe_layer.is_initialized + + # Set the second expert + simple_moe_layer.moe_mlp_1.experts[1] = torch.ones(16, 16) + + # We have all the experts, so the layer should be initialized + assert simple_moe_layer.is_initialized + assert isinstance(simple_moe_layer.moe_mlp_1, torch.Tensor) + + validate_device(simple_moe_layer.moe_mlp_1) + + +""" +Check that we can mix the number of elements in lists in the same context and have that +be tracked correctly. +""" + + +class CustomListParam1(ParameterBase): + + deps: ParamList("attr_1") + + +class CustomListParam2(ParameterBase): + + deps: ParamList("attr_2") + + +class MixedLayer(LayerContainer): + + list_1: CustomListParam1 + list_2: CustomListParam2 + + +class MixedInferenceModel: + + @property + def attr_1(self) -> int: + return 1 + + @property + def attr_2(self) -> int: + return 2 + + +@pytest.mark.inference_v2 +def test_mixed_param_lists(): + model = MixedInferenceModel() + + layer = MixedLayer(model) + + assert layer.list_1.deps.n_params == 1 + assert layer.list_2.deps.n_params == 2 diff --git a/tests/unit/inference/v2/model_implementations/parameters/utils.py b/tests/unit/inference/v2/model_implementations/parameters/utils.py new file mode 100644 index 000000000000..07d72059f9b3 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.inference_parameter import InferenceParameter +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParametrizedList + + +class SimpleParam(ParameterBase): + """ + Parameter with single dependency. + """ + + param: torch.Tensor + + @on_device + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(self.param) + + +class SimpleParametrizedList(ParametrizedList): + """ + Parameter list based on `num_dependencies` attribute. + """ + + count_attr: str = "num_dependencies" + + +class ListParam(ParameterBase): + """ + Parameter with list dependency. + + NOTE: This uses the tuple workaround for the `ParametrizedList` class + as described in the docstring of `ParametrizedList`. + """ + + params: SimpleParametrizedList + + @on_device + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(torch.cat(tuple(self.params))) + + +class DummyInferenceModel: + + @property + def num_dependencies(self) -> int: + return 2 + + def transform(self, param: torch.Tensor) -> torch.Tensor: + return InferenceParameter.initialize(param) + + +def validate_device(tensor: torch.Tensor): + assert tensor.device == torch.device(get_accelerator().current_device()) diff --git a/tests/unit/inference/v2/model_implementations/sharding/__init__.py b/tests/unit/inference/v2/model_implementations/sharding/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py new file mode 100644 index 000000000000..850c4c24fde6 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + +# None of the logic should be dependent on head size. +HEAD_SIZE = 64 + + +def fill_with_head_ids(head_size: int, n_heads: int) -> torch.Tensor: + """ + Fills a tensor with the associated head ids. All columns should have the same value. + """ + head_ids = torch.arange(n_heads, dtype=torch.half, device=get_accelerator().current_device()) + + head_ids = head_ids.repeat_interleave(head_size).repeat(head_size * n_heads).reshape(n_heads * head_size, -1) + return head_ids + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(1, 1), (8, 4), (32, 8)]) +def test_mha_even_sharding(n_heads: int, n_shards: int): + """ + Even head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + n_local_heads = n_heads // n_shards + sharded_shape = (HEAD_SIZE * n_heads, HEAD_SIZE * n_local_heads) + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + assert sharded_param.shape == sharded_shape + + heads = torch.chunk(sharded_param, n_local_heads, dim=1) + + for i, head in enumerate(heads): + assert torch.all(head == i + shard_rank * n_local_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_mha_unbalanced_sharding(n_heads: int, n_shards: int): + """ + Unbalanced head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + max_heads = 0 + min_heads = n_heads + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_ids = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_ids) == 1 + seen_heads.add(head_ids.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(20, 4, 8)]) +def test_gqa_uneven_sharding(n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + We only test the uneven GQA test case because even GQA shards the attention output + in the exact same manner as MHA. + + Args: + n_heads_q (int): The number of query heads. + n_heads_kv (int): The number of key/value heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads_q) + + min_heads = n_heads_q + max_heads = 0 + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE, n_heads_q, n_heads_kv) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_id = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_id) == 1 + seen_heads.add(head_id.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py new file mode 100644 index 000000000000..aac7e5391d8f --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def round_up_to_256(x: int) -> int: + """ + Round up to the nearest multiple of 256. + """ + return x + (256 - x % 256) + + +def make_params(model_dim: int, ffn_multiplier: int, n_experts: int, gated: bool = False) -> torch.Tensor: + """ + + """ + if gated: + mlp_1_intermediate = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mlp_2_intermediate = mlp_1_intermediate // 2 + else: + mlp_1_intermediate = ffn_multiplier * model_dim + mlp_2_intermediate = ffn_multiplier * model_dim + + mlp_1_shared_dim = torch.arange(mlp_1_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + + mlp_1_w = mlp_1_shared_dim.repeat_interleave(model_dim).reshape(mlp_1_intermediate, model_dim) + mlp_1_b = mlp_1_shared_dim + + mlp_2_shared_dim = torch.arange(mlp_2_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + mlp_2_w = mlp_2_shared_dim.repeat(model_dim).reshape(model_dim, mlp_2_intermediate) + mlp_2_b = torch.ones(model_dim, dtype=torch.float32, device=get_accelerator().current_device()) + + if n_experts > 1: + mlp_1_w = mlp_1_w.expand(n_experts, -1, -1) + mlp_1_b = mlp_1_b.expand(n_experts, -1) + mlp_2_w = mlp_2_w.expand(n_experts, -1, -1) + mlp_2_b = mlp_2_b.expand(n_experts, -1) + + return (mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_even_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + FFN sharding tends to be much simpler than attention sharding since it works on larger granularities. + While the test case of (1024, 4, 6) is not a use case we're likely to see, this does ensure that + the sharding logic will round correctly for the alignments we care about. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts) + + total_ffn_dim = model_dim * ffn_multiplier + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_gated_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + Test the same cases assuming a gated regime. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts, gated=True) + + total_ffn_dim = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] * 2 + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py new file mode 100644 index 000000000000..9a1cb9c09c64 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def fill_with_head_ids(head_size: int, n_heads_q: int, n_heads_kv: Optional[int] = None) -> torch.Tensor: + """ + + """ + head_ids_q = torch.arange(n_heads_q, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_q = head_ids_q.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_q * head_size, -1) + + if n_heads_kv is None: + return torch.cat([head_vals_q, head_vals_q, head_vals_q], dim=0) + + head_ids_k = torch.arange(n_heads_kv, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_k = head_ids_k.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_kv * head_size, -1) + + return torch.cat([head_vals_q, head_vals_k, head_vals_k], dim=0) + + +def validate_inferred_shape(shard: torch.Tensor, head_size: int, n_local_q_heads: int, n_local_kv_heads: int): + """ + Validate that the leading dim of the shard is of the expected size and aligns with the sharding + logic for the attention computation itself. + """ + inferred_leading_dim = head_size * (n_local_q_heads + 2 * n_local_kv_heads) + assert shard.shape[0] == inferred_leading_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads,n_shards", [(1, 1), (32, 1), (32, 8)]) +def test_even_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test for MHA sharding. In these scenarios, we expect that each of the shards + should be the same size. + """ + param = fill_with_head_ids(head_size, n_heads) + + heads_per_shard = n_heads // n_shards + + for shard_rank in range(n_shards): + + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape == (3 * head_size * heads_per_shard, head_size * n_heads) + + heads = shard.chunk(heads_per_shard * 3, dim=0) + for i in range(heads_per_shard): + assert torch.all(heads[i] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard * 2] == i + shard_rank * heads_per_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_unbalanced_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test MHA sharding when the distribution of heads will not be equal across all ranks. + """ + param = fill_with_head_ids(head_size, n_heads) + + max_heads = 0 + min_heads = n_heads + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + n_heads_in_shard = shard.shape[0] // head_size // 3 + + max_heads = max(max_heads, n_heads_in_shard) + min_heads = min(min_heads, n_heads_in_shard) + total_heads += n_heads_in_shard + + heads = shard.chunk(n_heads_in_shard * 3, dim=0) + + for local_head_id in range(n_heads_in_shard): + head_qkv = torch.cat([ + heads[local_head_id], heads[local_head_id + n_heads_in_shard], + heads[local_head_id + 2 * n_heads_in_shard] + ], + dim=0) + assert head_qkv.shape == (3 * head_size, head_size * n_heads) + + global_head_id = torch.unique_consecutive(head_qkv) + assert len(global_head_id) == 1 + + seen_heads.add(global_head_id.item()) + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 1), (8, 2, 1), (64, 16, 8)]) +def test_gqa_even_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when the KV heads are evenly divisible by the number of shards. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = n_heads_kv // n_shards + n_q_heads_in_shard = n_heads_q // n_shards + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape[0] == (n_q_heads_in_shard + n_kv_heads_in_shard * 2) * head_size + + q = shard[:n_q_heads_in_shard * head_size] + k = shard[n_q_heads_in_shard * head_size:(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size] + v = shard[(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size:] + + for local_head_id in range(n_q_heads_in_shard): + assert torch.all(q[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_q_heads_in_shard) + + for local_head_id in range(n_kv_heads_in_shard): + assert torch.all(k[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + assert torch.all(v[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 4), (20, 4, 8)]) +def test_gqa_uneven_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when there are more shards than KV heads. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = 1 + n_shards_per_kv_head = n_shards // n_heads_kv + + max_heads = 0 + min_heads = n_heads_q + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + local_n_heads_q = (shard.shape[0] - 2 * n_kv_heads_in_shard * head_size) // head_size + + max_heads = max(max_heads, local_n_heads_q) + min_heads = min(min_heads, local_n_heads_q) + total_heads += local_n_heads_q + + q = shard[:local_n_heads_q * head_size] + kv = shard[local_n_heads_q * head_size:] + + for local_head_id in range(local_n_heads_q): + q_head_id = torch.unique_consecutive(q[local_head_id * head_size:(local_head_id + 1) * head_size]) + assert len(q_head_id) == 1 + + seen_heads.add(q_head_id.item()) + + kv_id_calc = shard_rank // n_shards_per_kv_head + kv_id = torch.unique_consecutive(kv) + assert len(kv_id) == 1 + assert kv_id.item() == kv_id_calc + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(6, 8)]) +def test_unsupported_mha_configs(head_size: int, n_heads: int, n_shards: int): + """ + Sharding should fail if there are fewer heads than shards. + + TODO(cmikeh2): Look to support this configuration. + """ + param = fill_with_head_ids(head_size, n_heads) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(5, 2, 1), (40, 10, 8), (30, 5, 8)]) +def test_unsupported_gqa_configs(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + GQA has stricter requirements. We must be able to evenly shard or distribute the KV heads. + + Test cases are to test the following preconditions specifically: + 1. n_heads_q % n_heads_kv == 0 + 2. We must be able to evenly distribute KV heads + 3. We must be able to evely split KV heads + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + + +@pytest.mark.inference_v2 +def test_mha_input_shape_error(): + + param = torch.empty(256, 128) + + n_heads = 2 + head_size = 64 + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, 64) + + +@pytest.mark.inference_v2 +def test_gqa_input_shape_error(): + + head_size = 64 + n_heads_q = 16 + n_heads_kv = 4 + + # Correct shape is 1536 (=16 * 64 + 2 * 4 * 64), 1024 + param = torch.empty(2048, 1024) + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, head_size, n_heads_q, n_heads_kv) diff --git a/tests/unit/inference/v2/modules/__init__.py b/tests/unit/inference/v2/modules/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/modules/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/modules/test_blas_linear_module.py b/tests/unit/inference/v2/modules/test_blas_linear_module.py new file mode 100644 index 000000000000..f4d0b1991238 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_blas_linear_module.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ...v2.inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _blas_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True) -> None: + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + + bundle = ConfigBundle(name='blas_fp_linear', config=linear_config) + + module = DSLinearRegistry.instantiate_config(bundle) + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + # Reference output + ref_output = reference_implementation(hidden_states, weight, bias, act_fn) + + # New output + ds_output = module(hidden_states, weight, bias) + + # Check + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, in_channels, out_channels", [(1, 4608, 1728), (37, 8192, 4096), (1280, 3072, 6144)]) +def test_blas_linear_shapes(tokens: int, in_channels: int, out_channels: int) -> None: + + _blas_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, ActivationType.IDENTITY) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_blas_linear_act_fn(act_fn: ActivationType, use_bias: bool) -> None: + + _blas_linear_helper(283, 512, 4096, DtypeEnum.fp16, act_fn, use_bias=use_bias) diff --git a/tests/unit/inference/v2/modules/test_blocked_attn.py b/tests/unit/inference/v2/modules/test_blocked_attn.py new file mode 100644 index 000000000000..6556aa460a44 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_blocked_attn.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig +from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase + +from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager +from ...v2.inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False + + +def _blocked_flash_testing_helper(head_size: int, + n_heads_q: int, + n_heads_kv: int, + seq_params: List[Tuple[int, int]], + trained_freqs: bool = None) -> None: + """ + Helper function for testing blocked flash attention. This implementation is based on + the implemnentation in ``unit.inference.kernels.ragged_ops.test_blocked_flash`` but + integrates functionality to validate the composability. + """ + if trained_freqs is None: + embed_type = PositionalEmbeddingType.none + embed_args = None + else: + embed_type = PositionalEmbeddingType.rotate_half + embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs) + + attn_config = DSSelfAttentionConfig(max_tokens=2048, + n_heads_q=n_heads_q, + n_heads_kv=n_heads_kv, + head_size=head_size, + max_sequences=32, + positional_embedding_type=embed_type, + positional_embedding_config=embed_args) + + config = ConfigBundle(name='dense_blocked_attention', config=attn_config) + attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) + + kv_block_size = attn_module.kv_block_size + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + attn_module.build_atoms(batch) + if not trained_freqs: + out = attn_module(qkv, kv_cache, batch) + else: + inv_freqs = torch.randn((head_size // 2, ), device=get_accelerator().current_device(), dtype=torch.float16) + out = attn_module(qkv, kv_cache, batch, inv_freqs) + + if validate_accuracy and trained_freqs is None: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q = qkv[:, :head_size * n_heads_q] + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_freqs", [True, False]) +def test_rotary_emb(trained_freqs: bool) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params, trained_freqs=trained_freqs) diff --git a/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py b/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py new file mode 100644 index 000000000000..386f3b3ef0b3 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + residual = residual.to(torch.float32) + gamma = gamma.to(torch.float32) + beta = beta.to(torch.float32) + + if hidden_states is not None: + hidden_states = hidden_states.to(torch.float32) + residual = residual + hidden_states + hidden_states = torch.nn.functional.layer_norm(residual, (residual.size(-1), ), + weight=gamma, + bias=beta, + eps=epsilon) + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_ln_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_ln', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + beta = pre_ln_module.transform_param(beta) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_ln_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_ln_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_ln_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/v2/modules/test_custom_module.py b/tests/unit/inference/v2/modules/test_custom_module.py new file mode 100644 index 000000000000..eb54b7a913f2 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_custom_module.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.implementations import cuda_post_ln +from ...v2.inference_test_utils import allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@DSPostNormRegistry.register_module +class CustomPostLNModule(cuda_post_ln.DSPostLNCUDAModule): + + @staticmethod + def name(): + return 'custom_post_ln' + + +""" +Here, we explicitly register an LN implementation outside the core deepspeed repo. This should +validate that the registry is working as expected and we can implement modules outside the core +repo. +""" + + +@pytest.mark.inference_v2_ops +def test_custom_registration(): + channels = 4096 + dtype = torch.float16 + tokens = 1024 + + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='custom_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/modules/test_cutlass_moe.py b/tests/unit/inference/v2/modules/test_cutlass_moe.py new file mode 100644 index 000000000000..b14ba127c6be --- /dev/null +++ b/tests/unit/inference/v2/modules/test_cutlass_moe.py @@ -0,0 +1,328 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSMoEConfig +from deepspeed.inference.v2.modules.interfaces import DSMoERegistry + +from ..kernels.ragged_ops.ragged_testing_utils import build_simple_batch +from ...v2.inference_test_utils import allclose, get_dtypes + + +def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference gating code. + """ + logits = logits.float() + probs = torch.nn.functional.softmax(logits, dim=1) + + indices1_s = torch.argmax(probs, dim=-1) + mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) + indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (probs * mask1).sum(dim=1) + + sorted_indices = indices1_s.sort()[1] + original_indices = sorted_indices.sort()[1] + + exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() + exp_count_cumsum = exp_count.cumsum(dim=0) + + return sorted_indices, original_indices, exp_count_cumsum, gates1_s + + +def _reference_impl(hidden_states: torch.Tensor, gate_weight: torch.Tensor, mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, mlp_1_b: torch.Tensor, mlp_2_b: torch.Tensor, + act_fn: ActivationType) -> torch.Tensor: + """ + Reference implementation of the MoE module. + """ + + act_fn_dict = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + logits = torch.matmul(hidden_states, gate_weight.t()) + sorted_indices, original_indices, exp_count_cumsum, gate_scales = _gating_reference(logits) + + moe_input = hidden_states[sorted_indices] + + output_unordered = torch.empty_like(hidden_states) + + for expert_idx in range(mlp_1_w.shape[0]): + min_bound = 0 if expert_idx == 0 else exp_count_cumsum[expert_idx - 1] + max_bound = exp_count_cumsum[expert_idx] + + input_slice = moe_input[min_bound:max_bound] + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx]) + + intermediate = act_fn_dict[act_fn](intermediate) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx]) + + output_unordered[min_bound:max_bound] = output_slice + + output = output_unordered[original_indices] + + output.mul_(gate_scales.unsqueeze(-1)).reshape(hidden_states.shape) + return output + + +def _cutlass_moe_testing_helper(tokens: int, + in_channels: int, + intermediate_dim: int, + experts: int, + dtype: int, + activation_type: ActivationType = ActivationType.GELU, + use_bias: bool = True, + iters: int = 1) -> None: + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=activation_type, + input_dtype=dtype, + output_dtype=dtype) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([tokens]) + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_1_w = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_w = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + if use_bias: + mlp_1_b = torch.randn( + (experts, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_b = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + else: + mlp_1_b = None + mlp_2_b = None + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_1_w_ds = moe_module.transform_moe_mlp_1_param(mlp_1_w) + mlp_1_b_ds = moe_module.transform_moe_mlp_1_param(mlp_1_b) + mlp_2_w_ds = moe_module.transform_moe_mlp_2_param(mlp_2_w) + mlp_2_b_ds = moe_module.transform_moe_mlp_2_param(mlp_2_b) + + for _ in range(iters): + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + # Reference implementation + ref_output = _reference_impl(hidden_states, gate_weight, mlp_1_w, mlp_2_w, mlp_1_b, mlp_2_b, activation_type) + + output = moe_module(hidden_states, + batch, + gate_ds, + mlp_1_w_ds, + mlp_2_w_ds, + mlp_1_b=mlp_1_b_ds, + mlp_2_b=mlp_2_b_ds) + + # Increase the tolerance for larger meta ops since the error is additive + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("experts", [2, 32, 64]) +def test_expert_variance(experts: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=experts, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +def test_successive_inputs(): + """ + The CUTLASS MoE uses persistent state (expert counts) that is assumed to be cleared + on each forward pass. This ensures that the module is clearing that metadata. + """ + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True, + iters=10) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtypes(dtype: torch.dtype) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum(dtype), + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("activation_type", [ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]) +def test_activation_types(activation_type: ActivationType) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=activation_type, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("in_channels, out_channels", [(4096, 2048), (2048, 8192), (6144, 3072)]) +def test_in_out_channels(in_channels: int, out_channels: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=in_channels, + intermediate_dim=out_channels, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +def _mixtral_moe_baseline(hidden_states: torch.Tensor, + gate_weight: torch.Tensor, + mlp_w1: torch.Tensor, + mlp_w2: torch.Tensor, + mlp_w3: torch.Tensor, + force_float: bool = False) -> torch.Tensor: + """ + Baseline implementation for mixtral MoE module. + + Based on transformers implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + """ + output_dtype = hidden_states.dtype + if force_float: + hidden_states = hidden_states.float() + gate_weight = gate_weight.float() + mlp_w1 = mlp_w1.float() + mlp_w2 = mlp_w2.float() + mlp_w3 = mlp_w3.float() + + router_logits = torch.nn.functional.linear(hidden_states, gate_weight) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = routing_weights.topk(k=2, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # NOTE(cmikeh2): This is a difference implementation, ours will preserve the original scale + # as float32 and perform in-kernel fused FP16->FP32->FP16 conversion. + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros_like(hidden_states) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=gate_weight.shape[0]).permute(2, 1, 0) + get_accelerator().synchronize() + + for expert_idx in range(gate_weight.shape[0]): + exp_mlp_w1 = mlp_w1[expert_idx] + exp_mlp_w2 = mlp_w2[expert_idx] + exp_mlp_w3 = mlp_w3[expert_idx] + + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[top_x_list] + + linear = torch.nn.functional.linear + intermediate = torch.nn.functional.silu(linear(current_state, exp_mlp_w1)) * linear(current_state, exp_mlp_w3) + output = linear(intermediate, exp_mlp_w2) * routing_weights[top_x_list, idx_list].unsqueeze(-1) + final_hidden_states.index_add_(0, top_x, output.to(final_hidden_states.dtype)) + + return final_hidden_states.to(output_dtype) + + +@pytest.mark.inference_v2_ops +def test_mixtral_moe_config(): + + experts = 8 + n_top_k = 2 + in_channels = 4096 + intermediate_dim = 2048 + dtype = DtypeEnum.bf16 + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_w1 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w3 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w2 = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + n_tokens = 256 + hidden_states = torch.randn( + (n_tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + baseline = _mixtral_moe_baseline(hidden_states, gate_weight, mlp_w1, mlp_w2, mlp_w3) + + mlp_w13_fused = torch.cat([mlp_w1, mlp_w3], dim=-1).reshape(experts, 2 * intermediate_dim, in_channels) + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=ActivationType.SiGLU, + input_dtype=dtype, + output_dtype=dtype, + top_k=n_top_k, + normalize_scores=True) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([n_tokens]) + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_w1_ds = moe_module.transform_moe_mlp_1_param(mlp_w13_fused) + mlp_w2_ds = moe_module.transform_moe_mlp_2_param(mlp_w2) + + output = moe_module(hidden_states, batch, gate_ds, mlp_w1_ds, mlp_w2_ds) + + # NOTE(cmikeh2): These are higher than the other tests for reasons that aren't quite + # clear to me. My best guess is that the SiGLU activation is causing larger numerical + # divergence. The thresholds chosen here is based on the observed error between the + # float and bfloat16 reference implementations. + assert allclose(output, baseline.to(dtype.value), tolerances=(5e-2, 5e-2)) diff --git a/tests/unit/inference/v2/modules/test_post_ln_module.py b/tests/unit/inference/v2/modules/test_post_ln_module.py new file mode 100644 index 000000000000..f9dcfd272170 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_post_ln_module.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln_module(tokens: int, channels: int, dtype: torch.dtype) -> None: + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/modules/test_pre_rms_module.py b/tests/unit/inference/v2/modules/test_pre_rms_module.py new file mode 100644 index 000000000000..bbd108a35a5a --- /dev/null +++ b/tests/unit/inference/v2/modules/test_pre_rms_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + if hidden_states is not None: + hidden_states = hidden_states + residual = residual + hidden_states + + rms_vals = residual.to(torch.float32) + variance = rms_vals.pow(2).mean(-1, keepdim=True) + rms_vals = rms_vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + rms_vals = rms_vals.to(gamma.dtype) + + hidden_states = gamma * rms_vals + + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_rms_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="rms_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_rms', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_rms_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_rms_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_rms_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/v2/modules/test_quantized_linear_module.py b/tests/unit/inference/v2/modules/test_quantized_linear_module.py new file mode 100644 index 000000000000..050f21c3bf3a --- /dev/null +++ b/tests/unit/inference/v2/modules/test_quantized_linear_module.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ...v2.inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: + from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize + weight_quantized_fake_fp6, scales = fp_quantize(weight, num_bits=6, exp_bits=3) + return weight_quantized_fake_fp6 * scales + + +def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + weight_dequantized = _fp6_quant_dequant_weights(weight) + out_states = torch.nn.functional.linear(hidden_states, weight_dequantized, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quantized_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True, + expect_failure: bool = False) -> None: + # The current FP6 kernel only supports NVIDIA Ampere GPUs. + if not 'cuda' in get_accelerator().current_device_name(): + return + major, _ = torch.cuda.get_device_capability() #ignore-cuda + if major != 8: + return + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * \ + out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + # quantize and dequantize output + ref_quant_dequant_output = quant_dequant_implementation(hidden_states, weight, bias, act_fn) + + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) + fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) + weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) + + if expect_failure: + with pytest.raises(ValueError) as excinfo: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + assert "The out and in channel should be multiple of 256 and 64 respectively." in str(excinfo.value) + else: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + # The current FP6 kernel uses FP16 Tensor Core. + tolerances = (3e-2, 2e-3) # tolerances for fp16 + + # Check DeepSpeed implementation + assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] +all_tokens = [37] +all_in_out_channels = [ + (4096, 4096), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias) + + +# Other shapes, not supported by FP6 kernels. Will raise ValueError. +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", [(4608, 1728)]) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn_fail(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias, + expect_failure=True) diff --git a/tests/unit/inference/v2/ragged/__init__.py b/tests/unit/inference/v2/ragged/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/ragged/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/ragged/test_blocked_allocator.py b/tests/unit/inference/v2/ragged/test_blocked_allocator.py new file mode 100644 index 000000000000..4596e81c5652 --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_blocked_allocator.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List + +import pytest +import torch + +from deepspeed.inference.v2.ragged.blocked_allocator import BlockedAllocator + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('bad_size', [0, -1]) +def test_bad_initialization(bad_size: int) -> None: + with pytest.raises(ValueError): + BlockedAllocator(bad_size) + + +@pytest.mark.inference_v2 +def test_allocation() -> None: + + allocator = BlockedAllocator(16) + + a1 = allocator.allocate(4) + assert a1.numel() == 4 + assert allocator.free_blocks == 12 + + a2_allocs = [] + for i in range(3): + a2_allocs.append(allocator.allocate(2)) + assert allocator.free_blocks == 12 - (i + 1) * 2 + + a3 = allocator.allocate(6) + assert a3.numel() == 6 + + assert allocator.free_blocks == 0 + + # Test that we can't allocate more blocks than we have. + with pytest.raises(ValueError): + allocator.allocate(1) + + all_vals = torch.cat([a1, *a2_allocs, a3], dim=0) + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + +@pytest.mark.inference_v2 +def test_too_large_allocation(): + allocator = BlockedAllocator(16) + + with pytest.raises(ValueError): + allocator.allocate(17) + + +@pytest.mark.inference_v2 +def test_deallocation() -> None: + allocator = BlockedAllocator(16) + + # Allocate + all_blocks = allocator.allocate(16) + assert allocator.free_blocks == 0 + + # Deallocate all blocks + allocator.free(all_blocks) + assert allocator.free_blocks == 16 + + # Get all the blocks again + all_blocks = allocator.allocate(16) + + # Deallocate in chunks + c1 = all_blocks[:4] + c2 = all_blocks[4:8] + + allocator.free(c1) + assert allocator.free_blocks == 4 + + allocator.free(c2) + assert allocator.free_blocks == 8 + + with pytest.raises(ValueError): + allocator.free(c1) + + with pytest.raises(ValueError): + allocator.free(c2) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_dealloc_indices(index: int): + allocator = BlockedAllocator(1) + + with pytest.raises(ValueError): + allocator.free(torch.tensor([index])) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_alloc_indices(index: int): + allocator = BlockedAllocator(1) + allocator.allocate(1) + + to_free = [0, index] + + with pytest.raises(ValueError): + allocator.free(torch.tensor(to_free)) + + # Block 0 should not be freed if passed with an invalid index. + assert allocator.free_blocks == 0 + + allocator.free(torch.tensor([0])) + assert allocator.free_blocks == 1 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('test_iters', [8192]) +def test_long_running_allocation(test_iters: int) -> None: + """ + Evaluate the stability of the allocator over a longer sequence of allocations/deallocations. + """ + TOTAL_BLOCKS = 128 + + allocator = BlockedAllocator(TOTAL_BLOCKS) + + def validate_uniqueness(all_blocks: List[torch.Tensor]) -> None: + all_vals = torch.cat(all_blocks, dim=0) + assert all_vals.numel() <= TOTAL_BLOCKS + + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + all_allocs: List[torch.Tensor] = [] + num_allocs = 0 + num_frees = 0 + num_blocks_allocated = 0 + num_blocks_freed = 0 + + for _ in range(test_iters): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(1, 24) + if blocks_to_allocate > allocator.free_blocks: + with pytest.raises(ValueError): + allocator.allocate(blocks_to_allocate) + else: + all_allocs.append(allocator.allocate(blocks_to_allocate)) + num_allocs += 1 + num_blocks_allocated += blocks_to_allocate + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + allocator.free(all_allocs[idx]) + + num_frees += 1 + num_blocks_freed += all_allocs[idx].numel() + + del all_allocs[idx] + + if len(all_allocs) > 0: + validate_uniqueness(all_allocs) + + assert num_allocs == num_frees + len(all_allocs) + assert num_blocks_allocated == num_blocks_freed + (TOTAL_BLOCKS - allocator.free_blocks) diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py new file mode 100644 index 000000000000..bdd513445ddb --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +from pydantic import ValidationError + +from deepspeed.inference.v2.ragged import DSStateManagerConfig + + +@pytest.mark.inference_v2 +def test_negative_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=0) + + +@pytest.mark.inference_v2 +def test_too_small_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=512, max_ragged_sequence_count=1024) + + +@pytest.mark.inference_v2 +def test_too_small_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=512, max_ragged_sequence_count=1024) diff --git a/tests/unit/inference/v2/ragged/test_ragged_wrapper.py b/tests/unit/inference/v2/ragged/test_ragged_wrapper.py new file mode 100644 index 000000000000..3cb74f4c49d2 --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_ragged_wrapper.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, + DSStateManagerConfig, +) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('max_ragged_sequence_count, max_ragged_batch_size', [(128, 512), (128, 1024)]) +def test_wrapper_initialization(max_ragged_sequence_count: int, max_ragged_batch_size: int) -> None: + config = DSStateManagerConfig(max_tracked_sequences=max_ragged_sequence_count, + max_ragged_batch_size=max_ragged_batch_size, + max_ragged_sequence_count=max_ragged_sequence_count) + + batch = RaggedBatchWrapper(config) + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_len', [1, 37, 128, 512]) +def test_single_sequence_batch(seq_len: int) -> None: + """ + Test we successfully construct single sequence batches and the on device metadata is accurate. + """ + + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, 100, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize() + + assert batch.current_tokens == seq_len + assert batch.current_sequences == 1 + assert torch.equal(batch.input_ids(), tokens.to(get_accelerator().current_device())) + assert torch.equal(batch.tokens_to_seq(), torch.zeros_like(tokens, device=get_accelerator().current_device())) + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([seq_len, 1], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_lens', [[128, 128], [1, 32, 243], [64, 1, 1, 1, 1, 393, 27, 2]]) +def test_multi_sequence_batch(seq_lens: List[int]) -> None: + """ + Test sequentially adding new tokens to a batch and validate device data structures hold + the appropriate data. + """ + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + all_toks = [torch.randint(0, 100, (seq_len, )) for seq_len in seq_lens] + + for i, toks in enumerate(all_toks): + seq_desc = PlaceholderSequenceDescriptor() + batch.insert_sequence(seq_desc, toks) + + assert batch.current_tokens == sum(seq_lens[:i + 1]) + assert batch.current_sequences == i + 1 + + batch.finalize() + + assert batch.current_tokens == sum(seq_lens) + assert batch.current_sequences == len(seq_lens) + + assert torch.equal(batch.input_ids(), torch.cat(all_toks, dim=0).to(get_accelerator().current_device())) + assert torch.equal( + batch.tokens_to_seq(), + torch.cat([torch.full((seq_len, ), i, dtype=torch.int32) for i, seq_len in enumerate(seq_lens)], + dim=0).to(get_accelerator().current_device())) + + for i, seq_len in enumerate(seq_lens): + assert batch.inflight_seq_descriptors()[i][0] == sum(seq_lens[:i]) + assert batch.inflight_seq_descriptors()[i][1] == seq_len + assert batch.inflight_seq_descriptors()[i][2] == 0 + + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([sum(seq_lens), len(seq_lens)], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 diff --git a/tests/unit/launcher/test_ds_arguments.py b/tests/unit/launcher/test_ds_arguments.py index 7155beebc902..ee6d4ce6b7be 100644 --- a/tests/unit/launcher/test_ds_arguments.py +++ b/tests/unit/launcher/test_ds_arguments.py @@ -6,7 +6,7 @@ import argparse import pytest import deepspeed -from deepspeed.launcher.launch import parse_range_list +from deepspeed.utils.numa import parse_range_list def basic_parser(): @@ -40,7 +40,7 @@ def test_no_ds_arguments(): assert args.deepspeed == False assert hasattr(args, 'deepspeed_config') - assert args.deepspeed_config == None + assert args.deepspeed_config is None def test_no_ds_enable_argument(): @@ -74,7 +74,7 @@ def test_no_ds_config_argument(): assert args.deepspeed == True assert hasattr(args, 'deepspeed_config') - assert args.deepspeed_config == None + assert args.deepspeed_config is None def test_no_ds_parser(): diff --git a/tests/unit/launcher/test_multinode_runner.py b/tests/unit/launcher/test_multinode_runner.py index 743fffd8426f..3e8de2e83874 100644 --- a/tests/unit/launcher/test_multinode_runner.py +++ b/tests/unit/launcher/test_multinode_runner.py @@ -22,7 +22,7 @@ def runner_info(): def test_pdsh_runner(runner_info): env, resource_pool, world_info, args = runner_info runner = mnrunner.PDSHRunner(args, world_info) - cmd, kill_cmd = runner.get_cmd(env, resource_pool) + cmd, kill_cmd, env = runner.get_cmd(env, resource_pool) assert cmd[0] == 'pdsh' assert env['PDSH_RCMD_TYPE'] == 'ssh' @@ -32,6 +32,25 @@ def test_openmpi_runner(runner_info): runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) cmd = runner.get_cmd(env, resource_pool) assert cmd[0] == 'mpirun' + assert 'eth0' in cmd + + +def test_btl_nic_openmpi_runner(runner_info): + env, resource_pool, world_info, _ = runner_info + args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py']) + runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) + cmd = runner.get_cmd(env, resource_pool) + assert 'eth0' not in cmd + assert 'eth1' in cmd + + +def test_btl_nic_two_dashes_openmpi_runner(runner_info): + env, resource_pool, world_info, _ = runner_info + args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py']) + runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) + cmd = runner.get_cmd(env, resource_pool) + assert 'eth0' not in cmd + assert 'eth1' in cmd def test_mpich_runner(runner_info): diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py new file mode 100644 index 000000000000..fd1489803812 --- /dev/null +++ b/tests/unit/launcher/test_user_args.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import subprocess + +from types import SimpleNamespace + +from deepspeed.accelerator import get_accelerator +from deepspeed.launcher.multinode_runner import MultiNodeRunner + + +class DummyRunner(MultiNodeRunner): + + def backend_exists(self): + return True + + def get_cmd(self, environment, active_resources): + return [] + + +if not get_accelerator().is_available(): + pytest.skip("only supported in accelerator environments.", allow_module_level=True) + +user_arg_test_script = """import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str) +parser.add_argument("--local_rank", type=int, default=0) +parser.add_argument("--world_size", type=int, default=1) +args = parser.parse_args() +print("ARG PARSE SUCCESS") +""" + + +@pytest.fixture(scope="function") +def user_script_fp(tmpdir): + script_fp = tmpdir.join("user_arg_test.py") + with open(script_fp, "w") as f: + f.write(user_arg_test_script) + return script_fp + + +@pytest.fixture(scope="function") +def cmd(user_script_fp, prompt, multi_node): + if multi_node: + cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + else: + cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + return cmd + + +@pytest.fixture +def dummy_runner(): + args = SimpleNamespace(user_args=[], user_script="dummy_script.py") + return DummyRunner(args, "dummy_world_info") + + +@pytest.mark.parametrize("prompt", [ + '''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""", + '''I'm going to tell them "DeepSpeed is the best"''' +]) +@pytest.mark.parametrize("multi_node", [True, False]) +def test_user_args(cmd, multi_node): + if multi_node and get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_bash_string_args(tmpdir, user_script_fp): + bash_script = f""" + ARGS="--prompt 'DeepSpeed is the best'" + echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp} + """ + + bash_fp = tmpdir.join("bash_script.sh") + with open(bash_fp, "w") as f: + f.write(bash_script) + + p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_add_export_with_special_characters(dummy_runner): + """ + Values with special characters (e.g., 64(x2)) must be quoted to avoid bash syntax errors. + """ + dummy_runner.add_export("SLURM_JOB_CPUS_PER_NODE", "64(x2)") + assert dummy_runner.exports["SLURM_JOB_CPUS_PER_NODE"] == "\"64(x2)\"" + + +def test_add_export_no_special_characters(dummy_runner): + """ + Values without special characters should remain unquoted (e.g., PYTHONPATH). + This avoids issues where unnecessary quotes break module imports. + """ + dummy_runner.add_export("PYTHONPATH", "/usr/local/lib/python3.9/site-packages") + assert dummy_runner.exports["PYTHONPATH"] == "/usr/local/lib/python3.9/site-packages" diff --git a/tests/unit/linear/test_ctx.py b/tests/unit/linear/test_ctx.py new file mode 100644 index 000000000000..90fa55a489ff --- /dev/null +++ b/tests/unit/linear/test_ctx.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest + +import deepspeed.comm as dist +from deepspeed.linear import LoRAConfig, init_lora +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear +from unit.simple_model import random_dataloader, SimpleModel + +try: + import transformers +except ImportError: + transformers = None + +if transformers is None: + pytest.skip("transformers is required for this test", allow_module_level=True) + + +def injection_assert(model): + # pick out random linear that should have been replaced and initialized + q_proj = model.model.layers[1].self_attn.q_proj + + assert isinstance(q_proj, LoRAOptimizedLinear), "injection did not happen" + assert q_proj._initialized, "lora was not initialized properly" + assert isinstance(q_proj.lora_weight_1, torch.nn.Linear) + assert isinstance(q_proj.lora_weight_2, torch.nn.Linear) + + +class TestEngine(DistributedTest): + world_size = 2 + + def test_model(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + hidden_dim = 64 + nlayers = 4 + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + + init_lora(model) + + model_norms = [model.linears[i].weight.norm().item() for i in range(nlayers)] + + ds_config = { + "train_batch_size": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 1 + } + } + model, *_ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + + engine_norms = [model.module.linears[i].weight.norm().item() for i in range(nlayers)] + + # Ensure that sharded weights are not broadcast during engine init + assert engine_norms == model_norms, f"{dist.get_rank()=} base weight norms are not the same after engine init, {engine_norms=} != {model_norms=}" + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + +@pytest.mark.skip( + "Skipping test for now - the context manager has an issue with ._initialized and .disabled - worked with older transformers probably because it was setting some flags with the same name" +) +class TestInitTransformers(DistributedTest): + world_size = 2 + + def test_pretrained_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-Llama-3") + + injection_assert(model) + + def test_config_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + config = transformers.AutoConfig.from_pretrained("llamafactory/tiny-random-Llama-3") + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_config(config) + + injection_assert(model) diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py new file mode 100644 index 000000000000..2058791dba4a --- /dev/null +++ b/tests/unit/linear/test_linear.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear import OptimizedLinear, LoRAConfig, QuantizationConfig +from unit.common import DistributedTest + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestBasicLinear(DistributedTest): + world_size = 2 + + def test(self): + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 1 # Number of samples in a batch + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + + dummy_input = torch.rand(batch_size, input_features, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2]) +class TestLoRALinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding): + rank = dist.get_rank() + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + if rank == 0: + for n, p in linear_layer.named_parameters(): + print(f"{n}, {p.shape}") + + dummy_input = torch.rand(batch_size, input_features, device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("q_bits", [8, 6]) +class TestQuantLinear(DistributedTest): + world_size = 2 + + def test(self, q_bits): + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = None + quantization_config = QuantizationConfig(q_bits=q_bits) + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2], ids=['bws1', 'bws2']) +@pytest.mark.parametrize("q_bits", [8, 6], ids=['qbit8', 'qbit6']) +class TestOptimizedLinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding, q_bits): + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + quantization_config = QuantizationConfig(q_bits=q_bits) + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py new file mode 100644 index 000000000000..283d81b4bf36 --- /dev/null +++ b/tests/unit/linear/test_quant_param.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear.quantization import QuantizedParameter +from deepspeed.linear.config import QuantizationConfig + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestQuantParam(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('dtype', [torch.half, torch.float]) + def test_unsupported_dtypes(self, dtype): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device='cpu', dtype=dtype) + qp = QuantizedParameter(data) + with pytest.raises(AssertionError): + qp.to(device) + + def test_requires_grad(self): + data = torch.rand(5, 5, dtype=torch.bfloat16) + with pytest.raises(ValueError): + QuantizedParameter(data, requires_grad=True) + + def test_move_to_accelerator(self): + device = get_accelerator().current_device() + data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16) + quantization_config = QuantizationConfig() + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + qp = QuantizedParameter(data, quantization_config=quantization_config) + assert qp.device == torch.device('cpu') + qp = qp.to(get_accelerator().current_device_name()) + assert qp.device == torch.device(device) + assert qp.dtype == quantization_config.q_dtype + + def test_hf_clone(self): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device=device, dtype=torch.bfloat16) + + quantization_config = QuantizationConfig(q_bits=6) + qp = QuantizedParameter(data, quantization_config=quantization_config) + + # should be able to clone parameter via dict, HF expects this to work + qp_copy = QuantizedParameter(qp.data, **qp.__dict__) + + assert all(qp.data == qp_copy.data) + assert qp.quantization_config == qp_copy.quantization_config diff --git a/tests/unit/model_parallelism/test_autotp_custom_patterns.py b/tests/unit/model_parallelism/test_autotp_custom_patterns.py new file mode 100644 index 000000000000..01c2c0b8f547 --- /dev/null +++ b/tests/unit/model_parallelism/test_autotp_custom_patterns.py @@ -0,0 +1,569 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed.comm as dist +import deepspeed +from copy import deepcopy +from torch import nn + +from unit.common import DistributedTest, preferred_dtype +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import groups +from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer, fused_LinearLayer) +from deepspeed.module_inject.autotp_config import AutoTPConfig +from deepspeed.module_inject.auto_tp import AutoTP + + +def skip_on_device(): + if get_accelerator().device_name() == 'xpu': + pytest.skip("XPU requires a higher version for test") + + +class SequentialLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim, nlayers=1): + super(SequentialLinearModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) + + def forward(self, x): + for layer in self.linears: + x = layer(x) + return x + + +class CustomLinearModule(torch.nn.Module): + + def __init__(self, hidden_dim): + super(CustomLinearModule, self).__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim)) + self.bias = torch.nn.Parameter(torch.empty(hidden_dim)) + torch.nn.init.uniform_(self.weight, -0.02, 0.02) + torch.nn.init.uniform_(self.bias, -0.02, 0.02) + + def forward(self, x): + return torch.matmul(x, self.weight.transpose(-1, -2)) + self.bias + + +class CustomLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(CustomLinearModel, self).__init__() + self.custom = CustomLinearModule(hidden_dim) + + def forward(self, x): + return self.custom(x) + + +class QKVLinearModule(torch.nn.Module): + + def __init__(self, hidden_dim): + super(QKVLinearModule, self).__init__() + self.qkv_proj = torch.nn.Linear(hidden_dim, hidden_dim * 3) + + def forward(self, x): + return self.qkv_proj(x) + + +class QKVLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(QKVLinearModel, self).__init__() + self.self_attn = QKVLinearModule(hidden_dim) + + def forward(self, x): + return self.self_attn(x) + + +class DeepAttention(torch.nn.Module): + """Mimics HF attention module with separate projection layers.""" + + def __init__(self, hidden_dim): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + return self.o_proj(self.q_proj(x)) + + +class DeepBlock(torch.nn.Module): + """Mimics a single HF transformer block.""" + + def __init__(self, hidden_dim): + super().__init__() + self.self_attn = DeepAttention(hidden_dim) + + def forward(self, x): + return self.self_attn(x) + + +class DeepModel(torch.nn.Module): + """Mimics HF transformer structure: model.layers.[N].self_attn.{q,o}_proj. + + This creates a 4-level-deep module hierarchy to test that _replace_module + correctly propagates the full module path during recursion. + """ + + def __init__(self, hidden_dim, nlayers=2): + super().__init__() + self.layers = torch.nn.ModuleList([DeepBlock(hidden_dim) for _ in range(nlayers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def init_tp_engine(tp_size, partition_config=None): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size, + }, + "zero_optimization": { + "stage": 0, + } + } + if partition_config is not None: + config_dict["tensor_parallel"]["partition_config"] = partition_config + else: + config_dict["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=8) + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + +def apply_autotp_with_partition_config(model, tp_size, partition_config): + groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size) + autotp_config = AutoTPConfig.from_dict(partition_config) + autotp = AutoTP(module=model, + all_reduce_linears=[], + prefix="", + state_dict=None, + linear_layer_setting=None, + orig_layer_impl=None, + keep_module_on_host=False, + partition_config=autotp_config) + autotp.set_tensor_parallel_config(tp_size, groups.get_tensor_model_parallel_group()) + autotp.update_linear_policies() + autotp._replace_module(model) + return model + + +def gather_subparam_output(output, subparam_sizes, mp_group): + tp_world_size = dist.get_world_size(group=mp_group) + local_sizes = [size // tp_world_size for size in subparam_sizes] + output_chunks = torch.split(output, local_sizes, dim=-1) + gathered_chunks = [] + for chunk in output_chunks: + chunk = chunk.contiguous() + gathered = [torch.empty_like(chunk) for _ in range(tp_world_size)] + dist.all_gather(gathered, chunk, group=mp_group) + gathered_chunks.append(torch.cat(gathered, dim=-1)) + return torch.cat(gathered_chunks, dim=-1) + + +def assert_close_for_preferred_dtype(actual, expected): + atol = 1e-3 + rtol = 2e-2 + if preferred_dtype() is torch.float32: + atol = 1e-5 + rtol = 1e-5 + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +class TestAutoTPCustomPatterns(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def test_custom_pattern_replacement(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*linears\\.1\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*linears\\.2\\.weight$"], + "partition_type": "skip", + }, + ], + } + model = SequentialLinearModel(hidden_dim=16, nlayers=3) + model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config) + + assert isinstance(model.linears[0], LinearAllreduce) + assert isinstance(model.linears[1], LinearLayer) + assert isinstance(model.linears[2], nn.Linear) + + def test_custom_patterns_applied_via_config(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*linears\\.1\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*linears\\.2\\.weight$"], + "partition_type": "skip", + }, + ], + } + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": 2, + "partition_config": partition_config, + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=16, nlayers=3) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert isinstance(engine.module.linears[0], LinearAllreduce) + assert isinstance(engine.module.linears[1], LinearLayer) + assert isinstance(engine.module.linears[2], nn.Linear) + + def test_use_default_specs_false_skips_unmatched_layers(self): + skip_on_device() + # Verify unmatched layers remain unsharded when defaults are disabled. + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*linears\\.1\\.weight$"], + "partition_type": "column", + }, + ], + } + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": 2, + "partition_config": partition_config, + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=16, nlayers=3) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert isinstance(engine.module.linears[0], LinearAllreduce) + assert isinstance(engine.module.linears[1], LinearLayer) + assert isinstance(engine.module.linears[2], nn.Linear) + + def test_custom_module_replacement_with_patterns(self): + skip_on_device() + # Verify custom linear-like modules are partitioned via patterns. + partition_config = { + "use_default_specs": False, + "layer_specs": [ + { + "patterns": [".*custom\\.weight$"], + "partition_type": "column", + }, + ], + } + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": 2, + "partition_config": partition_config, + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = CustomLinearModel(hidden_dim=16) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert isinstance(engine.module.custom, LinearLayer) + + def test_custom_pattern_disables_fused_qkv_heuristic(self): + skip_on_device() + # Use a qkv_proj name that would trigger the fused-QKV heuristic, then + # verify custom patterns override that path and preserve correctness. + torch.manual_seed(1234) + hidden_dim = 16 + qkv_sizes = (hidden_dim, hidden_dim, hidden_dim) + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*self_attn\\.qkv_proj\\.weight$"], + "partition_type": "column", + "shape": [list(qkv_sizes), -1], + "partition_dim": 0, + }, + ], + } + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": 2, + "partition_config": partition_config, + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = QKVLinearModel(hidden_dim=hidden_dim) + baseline = deepcopy(model).to(get_accelerator().current_device(), dtype=preferred_dtype()) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + qkv_layer = engine.module.self_attn.qkv_proj + # Custom pattern should force SubParamLinearLayer (shape-based path), + # and avoid the legacy fused-QKV heuristic despite the qkv_proj name. + assert isinstance(qkv_layer, SubParamLinearLayer) + assert not isinstance(qkv_layer, fused_LinearLayer) + + assert qkv_layer.partition_dim == 0 + assert qkv_layer._subparam_sizes == qkv_sizes + assert qkv_layer._orig_weight_shape == (hidden_dim * 3, hidden_dim) + + qkv_layer.gather_params([qkv_layer.weight, qkv_layer.bias]) + torch.testing.assert_close(qkv_layer.weight, baseline.self_attn.qkv_proj.weight) + if qkv_layer.bias is not None: + torch.testing.assert_close(qkv_layer.bias, baseline.self_attn.qkv_proj.bias) + + torch.manual_seed(4321) + inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device()) + full_output = baseline(inputs) + tp_output = engine.module(inputs) + assert_close_for_preferred_dtype(tp_output, full_output) + + def test_first_match_precedence(self): + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "skip", + }, + { + "patterns": [".*linears\\.0\\.weight$"], + "partition_type": "column", + }, + ], + } + model = SequentialLinearModel(hidden_dim=16, nlayers=1) + model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config) + + assert isinstance(model.linears[0], nn.Linear) + + def test_deep_model_full_path_propagation(self): + """Verify _replace_module propagates accumulated paths through deep hierarchies. + + Uses a 4-level-deep model (layers.N.self_attn.{q,o}_proj) with patterns + that require intermediate path components (layers.N). Without correct + full_name propagation, the recursive path is truncated and patterns + that include intermediate levels will silently fail to match. + """ + skip_on_device() + partition_config = { + "use_default_specs": + False, + "layer_specs": [ + { + "patterns": [r".*layers\.\d+\.self_attn\.q_proj\.weight$"], + "partition_type": "column", + }, + { + "patterns": [r".*layers\.\d+\.self_attn\.o_proj\.weight$"], + "partition_type": "row", + }, + ], + } + model = DeepModel(hidden_dim=16, nlayers=2) + model = apply_autotp_with_partition_config(model, tp_size=2, partition_config=partition_config) + + # All 4 projections (2 layers x {q_proj, o_proj}) must be replaced. + # Before the full_name fix, 0 modules were replaced because the mangled + # path "self_attn.q_proj.weight" could not match "layers.N.self_attn...". + for i in range(2): + assert isinstance(model.layers[i].self_attn.q_proj, LinearLayer), \ + f"layers.{i}.self_attn.q_proj was not replaced (path propagation bug?)" + assert isinstance(model.layers[i].self_attn.o_proj, LinearAllreduce), \ + f"layers.{i}.self_attn.o_proj was not replaced (path propagation bug?)" + + +def test_invalid_custom_shape_rejected(): + bad_config = { + "layer_specs": [{ + "patterns": [".*"], + "partition_type": "column", + "shape": [2, [1, 1]], + }] + } + with pytest.raises(ValueError, match="nested tuple only allowed at partition_dim"): + AutoTPConfig.from_dict(bad_config) + + +class TestAutoTPFusedWeights(DistributedTest): + world_size = 2 + reuse_dist_env = False + + def test_gate_up_fused_weight_partition(self): + skip_on_device() + init_tp_engine(tp_size=2) + + hidden_dim = 8 + torch.manual_seed(42) + linear = nn.Linear(hidden_dim, + hidden_dim * 2, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + full_weight = deepcopy(linear.weight.data) + full_bias = deepcopy(linear.bias.data) + + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=(2, -1), + partition_dim=0, + name="mlp.gate_up_proj") + assert layer._subparam_sizes == (hidden_dim, hidden_dim) + assert layer.weight.shape == (hidden_dim, hidden_dim) + + layer.gather_params([layer.weight, layer.bias]) + torch.testing.assert_close(layer.weight.data, full_weight) + torch.testing.assert_close(layer.bias.data, full_bias) + + def test_gqa_uneven_qkv_fused_weight_partition(self): + skip_on_device() + init_tp_engine(tp_size=2) + + hidden_dim = 8 + q_size, k_size, v_size = 8, 4, 4 + torch.manual_seed(123) + linear = nn.Linear(hidden_dim, + q_size + k_size + v_size, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + full_weight = deepcopy(linear.weight.data) + full_bias = deepcopy(linear.bias.data) + + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=((q_size, k_size, v_size), -1), + partition_dim=0, + name="self_attn.qkv_proj") + assert layer._subparam_sizes == (q_size, k_size, v_size) + assert layer.weight.shape == ((q_size + k_size + v_size) // 2, hidden_dim) + + layer.gather_params([layer.weight, layer.bias]) + torch.testing.assert_close(layer.weight.data, full_weight) + torch.testing.assert_close(layer.bias.data, full_bias) + + def test_gqa_uneven_qkv_fused_forward(self): + skip_on_device() + groups._init_tp_mesh_device(tensor_model_parallel_size=2) + + hidden_dim = 8 + q_size, k_size, v_size = 8, 4, 4 + torch.manual_seed(321) + linear = nn.Linear(hidden_dim, + q_size + k_size + v_size, + bias=True, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + layer = SubParamLinearLayer(deepcopy(linear), + groups.get_tensor_model_parallel_group(), + shape=((q_size, k_size, v_size), -1), + partition_dim=0, + name="self_attn.qkv_proj") + + torch.manual_seed(42) + inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device()) + full_output = linear(inputs) + tp_output = layer(inputs) + + gathered_output = gather_subparam_output(tp_output, (q_size, k_size, v_size), + groups.get_tensor_model_parallel_group()) + assert_close_for_preferred_dtype(gathered_output, full_output) diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py new file mode 100644 index 000000000000..64f0b1113b16 --- /dev/null +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -0,0 +1,763 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed.comm as dist +import torch +import math +from copy import deepcopy + +from unit.common import DistributedTest, preferred_dtype +import deepspeed +from deepspeed.accelerator import get_accelerator +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.utils import groups +from contextlib import contextmanager +from torch import nn +from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode, is_autotp_training_mode +from unit.checkpoint.common import compare_lr_scheduler_states, compare_optimizer_states +import os +from deepspeed.runtime.utils import is_model_parallel_parameter + + +def skip_on_device(): + if get_accelerator().device_name() == 'xpu': + pytest.skip("XPU requires a higher version for test") + + +def reset_tp_model_init_state(): + deepspeed._TP_MODEL_INIT_ARGS = None + set_autotp_mode(training=False) + + +class DummyMPU: + + def __init__(self, tp_world_size=1): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.tp_world_size = tp_world_size + self.dp_group = dist.get_world_group() + self.tp_group = dist.get_world_group() + + def get_model_parallel_rank(self): + return self.rank % self.tp_world_size + + def get_model_parallel_world_size(self): + return self.tp_world_size + + def get_data_parallel_rank(self): + return self.rank // self.tp_world_size + + def get_data_parallel_world_size(self): + return self.world_size // self.tp_world_size + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + +class SequentialLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): + super(SequentialLinearModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) + if empty_grad: + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + self.empty_grad = empty_grad + + def forward(self, x, y): + if len(self.linears) == 1: + x = self.linears[0](x) + else: + for i, l in enumerate(self.linears): + x = self.linears[i](x) + return self.cross_entropy_loss(x, y) + + +@contextmanager +def should_assert_with_msg(expected_message): + try: + yield + except AssertionError as e: + if dist.get_rank() == 0: + print(expected_message) + print(str(e)) + if str(e) == expected_message: + pass + else: + raise e + else: + raise AssertionError(f"Expected AssertionError with message '{expected_message}' " + "but no exception was raised.") + + +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpParallelStates(DistributedTest): + world_size = 4 + + def test(self, tp_size: int): + skip_on_device() + dp_size = 4 / tp_size + hidden_dim = 128 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + }, + "zero_optimization": { + "stage": 0 + } + } + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert groups.get_tensor_model_parallel_world_size() == tp_size + assert groups.get_data_parallel_world_size() == dp_size + + +class TestTpModelInitCompatibility(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_tp_model_init_merges_config(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 0, + } + } + engine, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=DummyMPU()) + assert engine.autotp_size() == 1 + assert is_autotp_training_mode() + + def test_tp_model_init_config_autotp_size_mismatch(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2, + }, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="tensor_parallel.autotp_size"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + def test_tp_model_init_autocreates_tp_group(self): + skip_on_device() + reset_tp_model_init_state() + # Verify tp_model_init creates TP groups when no mpu is provided. + model = SimpleModel(hidden_dim=8) + tp_size = 2 + deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype()) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + }, + "zero_optimization": { + "stage": 0, + } + } + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert engine.autotp_size() == tp_size + assert groups.get_tensor_model_parallel_world_size() == tp_size + assert groups.get_data_parallel_world_size() == dist.get_world_size() // tp_size + + def test_tp_model_init_tp_group_rejects_mpu(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + tp_group = dist.new_group(ranks=[0]) + deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype(), tp_group=tp_group) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="tp_model_init provided tp_group"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + def test_tp_model_init_dtype_mismatch(self): + skip_on_device() + reset_tp_model_init_state() + model = SimpleModel(hidden_dim=8) + deepspeed.tp_model_init(model, tp_size=1, dtype=torch.float16) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "bf16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 0, + } + } + with pytest.raises(ValueError, match="Conflicting dtype"): + deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU()) + + @pytest.mark.sequential + @pytest.mark.parametrize("tp_size", [2, 4]) + @pytest.mark.parametrize("tp_overlap_comm", [True, False]) + def test_tp_model_init_row_parallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False, use_tp_model_init=True) + + @pytest.mark.sequential + @pytest.mark.parametrize("tp_size", [2, 4]) + @pytest.mark.parametrize("tp_overlap_comm", [True, False]) + def test_tp_model_init_column_parallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True, use_tp_model_init=True) + + +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpDataloaderCorrectness(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test(self, tp_size: int): + skip_on_device() + hidden_dim = 128 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + torch.manual_seed(42) + + data_loader = random_dataloader(model=model, + total_samples=3, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + with should_assert_with_msg( + "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." + ): + for batch in data_loader: + # batch[0].requires_grad = requires_grad + batch[0] += dist.get_rank() + model(batch[0], batch[1]) + + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=3, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + for batch in data_loader: + dist.broadcast(batch[0], + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + dist.broadcast(batch[1], + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + model(batch[0], batch[1]) + + +def process_linear_layer(hidden_dim, input): + torch.manual_seed(42) + torch_linear = nn.Linear(hidden_dim, + hidden_dim, + dtype=preferred_dtype(), + device=get_accelerator().current_device()) + torch_out = torch_linear(input) + torch_loss = torch_out.sum() + torch_loss.backward() + return torch_linear, torch_out + + +def run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel, use_tp_model_init=False): + skip_on_device() + hidden_dim = 128 + batch_size_per_device = 1 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size, + "tp_overlap_comm": tp_overlap_comm + }, + "zero_optimization": { + "stage": 0, + } + } + partition_type = "column" if column_parallel else "row" + config_dict["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": partition_type, + }], + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=hidden_dim) + if use_tp_model_init: + reset_tp_model_init_state() + deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype()) + mpu = DummyMPU(tp_world_size=tp_size) + model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict, + mpu=mpu) + else: + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + input = torch.randn(batch_size_per_device, + hidden_dim, + dtype=preferred_dtype(), + requires_grad=True, + device=get_accelerator().current_device()) + dist.broadcast(input, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + + # Note: correctness checks below use standalone TP wrappers and do not + # rely on the model's AutoTP-partitioned parameters. + torch_linear, torch_out = process_linear_layer(hidden_dim, input) + if column_parallel: + linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + out = linear(input.to(get_accelerator().current_device())) + loss = out.sum() + loss.backward() + + cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] + torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] + + torch.testing.assert_close(linear.bias.grad, + torch_bias_grad.to(get_accelerator().current_device()), + atol=1e-3, + rtol=1e-3) + torch.testing.assert_close(linear.weight.grad, + torch_grad.to(get_accelerator().current_device()), + atol=1e-3, + rtol=1e-3) + torch.testing.assert_close(cur_device_out.to(get_accelerator().current_device()).contiguous(), + out.contiguous(), + atol=1e-2, + rtol=1e-2) + else: + linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + out = linear(input_.to(get_accelerator().current_device())) + loss = out.sum() + loss.backward() + + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()] + torch_bias_grad = torch_linear.bias.grad + torch.testing.assert_close(linear.bias.grad, + torch_bias_grad.to(get_accelerator().current_device()), + atol=1e-3, + rtol=1e-3) + torch.testing.assert_close(linear.weight.grad, + torch_grad.to(get_accelerator().current_device()), + atol=1e-3, + rtol=1e-3) + torch.testing.assert_close(out, torch_out.to(get_accelerator().current_device()), atol=1e-2, rtol=1e-2) + + +@pytest.mark.sequential +@pytest.mark.parametrize("tp_size", [2, 4]) +@pytest.mark.parametrize("tp_overlap_comm", [True, False]) +class TestTpLayerFwdBwd(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def testRowParallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=False) + + def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool): + run_tp_layer_fwd_bwd(tp_size, tp_overlap_comm, column_parallel=True) + + +# @pytest.mark.sequential +class TestParamsGather(DistributedTest): + world_size = 4 + reuse_dist_env = False + + @pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"]) + def test(self, layer_type): + skip_on_device() + tp_size = 4 + hidden_dim = 128 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + torch.manual_seed(42) + model = SequentialLinearModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + torch_linear = nn.Linear(hidden_dim, hidden_dim, dtype=preferred_dtype(), device="cpu") + total_params = sum(p.numel() for p in torch_linear.parameters()) + tp_layer = None + if layer_type == "linear": + tp_layer = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + elif layer_type == "linearallreduce": + tp_layer = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + else: + raise ValueError(f"Invalid linear type: {config_dict['linear_type']}") + + tp_params = sum(p.numel() for p in tp_layer.parameters()) + + expected_tp_params = 0 + # compute expected TP params: + # - column-parallel (LinearLayer): weight & bias both split => total // tp_size + # - row-parallel (LinearAllreduce): weight split, bias (1d tensors) replicated + if layer_type == "linearallreduce": + weight_params = torch_linear.weight.numel() + bias_params = torch_linear.bias.numel() + expected_tp_params = weight_params // tp_size + bias_params + else: + expected_tp_params = total_params // tp_size + assert expected_tp_params == tp_params, ( + f"{layer_type}: expected {expected_tp_params} tp params, got {tp_params}") + + for name, param in tp_layer.named_parameters(recurse=False): + if is_model_parallel_parameter(param): + param.gather_params([param]) + + torch_linear = torch_linear.to(get_accelerator().current_device()) + is_same_weights = all( + torch.equal(param1, param2) for param1, param2 in zip(tp_layer.parameters(), torch_linear.parameters())) + + assert is_same_weights + + params1 = sum(p.numel() for p in tp_layer.parameters()) + assert total_params == params1 + + for name, param in tp_layer.named_parameters(recurse=False): + if is_model_parallel_parameter(param): + param._tp_partition([param]) + + tp_params2 = sum(p.numel() for p in tp_layer.parameters()) + + assert expected_tp_params == tp_params2 + + +def dummy_init_engine(config): + # This is a dummy initialization function for the DeepSpeed engine. + # We only need to use the config to initialize the distributed settings for the test. + # Add default partition_config for simple test models if not provided + if "tensor_parallel" in config and "partition_config" not in config["tensor_parallel"]: + config["tensor_parallel"]["partition_config"] = { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + } + model = SequentialLinearModel(hidden_dim=8) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + +def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, group, return_global_copy=False): + model = SequentialLinearModel(hidden_dim=hidden_dim, nlayers=nlayers).to(preferred_dtype()) + base_model = None + if return_global_copy: + base_model = deepcopy(model) + for i in linear_indices: + layer = LinearLayer(model.linears[i], group) + model.linears[i] = layer + + for i in allreduce_indices: + layer = LinearAllreduce(model.linears[i], group) + model.linears[i] = layer + + return model, base_model + + +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestSave(DistributedTest): + + world_size = 4 + reuse_dist_env = False + + def test_save_original_weight(self, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + dummy_init_engine(config_dict) + torch.manual_seed(42) + + model, base_model = prepare_tp_model(hidden_dim, + 8, [2, 5], [3, 6], + groups.get_tensor_model_parallel_group(), + return_global_copy=True) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + cur_params_numel = sum(p.numel() for p in model.parameters()) + base_params_numel = sum(p.numel() for p in base_model.parameters()) + assert cur_params_numel < base_params_numel + + tp_state_dict = model._consolidated_16bit_state_dict() + + def compare_state_dicts(state_dict1, state_dict2): + if state_dict1.keys() != state_dict2.keys(): + print("The state_dicts have different keys!") + return False + + for key in state_dict1: + if not torch.allclose(state_dict1[key], state_dict2[key], atol=1e-3): + assert state_dict1[key].device == "cpu" + print(f"Parameters for {key} are different!") + return False + + return True + + base_state_dict = base_model.state_dict() + if dist.get_rank() == 0: + # we should consider the case when zero3 is used in the future. + assert compare_state_dicts(base_state_dict, tp_state_dict), "State_dict is not the same!" + else: + assert tp_state_dict is None, "noly rank0 should have the state_dict" + + def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000 + } + } + } + + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + dummy_init_engine(config_dict) + + trained_model, _ = prepare_tp_model(hidden_dim, 8, [2, 5], [3, 6], groups.get_tensor_model_parallel_group()) + loaded_model, _ = prepare_tp_model(hidden_dim, 8, [2, 5], [3, 6], groups.get_tensor_model_parallel_group()) + + trained_model, _, _, _ = deepspeed.initialize(model=trained_model, + model_parameters=trained_model.parameters(), + config=config_dict) + torch.manual_seed(42) + + data_loader = random_dataloader(model=trained_model, + total_samples=3, + hidden_dim=hidden_dim, + device=trained_model.device, + dtype=preferred_dtype()) + ckpt_path = os.path.join(tmpdir, 'tp_saved_checkpoint') + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = trained_model(batch[0], batch[1]) + loss = loss + trained_model.backward(loss) + trained_model.step() + trained_model.save_checkpoint(ckpt_path) + + loaded_model, _, _, _ = deepspeed.initialize(model=loaded_model, + model_parameters=loaded_model.parameters(), + config=config_dict) + loaded_model.load_checkpoint(ckpt_path, load_optimizer_states=True, load_lr_scheduler_states=True) + compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16=(preferred_dtype() == torch.float16)) + compare_lr_scheduler_states(trained_model, loaded_model) + + +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpGradNorm(DistributedTest): + + world_size = 4 + reuse_dist_env = False + + def test(self, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + if zero_stage == 0: + pytest.skip( + "This test has an overflow data and needs to implement an overflow skip mechanism in BF16_optimizer" + ) + config_dict["bf16"] = {"enabled": True} + + torch.manual_seed(42) + + dummy_init_engine(config=config_dict) + tp_model, base_model = prepare_tp_model(hidden_dim, + 8, [2, 5], [3, 6], + groups.get_tensor_model_parallel_group(), + return_global_copy=True) + + base_model, base_optimizer, _, _ = deepspeed.initialize(model=base_model, + model_parameters=base_model.parameters(), + config=config_dict) + data_loader = random_dataloader(model=base_model, + total_samples=20, + hidden_dim=hidden_dim, + device=base_model.device, + dtype=preferred_dtype()) + + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = base_model(batch[0], batch[1]) + loss = loss + base_model.backward(loss) + base_model.step() + + base_norm = base_optimizer._global_grad_norm + + base_model.destroy() + + tp_model, tp_optimizer, _, _ = deepspeed.initialize(model=tp_model, + model_parameters=tp_model.parameters(), + config=config_dict) + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = tp_model(batch[0], batch[1]) + loss = loss + tp_model.backward(loss) + tp_model.step() + + tp_norm = tp_optimizer._global_grad_norm + + assert math.isclose(base_norm, tp_norm, abs_tol=1e-3), f"base_norm: {base_norm}, tp_norm: {tp_norm}" + tp_params_numel = sum(p.numel() for p in tp_model.parameters()) + base_params_numel = sum(p.numel() for p in base_model.parameters()) + assert tp_params_numel < base_params_numel, f"tp_params_numel: {tp_params_numel}, base_params_numel: {base_params_numel}" diff --git a/tests/unit/model_parallelism/test_configurable_parallel_mp.py b/tests/unit/model_parallelism/test_configurable_parallel_mp.py index f8491b8411a7..a7b0d3431ee9 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_mp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_mp.py @@ -13,14 +13,13 @@ from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest, DistributedFixture from unit.megatron_model import get_gpt2_model, get_megatron_version -from unit.util import required_minimum_torch_version, required_maximum_torch_version +from deepspeed.utils.torch import required_torch_version -pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5), - reason='Megatron-LM package requires Pytorch version 1.5 or above') -pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version=1, minor_version=13), - reason='Megatron-LM package requires Pytorch version 1.13 or below') +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.5, max_version=1.13), + reason='Megatron-LM package requires Pytorch version >=1.5 and <=1.13') +# TODO: integrated testing of TP and ZeRO 1/2/3 def get_deepspeed_model(model): ds_config_dict = { "train_micro_batch_size_per_gpu": 1, @@ -60,6 +59,7 @@ def inputs(self, bs=1, seq_len=20): class TestConfigurableMP(ConfigurableMP): @pytest.mark.world_size(1) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_gpt2_basic(self, tmpdir, inputs): args_defaults = { 'num_layers': 2, @@ -87,6 +87,7 @@ def test_gpt2_basic(self, tmpdir, inputs): atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}" @pytest.mark.world_size(2) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_gpt2_mp2_no_resize(self, tmpdir, inputs): args_defaults = { 'num_layers': 2, @@ -148,6 +149,7 @@ def run(self, inputs, class_tmpdir): class TestConfigurableResizeMP(ConfigurableMP): world_size = [1, 4] + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test(self, baseline_mp2, inputs, class_tmpdir): args_defaults = { 'num_layers': 2, @@ -168,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir): test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name)) if dist.get_rank() == 0: load_path = os.path.join(class_tmpdir, "output.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) test = test.cpu() assert torch.allclose( baseline, test, diff --git a/tests/unit/model_parallelism/test_configurable_parallel_pp.py b/tests/unit/model_parallelism/test_configurable_parallel_pp.py index aaab061ed056..df469044e186 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_pp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_pp.py @@ -15,12 +15,10 @@ from unit.megatron_model import MockGPT2ModelPipe as GPT2ModelPipe from deepspeed.utils import RepeatingLoader from deepspeed.accelerator import get_accelerator -from unit.util import required_minimum_torch_version, required_maximum_torch_version +from deepspeed.utils.torch import required_torch_version -pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5), - reason='Megatron-LM package requires Pytorch version 1.5 or above') -pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version=1, minor_version=13), - reason='Megatron-LM package requires Pytorch version 1.13 or below') +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.5, max_version=1.13), + reason='Megatron-LM package requires Pytorch version >=1.5 and <=1.13') def get_deepspeed_model(model): @@ -69,6 +67,7 @@ class TestConfigurablePP(ConfigurablePP): pp_size = 2 world_size = 4 # mp_size * pp_size + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_pp_basic(self, inputs, tmpdir): # basic test case, mp_size=2, pp_size=2, verify ckpt saving/loading. args_defaults = { @@ -226,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz assert torch.is_tensor(test[0][0]) test = test[0][0].cpu() load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) assert torch.allclose( baseline, test, atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}" @@ -234,30 +233,35 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz # These tests are divided by baseline model worldsize and test model worldsize @pytest.mark.world_size(1) @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 1)]) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_world_size_2to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize, pp_resize): self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize) @pytest.mark.world_size(1) @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 1, 1)]) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_world_size_4to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize, pp_resize): self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize) @pytest.mark.world_size(2) @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 2, 1)]) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_world_size_4to2(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize, pp_resize): self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize) @pytest.mark.world_size(4) @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 1, 2, 2)]) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_world_size_1to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws1, mp_size, pp_size, mp_resize, pp_resize): self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize) @pytest.mark.world_size(4) @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 4), (2, 1, 2, 2)]) + @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.") def test_world_size_2to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize, pp_resize): self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize) diff --git a/tests/unit/model_parallelism/test_tp_plan_e2e.py b/tests/unit/model_parallelism/test_tp_plan_e2e.py new file mode 100644 index 000000000000..1c85b45947dc --- /dev/null +++ b/tests/unit/model_parallelism/test_tp_plan_e2e.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import groups +from deepspeed.runtime.utils import is_model_parallel_parameter +from unit.common import DistributedTest, preferred_dtype + + +def skip_on_device(): + return + + +class TestTPPlanEndToEnd(DistributedTest): + world_size = 2 + + class SimpleHFModel(torch.nn.Module): + + class Block(torch.nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_size, hidden_size * 2) + self.o_proj = torch.nn.Linear(hidden_size * 2, hidden_size) + + def forward(self, x): + return self.o_proj(self.q_proj(x)) + + def __init__(self, hidden_size=64): + super().__init__() + self.hidden_size = hidden_size + self.config = type( + "Config", + (), + {"base_model_tp_plan": { + "*.q_proj": "colwise", + "*.o_proj": "rowwise", + }}, + )() + self.layers = torch.nn.ModuleList([self.Block(hidden_size)]) + + def forward(self, x): + return self.layers[0](x) + + def _setup_baseline_linears(self, model): + torch_q = torch.nn.Linear(model.hidden_size, model.hidden_size * 2) + torch_o = torch.nn.Linear(model.hidden_size * 2, model.hidden_size) + torch_q.load_state_dict(model.layers[0].q_proj.state_dict(), strict=True) + torch_o.load_state_dict(model.layers[0].o_proj.state_dict(), strict=True) + + if preferred_dtype() == torch.float16: + torch_q = torch_q.half() + torch_o = torch_o.half() + elif preferred_dtype() == torch.bfloat16: + torch_q = torch_q.bfloat16() + torch_o = torch_o.bfloat16() + + device = get_accelerator().current_device_name() + torch_q = torch_q.to(device) + torch_o = torch_o.to(device) + return torch_q, torch_o + + def _compare_tp_gradients(self, model, torch_q, torch_o, input_tensor, engine): + + def _get_grad(param): + if param.grad is not None: + return param.grad + return getattr(param, "grad_accum", None) + + torch_q.zero_grad(set_to_none=True) + torch_o.zero_grad(set_to_none=True) + torch_q_out = torch_q(input_tensor) + torch_o_out = torch_o(torch_q_out) + torch_loss = torch_o_out.sum() + torch_loss.backward() + + output = engine(input_tensor) + loss = output.sum() + engine.backward(loss) + + tp_rank = groups.get_tensor_model_parallel_rank() + tp_size = engine.autotp_size() + q_proj = model.layers[0].q_proj + o_proj = model.layers[0].o_proj + + torch_q_grad = torch.chunk(torch_q.weight.grad, tp_size, dim=0)[tp_rank] + torch_q_bias_grad = torch.chunk(torch_q.bias.grad, tp_size, dim=0)[tp_rank] + torch_o_grad = torch.chunk(torch_o.weight.grad, tp_size, dim=1)[tp_rank] + + q_weight_grad = _get_grad(q_proj.weight) + q_bias_grad = _get_grad(q_proj.bias) if q_proj.bias is not None else None + o_weight_grad = _get_grad(o_proj.weight) + + torch.testing.assert_close(q_weight_grad, torch_q_grad, atol=2e-2, rtol=2e-2) + if q_bias_grad is not None: + torch.testing.assert_close(q_bias_grad, torch_q_bias_grad, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(o_weight_grad, torch_o_grad, atol=2e-2, rtol=2e-2) + + def _gather_and_compare_params(self, model, torch_q, torch_o, compare_values=True): + q_proj = model.layers[0].q_proj + o_proj = model.layers[0].o_proj + + original_shards = [] + for _, param in q_proj.named_parameters(recurse=False): + if is_model_parallel_parameter(param): + original_shards.append((param, param.data.detach().clone())) + + for _, param in o_proj.named_parameters(recurse=False): + if is_model_parallel_parameter(param): + original_shards.append((param, param.data.detach().clone())) + + for param, _ in original_shards: + param.gather_params([param]) + + if compare_values: + torch.testing.assert_close(q_proj.weight, torch_q.weight, atol=2e-2, rtol=2e-2) + if q_proj.bias is not None: + torch.testing.assert_close(q_proj.bias, torch_q.bias, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(o_proj.weight, torch_o.weight, atol=2e-2, rtol=2e-2) + if o_proj.bias is not None: + torch.testing.assert_close(o_proj.bias, torch_o.bias, atol=2e-2, rtol=2e-2) + + for param, original in original_shards: + param._tp_partition([param]) + torch.testing.assert_close(param.data, original, atol=2e-2, rtol=2e-2) + + def test_tp_plan_basic_training(self): + skip_on_device() + + model = self.SimpleHFModel() + if preferred_dtype() == torch.float16: + model = model.half() + elif preferred_dtype() == torch.bfloat16: + model = model.bfloat16() + + torch_q, torch_o = self._setup_baseline_linears(model) + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 0 + }, + "steps_per_print": 1, + } + + if preferred_dtype() == torch.float16: + ds_config["fp16"] = {"enabled": True} + elif preferred_dtype() == torch.bfloat16: + ds_config["bf16"] = {"enabled": True} + + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config) + + assert engine.autotp_size() == 2 + + input_tensor = torch.randn(2, 4, 64, dtype=preferred_dtype()).to(get_accelerator().current_device_name()) + dist.broadcast( + input_tensor, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + if preferred_dtype() == torch.float16: + torch_q = torch_q.half() + torch_o = torch_o.half() + elif preferred_dtype() == torch.bfloat16: + torch_q = torch_q.bfloat16() + torch_o = torch_o.bfloat16() + + self._compare_tp_gradients(model, torch_q, torch_o, input_tensor, engine) + + def test_tp_plan_with_zero1(self): + skip_on_device() + + model = self.SimpleHFModel() + torch_q, torch_o = self._setup_baseline_linears(model) + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 1 + }, + "steps_per_print": 1, + } + + if preferred_dtype() == torch.float16: + ds_config["fp16"] = {"enabled": True} + elif preferred_dtype() == torch.bfloat16: + ds_config["bf16"] = {"enabled": True} + + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config) + + assert engine.autotp_size() == 2 + + for _ in range(1): + input_tensor = torch.randn(2, 4, 64, dtype=preferred_dtype()).to(get_accelerator().current_device_name()) + dist.broadcast( + input_tensor, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + self._gather_and_compare_params(model, torch_q, torch_o, compare_values=False) + output = engine(input_tensor) + loss = output.mean() + engine.backward(loss) + engine.step() + + for p in engine.parameters(): + assert not torch.isnan(p).any() + + def test_tp_plan_with_zero2(self): + skip_on_device() + + model = self.SimpleHFModel() + torch_q, torch_o = self._setup_baseline_linears(model) + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 2 + }, + "steps_per_print": 1, + } + + if preferred_dtype() == torch.float16: + ds_config["fp16"] = {"enabled": True} + elif preferred_dtype() == torch.bfloat16: + ds_config["bf16"] = {"enabled": True} + + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config) + + assert engine.autotp_size() == 2 + + input_tensor = torch.randn(2, 4, 64, dtype=preferred_dtype()).to(get_accelerator().current_device_name()) + dist.broadcast( + input_tensor, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + self._gather_and_compare_params(model, torch_q, torch_o, compare_values=False) + output = engine(input_tensor) + loss = output.mean() + engine.backward(loss) + engine.step() diff --git a/tests/unit/model_parallelism/test_tp_plan_real_models.py b/tests/unit/model_parallelism/test_tp_plan_real_models.py new file mode 100644 index 000000000000..7bea8643700e --- /dev/null +++ b/tests/unit/model_parallelism/test_tp_plan_real_models.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed.comm as dist +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import groups +from unit.common import DistributedTest + + +def skip_on_device(): + if get_accelerator().device_name() == "xpu": + pytest.skip("XPU requires a higher version for test") + + +class TestTPPlanRealHFModels(DistributedTest): + """End-to-end tests using real HuggingFace models""" + + world_size = 2 + + def test_qwen2_tp_plan_with_zero2(self): + """Test Qwen2 model + tp_plan + ZeRO2""" + skip_on_device() + + try: + from transformers import AutoModelForCausalLM, AutoConfig + except ImportError: + pytest.skip("transformers not installed") + + # Create small Qwen2 config + config = AutoConfig.from_pretrained( + "Qwen/Qwen2-7B", + vocab_size=1000, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + ) + + model = AutoModelForCausalLM.from_config(config) + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 2 + }, + "bf16": { + "enabled": True + }, + "steps_per_print": 1, + } + + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config) + + assert engine.autotp_size() == 2 + + # Train for a few steps + for _ in range(3): + input_ids = torch.randint(0, 1000, (1, 16)).to(get_accelerator().current_device_name()) + dist.broadcast( + input_ids, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + outputs = engine(input_ids, labels=input_ids) + engine.backward(outputs.loss) + engine.step() + + assert not torch.isnan(outputs.loss) + + def test_custom_model_with_custom_tp_plan(self): + """Test custom model + custom tp_plan""" + skip_on_device() + + class CustomTransformerModel(torch.nn.Module): + + def __init__(self, hidden_size=64): + super().__init__() + self.config = type( + "Config", + (), + { + "base_model_tp_plan": { + "encoder.*.attention.query": "colwise", + "encoder.*.attention.key": "colwise", + "encoder.*.attention.value": "colwise", + "encoder.*.attention.output": "rowwise", + "encoder.*.ffn.intermediate": "colwise", + "encoder.*.ffn.output": "rowwise", + } + }, + )() + + # Simple encoder layers + self.encoder = torch.nn.ModuleList([ + torch.nn.ModuleDict({ + "attention": + torch.nn.ModuleDict({ + "query": torch.nn.Linear(hidden_size, hidden_size), + "key": torch.nn.Linear(hidden_size, hidden_size), + "value": torch.nn.Linear(hidden_size, hidden_size), + "output": torch.nn.Linear(hidden_size, hidden_size), + }), + "ffn": + torch.nn.ModuleDict({ + "intermediate": torch.nn.Linear(hidden_size, hidden_size * 4), + "output": torch.nn.Linear(hidden_size * 4, hidden_size), + }), + }) for _ in range(2) + ]) + + def forward(self, x): + for layer in self.encoder: + # Simplified attention + q = layer.attention.query(x) + k = layer.attention.key(x) + v = layer.attention.value(x) + attn_out = layer.attention.output(q + k + v) + + # FFN + intermediate = torch.relu(layer.ffn.intermediate(attn_out)) + x = layer.ffn.output(intermediate) + return x + + model = CustomTransformerModel(hidden_size=64) + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": 2 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 0 + }, + "bf16": { + "enabled": True + }, + } + + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config) + + assert engine.autotp_size() == 2 + + # Training step + input_tensor = torch.randn(2, 4, 64, dtype=torch.bfloat16).to(get_accelerator().current_device_name()) + dist.broadcast( + input_tensor, + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group(), + ) + output = engine(input_tensor) + loss = output.mean() + engine.backward(loss) + engine.step() diff --git a/tests/unit/modeling.py b/tests/unit/modeling.py index 7930fdafe541..835975697afc 100644 --- a/tests/unit/modeling.py +++ b/tests/unit/modeling.py @@ -29,17 +29,11 @@ import json import logging import math -import os -import shutil -import tarfile -import tempfile from io import open import torch from torch import nn -from torch.nn import CrossEntropyLoss from torch.utils import checkpoint -import deepspeed.comm as dist from torch.nn import Module import torch.nn.functional as F @@ -51,84 +45,6 @@ from deepspeed.accelerator import get_accelerator logger = logging.getLogger(__name__) - -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", -} -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' -TF_WEIGHTS_NAME = 'model.ckpt' - - -def load_tf_weights_in_bert(model, tf_checkpoint_path): - """ Load tf checkpoints in a pytorch model - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") - raise - tf_path = os.path.abspath(tf_checkpoint_path) - print("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m"] for n in name): - print("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - l = re.split(r'_(\d+)', m_name) - else: - l = [m_name] - if l[0] == 'kernel' or l[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias' or l[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif l[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - else: - pointer = getattr(pointer, l[0]) - if len(l) >= 2: - num = int(l[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - """ @torch.jit.script def f_gelu(x): @@ -299,8 +215,7 @@ def __init__(self, if isinstance(vocab_size_or_config_json_file, str): with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value + self.__dict__.update(json_config) elif isinstance(vocab_size_or_config_json_file, int): self.vocab_size = vocab_size_or_config_json_file self.hidden_size = hidden_size @@ -323,8 +238,7 @@ def __init__(self, def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" config = BertConfig(vocab_size_or_config_json_file=-1) - for key, value in json_object.items(): - config.__dict__[key] = value + config.__dict__.update(json_object) return config @classmethod @@ -373,38 +287,6 @@ def forward(self, x): return self.weight * x + self.bias -class BertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - - def __init__(self, config): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - class BertSelfAttention(nn.Module): def __init__(self, i, config, weights, biases): @@ -601,16 +483,6 @@ def __init__(self, config, weights, biases): def get_grads(self): return self.grads - # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - # all_encoder_layers = [] - # for layer_module in self.layer: - # hidden_states = layer_module(hidden_states, attention_mask) - # if output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # if not output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # return all_encoder_layers - def get_modules(self, big_node, input): for mdl in big_node.named_children(): self.graph.append(mdl) @@ -650,875 +522,3 @@ def custom_forward(*inputs): if not output_all_encoded_layers or checkpoint_activations: all_encoder_layers.append((hidden_states)) return all_encoder_layers - - -#class BertEncoder(nn.Module): -# def __init__(self, config): -# super(BertEncoder, self).__init__() -# layer = BertLayer(config) -# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) -# -# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): -# all_encoder_layers = [] -# for layer_module in self.layer: -# hidden_states = layer_module(hidden_states, attention_mask) -# if output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# if not output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# return all_encoder_layers - - -class BertPooler(nn.Module): - - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act="tanh") - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense_act(first_token_tensor) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - - def __init__(self, config): - super(BertPredictionHeadTransform, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act=config.hidden_act) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertLMPredictionHead, self).__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), - bert_model_embedding_weights.size(0), - bias=False) - self.decoder.weight = bert_model_embedding_weights - self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - get_accelerator().range_push("decoder input.size() = {}, weight.size() = {}".format( - hidden_states.size(), self.decoder.weight.size())) - hidden_states = self.decoder(hidden_states) + self.bias - get_accelerator().range_pop() - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertOnlyMLMHead, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - - def __init__(self, config): - super(BertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertPreTrainingHeads, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - def __init__(self, config, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__() - if not isinstance(config, BertConfig): - raise ValueError("Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__)) - self.config = config - - def init_bert_weights(self, module): - """ Initialize the weights. - """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path, - state_dict=None, - cache_dir=None, - from_tf=False, - *inputs, - **kwargs): - """ - Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] - else: - archive_file = pretrained_model_name_or_path - if resolved_archive_file == archive_file: # noqa: F821 - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format(archive_file, - resolved_archive_file)) # noqa: F821 - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: # noqa: F821 - serialization_dir = resolved_archive_file # noqa: F821 - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, # noqa: F821 - tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: # noqa: F821 - archive.extractall(tempdir) - serialization_dir = tempdir - # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = BertConfig.from_json_file(config_file) - logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu' if not get_accelerator().is_available() else None) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - if from_tf: - # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) - return load_tf_weights_in_bert(model, weights_path) - # Load from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, - error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - start_prefix = '' - if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): - start_prefix = 'bert.' - load(model, prefix=start_prefix) - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, - unexpected_keys)) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - return model - - -class BertModel(BertPreTrainedModel): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Params: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_all_encoded_layers=True, - checkpoint_activations=False): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - checkpoint_activations=checkpoint_activations) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - -class BertForPreTraining(BertPreTrainedModel): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: - - the masked language modeling head, and - - the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `masked_lm_labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `masked_lm_labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForPreTraining(config) - masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, args): - super(BertForPreTraining, self).__init__(config) - self.summary_writer = None - if dist.get_rank() == 0: - self.summary_writer = args.summary_writer - self.samples_per_step = dist.get_world_size() * args.train_batch_size - self.sample_count = self.samples_per_step - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def log_summary_writer(self, logs: dict, base='Train'): - if dist.get_rank() == 0: - module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) - for key, log in logs.items(): - self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', log, self.sample_count) - self.sample_count += self.samples_per_step - - def forward(self, batch, log=True): - #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): - input_ids = batch[1] - token_type_ids = batch[3] - attention_mask = batch[2] - masked_lm_labels = batch[5] - next_sentence_label = batch[4] - checkpoint_activations = False - - sequence_output, pooled_output = self.bert(input_ids, - token_type_ids, - attention_mask, - output_all_encoded_layers=False, - checkpoint_activations=checkpoint_activations) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) - total_loss = masked_lm_loss + next_sentence_loss - # if log: - # self.log_summary_writer(logs={'train_loss': total_loss.item()}) - return total_loss - else: - return prediction_scores, seq_relationship_score - - -class BertForMaskedLM(BertPreTrainedModel): - """BERT model with the masked language modeling head. - This module comprises the BERT model followed by the masked language modeling head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - - Outputs: - if `masked_lm_labels` is not `None`: - Outputs the masked language modeling loss. - if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForMaskedLM(config) - masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - masked_lm_labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - prediction_scores = self.cls(sequence_output) - - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - return masked_lm_loss - else: - return prediction_scores - - -class BertForNextSentencePrediction(BertPreTrainedModel): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `next_sentence_label` is not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `next_sentence_label` is `None`: - Outputs the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForNextSentencePrediction(config) - seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - next_sentence_label=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - seq_relationship_score = self.cls(pooled_output) - - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - return next_sentence_loss - else: - return seq_relationship_score - - -class BertForSequenceClassification(BertPreTrainedModel): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels): - super(BertForSequenceClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForMultipleChoice(BertPreTrainedModel): - """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_choices = 2 - - model = BertForMultipleChoice(config, num_choices) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_choices): - super(BertForMultipleChoice, self).__init__(config) - self.num_choices = num_choices - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert(flat_input_ids, - flat_token_type_ids, - flat_attention_mask, - output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, self.num_choices) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - return loss - else: - return reshaped_logits - - -class BertForTokenClassification(BertPreTrainedModel): - """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForTokenClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels): - super(BertForTokenClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForQuestionAnswering(BertPreTrainedModel): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - - Outputs: - if `start_positions` and `end_positions` are not `None`: - Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - if `start_positions` or `end_positions` is `None`: - Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - start_positions=None, - end_positions=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py index 7058c1a744fd..7e7b46a0d4d5 100644 --- a/tests/unit/modelingpreln.py +++ b/tests/unit/modelingpreln.py @@ -29,106 +29,18 @@ import json import logging import math -import os -import shutil -import tarfile -import tempfile from io import open import torch from torch import nn -from torch.nn import CrossEntropyLoss from torch.utils import checkpoint -import deepspeed.comm as dist from torch.nn import Module import torch.nn.functional as F import torch.nn.init as init from deepspeed.accelerator import get_accelerator -#from numba import cuda - -#from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax - logger = logging.getLogger(__name__) - -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", -} -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' -TF_WEIGHTS_NAME = 'model.ckpt' - - -def load_tf_weights_in_bert(model, tf_checkpoint_path): - """ Load tf checkpoints in a pytorch model - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") - raise - tf_path = os.path.abspath(tf_checkpoint_path) - print("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m"] for n in name): - print("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - l = re.split(r'_(\d+)', m_name) - else: - l = [m_name] - if l[0] == 'kernel' or l[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias' or l[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif l[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - else: - pointer = getattr(pointer, l[0]) - if len(l) >= 2: - num = int(l[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - """ @torch.jit.script def f_gelu(x): @@ -695,16 +607,6 @@ def __init__(self, config, weights, biases): def get_grads(self): return self.grads - # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - # all_encoder_layers = [] - # for layer_module in self.layer: - # hidden_states = layer_module(hidden_states, attention_mask) - # if output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # if not output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # return all_encoder_layers - def get_modules(self, big_node, input): for mdl in big_node.named_children(): self.graph.append(mdl) @@ -745,875 +647,3 @@ def custom_forward(*inputs): hidden_states = self.FinalLayerNorm(hidden_states) all_encoder_layers.append((hidden_states)) return all_encoder_layers - - -#class BertEncoder(nn.Module): -# def __init__(self, config): -# super(BertEncoder, self).__init__() -# layer = BertLayer(config) -# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) -# -# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): -# all_encoder_layers = [] -# for layer_module in self.layer: -# hidden_states = layer_module(hidden_states, attention_mask) -# if output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# if not output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# return all_encoder_layers - - -class BertPooler(nn.Module): - - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act="tanh") - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense_act(first_token_tensor) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - - def __init__(self, config): - super(BertPredictionHeadTransform, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act=config.hidden_act) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertLMPredictionHead, self).__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), - bert_model_embedding_weights.size(0), - bias=False) - self.decoder.weight = bert_model_embedding_weights - self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - get_accelerator().range_push("decoder input.size() = {}, weight.size() = {}".format( - hidden_states.size(), self.decoder.weight.size())) - hidden_states = self.decoder(hidden_states) + self.bias - get_accelerator().range_pop() - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertOnlyMLMHead, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - - def __init__(self, config): - super(BertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - - def __init__(self, config, bert_model_embedding_weights): - super(BertPreTrainingHeads, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - def __init__(self, config, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__() - if not isinstance(config, BertConfig): - raise ValueError("Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__)) - self.config = config - - def init_bert_weights(self, module): - """ Initialize the weights. - """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path, - state_dict=None, - cache_dir=None, - from_tf=False, - *inputs, - **kwargs): - """ - Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] - else: - archive_file = pretrained_model_name_or_path - if resolved_archive_file == archive_file: # noqa: F821 - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format(archive_file, - resolved_archive_file)) # noqa: F821 - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: # noqa: F821 - serialization_dir = resolved_archive_file # noqa: F821 - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, # noqa: F821 - tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: # noqa: F821 - archive.extractall(tempdir) - serialization_dir = tempdir - # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = BertConfig.from_json_file(config_file) - logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu' if not get_accelerator().is_available() else None) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - if from_tf: - # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) - return load_tf_weights_in_bert(model, weights_path) - # Load from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, - error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - start_prefix = '' - if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): - start_prefix = 'bert.' - load(model, prefix=start_prefix) - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, - unexpected_keys)) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - return model - - -class BertModel(BertPreTrainedModel): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Params: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_all_encoded_layers=True, - checkpoint_activations=False): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - checkpoint_activations=checkpoint_activations) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - -class BertForPreTraining(BertPreTrainedModel): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: - - the masked language modeling head, and - - the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `masked_lm_labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `masked_lm_labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForPreTraining(config) - masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, args): - super(BertForPreTraining, self).__init__(config) - self.summary_writer = None - if dist.get_rank() == 0: - self.summary_writer = args.summary_writer - self.samples_per_step = dist.get_world_size() * args.train_batch_size - self.sample_count = self.samples_per_step - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def log_summary_writer(self, logs: dict, base='Train'): - if dist.get_rank() == 0: - module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) - for key, log in logs.items(): - self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', log, self.sample_count) - self.sample_count += self.samples_per_step - - def forward(self, batch, log=True): - #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): - input_ids = batch[1] - token_type_ids = batch[3] - attention_mask = batch[2] - masked_lm_labels = batch[5] - next_sentence_label = batch[4] - checkpoint_activations = False - - sequence_output, pooled_output = self.bert(input_ids, - token_type_ids, - attention_mask, - output_all_encoded_layers=False, - checkpoint_activations=checkpoint_activations) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) - total_loss = masked_lm_loss + next_sentence_loss - # if log: - # self.log_summary_writer(logs={'train_loss': total_loss.item()}) - return total_loss - else: - return prediction_scores, seq_relationship_score - - -class BertForMaskedLM(BertPreTrainedModel): - """BERT model with the masked language modeling head. - This module comprises the BERT model followed by the masked language modeling head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - - Outputs: - if `masked_lm_labels` is not `None`: - Outputs the masked language modeling loss. - if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForMaskedLM(config) - masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - masked_lm_labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - prediction_scores = self.cls(sequence_output) - - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - return masked_lm_loss - else: - return prediction_scores - - -class BertForNextSentencePrediction(BertPreTrainedModel): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `next_sentence_label` is not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `next_sentence_label` is `None`: - Outputs the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForNextSentencePrediction(config) - seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - next_sentence_label=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - seq_relationship_score = self.cls(pooled_output) - - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - return next_sentence_loss - else: - return seq_relationship_score - - -class BertForSequenceClassification(BertPreTrainedModel): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels): - super(BertForSequenceClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForMultipleChoice(BertPreTrainedModel): - """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_choices = 2 - - model = BertForMultipleChoice(config, num_choices) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_choices): - super(BertForMultipleChoice, self).__init__(config) - self.num_choices = num_choices - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert(flat_input_ids, - flat_token_type_ids, - flat_attention_mask, - output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, self.num_choices) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - return loss - else: - return reshaped_logits - - -class BertForTokenClassification(BertPreTrainedModel): - """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForTokenClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config, num_labels): - super(BertForTokenClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForQuestionAnswering(BertPreTrainedModel): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - - Outputs: - if `start_positions` and `end_positions` are not `None`: - Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - if `start_positions` or `end_positions` is `None`: - Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - start_positions=None, - end_positions=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits diff --git a/tests/unit/module_inject/test_tp_plan_converter.py b/tests/unit/module_inject/test_tp_plan_converter.py new file mode 100644 index 000000000000..da787153943f --- /dev/null +++ b/tests/unit/module_inject/test_tp_plan_converter.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.module_inject.tp_plan_converter import TPPlanConverter +from deepspeed.module_inject.autotp_config import PartitionType + + +class TestTPPlanConverter: + + def test_wildcard_to_regex_basic(self): + assert TPPlanConverter._wildcard_to_regex("layers.*.q_proj") == r".*layers\..*\.q_proj" + assert TPPlanConverter._wildcard_to_regex("self_attn.*.weight") == r".*self_attn\..*\.weight" + assert TPPlanConverter._wildcard_to_regex("layers.*.self_attn.q_proj") == r".*layers\..*\.self_attn\.q_proj" + + def test_wildcard_to_regex_special_chars(self): + assert TPPlanConverter._wildcard_to_regex("layers.0.q_proj") == r".*layers\.0\.q_proj" + assert TPPlanConverter._wildcard_to_regex("mlp.gate_proj") == r".*mlp\.gate_proj" + + def test_colwise_rowwise_conversion(self): + hf_plan = {"layers.*.q_proj": "colwise", "layers.*.o_proj": "rowwise"} + specs = TPPlanConverter.convert(hf_plan) + + assert len(specs) == 2 + + q_spec = [s for s in specs if "q_proj" in s.patterns[0]][0] + o_spec = [s for s in specs if "o_proj" in s.patterns[0]][0] + + assert q_spec.partition_type == PartitionType.COLUMN + assert o_spec.partition_type == PartitionType.ROW + + def test_pattern_weight_suffix(self): + hf_plan = {"layers.*.q_proj": "colwise"} + specs = TPPlanConverter.convert(hf_plan) + + assert len(specs) == 1 + assert specs[0].patterns[0].endswith(r"\.weight$") + + def test_pattern_weight_suffix_already_present(self): + hf_plan = {"layers.*.q_proj.weight": "colwise"} + specs = TPPlanConverter.convert(hf_plan) + + assert len(specs) == 1 + assert specs[0].patterns[0].endswith(r"\.weight$") + + def test_empty_plan(self): + hf_plan = {} + specs = TPPlanConverter.convert(hf_plan) + + assert len(specs) == 0 + + def test_multiple_patterns(self): + hf_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + specs = TPPlanConverter.convert(hf_plan) + + assert len(specs) == 7 + + colwise_count = sum(1 for s in specs if s.partition_type == PartitionType.COLUMN) + rowwise_count = sum(1 for s in specs if s.partition_type == PartitionType.ROW) + + assert colwise_count == 5 + assert rowwise_count == 2 + + def test_pattern_matches_param_name(self): + import re + + hf_plan = {"layers.*.self_attn.q_proj": "colwise", "layers.*.mlp.down_proj": "rowwise"} + specs = TPPlanConverter.convert(hf_plan) + + q_pattern = [s for s in specs if "q_proj" in s.patterns[0]][0] + down_pattern = [s for s in specs if "down_proj" in s.patterns[0]][0] + + assert re.match(q_pattern.patterns[0], "model.layers.0.self_attn.q_proj.weight") + assert re.match(q_pattern.patterns[0], "model.layers.10.self_attn.q_proj.weight") + assert not re.match(q_pattern.patterns[0], "model.layers.0.self_attn.k_proj.weight") + + assert re.match(down_pattern.patterns[0], "model.layers.5.mlp.down_proj.weight") + + def test_unsupported_style_returns_none(self): + """Unsupported styles cause convert() to return None for fallback.""" + hf_plan = {"layers.*.q_proj": "colwise_rep", "layers.*.o_proj": "rowwise"} + result = TPPlanConverter.convert(hf_plan) + assert result is None + + def test_alternate_prefixes(self): + """Test tp_plan with non-layers prefix""" + hf_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "transformer.layers.*.self_attn.o_proj": "rowwise", + } + + layer_specs = TPPlanConverter.convert(hf_plan) + assert len(layer_specs) == 2 + assert any("model\\.layers" in s.patterns[0] for s in layer_specs) + assert any("transformer\\.layers" in s.patterns[0] for s in layer_specs) + + def test_alternate_projection_names(self): + """Test tp_plan with qkv and Wq/Wk/Wv style names""" + hf_plan = { + "layers.*.attn.qkv": "colwise", + "layers.*.attn.out_proj": "rowwise", + "layers.*.attn.Wq": "colwise", + "layers.*.attn.Wk": "colwise", + "layers.*.attn.Wv": "colwise", + } + + layer_specs = TPPlanConverter.convert(hf_plan) + assert len(layer_specs) == 5 + colwise_count = sum(1 for s in layer_specs if s.partition_type == PartitionType.COLUMN) + rowwise_count = sum(1 for s in layer_specs if s.partition_type == PartitionType.ROW) + + assert colwise_count == 4 # qkv + Wq/Wk/Wv + assert rowwise_count == 1 # out_proj diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 83894b296892..6283007d3e8e 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -6,39 +6,157 @@ import torch import deepspeed import pytest +import gc +import random from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader -from unit.util import required_torch_version +import deepspeed.comm as dist +import deepspeed.moe.sharded_moe as sharded_moe +from deepspeed import get_accelerator +from deepspeed.moe.layer import MoE +from deepspeed.moe.sharded_moe import top1gating, top2gating, topkgating +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param +from deepspeed.utils.torch import required_torch_version + + +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +class TestSimpleMoE(DistributedTest): + world_size = 2 + + def test(self, zero_stage): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage + } + } + # should automatically create moe param groups in deepspeed backend + hidden_dim = 16 + model = SimpleMoEModel(hidden_dim=hidden_dim, ep_size=1) + model, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model) + data_loader = sequence_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() @pytest.mark.parametrize("ep_size", [2, 4]) +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) @pytest.mark.parametrize("use_residual", [True, False]) class TestMoE(DistributedTest): world_size = 4 - def test(self, ep_size, use_residual): - if not required_torch_version(): + def test(self, ep_size, zero_stage, use_residual): + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") - config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage + } + } hidden_dim = 16 # E+D -- ep_size = 2 # E only -- ep_size = 4 model = SimpleMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual) - optimizer = torch.optim.AdamW(params=model.parameters()) - model, _, _, _ = deepspeed.initialize(config=config_dict, - model=model, - optimizer=optimizer, - dist_init_required=False) + param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'} + params = split_params_into_different_moe_groups_for_optimizer(param_group) + optimizer = torch.optim.AdamW(params=params) + model, optimizer, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer, + dist_init_required=False) #dist_init_required=False -- parameterize to True/False? - data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = sequence_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + + def strict_average_tensor(tensor, communication_data_type: torch.dtype): + process_group = optimizer.dp_process_group + curr_size = 0 + pg_offsets = [] + + ipg_bucket = optimizer.ipg_buckets[communication_data_type] + for i, param_idx, param_id in ipg_bucket.params: + param = optimizer.bit16_groups[i][param_idx] + process_group = optimizer.dp_process_group + if ipg_bucket.has_moe_params: + process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( + param) else optimizer.dp_process_group + partition_ids = optimizer.param_to_partition_ids[i][param_id] + # Get all partition ids + their offsets + partition_offsets = [] + for partition_id in partition_ids: + offset = optimizer.grad_start_offset[i][partition_id][param_id] + partition_offsets.append(offset) + partition_offsets.sort() + # Calculate rank and offsets for grad slices + for idx, offset in enumerate(partition_offsets): + # Calculate numel for grad slice depending on partition location + if idx == len(partition_offsets) - 1: + # Last partition_id uses its own offset + numel = param.numel() - offset + else: + # Set numel to next partition's offset + numel = partition_offsets[idx + 1] - offset + pg_offsets.append((curr_size, process_group)) + curr_size += numel + + def strict_narrow(dim, start, length): + lo, hi = 0, len(pg_offsets) - 1 + while lo < hi: + mi = lo + (hi - lo) // 2 + if pg_offsets[mi][0] >= start: + hi = mi + else: + lo = mi + 1 + curr_slice, reduce_process_group = lo, pg_offsets[lo][1] + while curr_slice < len(pg_offsets) and start + length > pg_offsets[curr_slice][0]: + assert reduce_process_group == pg_offsets[curr_slice][ + 1], "reduce process_group does not match the parameter's process_group" + curr_slice += 1 + return orig_narrow(dim, start, length) # real call + + orig_narrow, tensor.narrow = tensor.narrow, strict_narrow + type(optimizer).average_tensor(optimizer, tensor, communication_data_type) # real call + tensor.narrow = orig_narrow + + if "average_tensor" in dir(optimizer): + optimizer.average_tensor = strict_average_tensor for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + gc.collect() # Must do this or we get a memory leak in this test @pytest.mark.parametrize("ep_size, use_residual", [(2, True), (2, False)]) @@ -46,7 +164,7 @@ class TestPRMoE(DistributedTest): world_size = 4 def test(self, ep_size, use_residual): - if not required_torch_version(): + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} @@ -61,9 +179,287 @@ def test(self, ep_size, use_residual): optimizer=optimizer, dist_init_required=False) - data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = sequence_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + +class TestTopk(DistributedTest): + world_size = 2 + + def test(self): + device = get_accelerator().current_device_name() + if dist.get_rank() == 0: + logits = torch.rand(2, 2, device=device) + elif dist.get_rank() == 1: + logits = torch.rand(10, 2, device=device) + + output = top1gating(logits=logits, + capacity_factor=1, + min_capacity=0, + used_token=None, + noisy_gate_policy=None, + drop_tokens=False, + use_rts=True, + use_tutel=False) + + +class TestMoESingleton(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("ep_size, expected_calls", [(1, 0), (2, 2)], ids=["single", "multi"]) + def test_all_to_all(self, monkeypatch, ep_size, expected_calls): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + config_dict = {"train_micro_batch_size_per_gpu": 1, "steps_per_print": 1} + hidden_dim = 8 + expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim)) + model = MoE(hidden_size=hidden_dim, expert=expert, num_experts=2, ep_size=ep_size, k=1, min_capacity=0) + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer, + dist_init_required=False) + + all_to_all_calls = [] + + def counted_all_to_all(group, input): + all_to_all_calls.append((group, input.shape)) + return input + + monkeypatch.setattr(sharded_moe._AllToAll, "apply", counted_all_to_all) + + x = torch.randn(1, 4, hidden_dim, device=model.device, requires_grad=True) + output, l_aux, _ = model(x) + assert len(all_to_all_calls) == expected_calls + + loss = output.float().sum() + l_aux.float() + model.backward(loss) + assert len(all_to_all_calls) == expected_calls + assert x.grad is not None + assert any(param.grad is not None for param in model.module.parameters()) + + @pytest.mark.parametrize("gate_fn, capacity_args", [(top1gating, (1, 0)), (top2gating, (1, 0)), + (topkgating, (3, 1, 0))], + ids=["top1", "top2", "topk"]) + @pytest.mark.parametrize("ep_world_size, expected_calls", [(1, 0), (2, 1)], ids=["single", "multi"]) + def test_capacity(self, monkeypatch, gate_fn, capacity_args, ep_world_size, expected_calls): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + ep_group = None + if ep_world_size == 1: + for rank in range(dist.get_world_size()): + group = dist.new_group([rank]) + if rank == dist.get_rank(): + ep_group = group + else: + ep_group = dist.new_group(list(range(dist.get_world_size()))) + + all_reduce_calls = [] + original_all_reduce = sharded_moe.dist.all_reduce + + def counted_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None): + all_reduce_calls.append((tensor, op, group)) + return original_all_reduce(tensor, op=op, group=group) + + monkeypatch.setattr(sharded_moe.dist, "all_reduce", counted_all_reduce) + + device = get_accelerator().current_device_name() + logits = torch.randn(8, 4, device=device) + gate_fn(logits, *capacity_args, drop_tokens=False, ep_group=ep_group) + + assert len(all_reduce_calls) == expected_calls + if all_reduce_calls: + _, op, group = all_reduce_calls[0] + assert op == dist.ReduceOp.MAX + assert group is ep_group + + def test_no_ep_group(self, monkeypatch): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + def fail_collective(*args, **kwargs): + raise AssertionError("ep_group=None should not enter expert-parallel collective code") + + monkeypatch.setattr(sharded_moe.dist, "get_world_size", fail_collective) + monkeypatch.setattr(sharded_moe.dist, "all_reduce", fail_collective) + + device = get_accelerator().current_device_name() + logits = torch.randn(8, 4, device=device) + top2gating(logits, 1, 0, drop_tokens=False, ep_group=None, top2_2nd_expert_sampling=False) + + +class TestTopkGate(DistributedTest): + + def test(self): + + def check_equal(logits, cap, sparse_truth, res): + m, n = logits.shape + dispatch_mask_truth = torch.zeros(m, n, cap) + i, j, k = sparse_truth.t() + dispatch_mask_truth[i, j, k] = 1 + assert (torch.equal(dispatch_mask_truth, res)) + + #s=4 e=4 topk=2 cap=2(s*topk/e) + logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5], + [0.1, 0.11, 0.7, 0.8]]) + logits *= dist.get_rank() + 1 + probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]]) + check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1], + [3, 2, 1]]) + position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits, 2, position_sec_sparse, position_dispatch_res) + + #s=4 e=6 topk=3 cap=2(s*topk/e) + logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034], + [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450], + [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068], + [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]]) + logits2 *= dist.get_rank() + 1 + + #top3 full mask #prob_mask #postion_mask + #0 0 1 0 1 1 #0 0 1 0 1 0 #0 0 1 0 1 1 + #0 1 0 1 0 1 #0 1 0 1 0 1 #0 1 0 1 0 1 + #0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0 + #1 1 0 0 0 1 #1 0 0 0 0 1 #1 0 0 0 0 0 + probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [1, 1, 0], [1, 3, 0], [1, 5, 0], [2, 1, 1], [2, 2, 1], + [2, 3, 1], [3, 0, 0], [3, 5, 1]]) + check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1], + [2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]]) + position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits2, 2, position_sec_sparse, position_dispatch_res) + + #s=4 e=4 topk=2 drop_tokens=False + logits3 = torch.tensor([[0.95, 0.85, 0.90, 0.80], [0.70, 0.65, 0.75, 0.60], [0.50, 0.55, 0.45, 0.40], + [0.35, 0.30, 0.25, 0.20]]) + logits3 *= dist.get_rank() + 1 + dispatch_res = topkgating(logits3, 2, 1, min_capacity=1, drop_tokens=False)[2] + sec_sparse = torch.tensor([[0, 0, 0], [0, 2, 0], [1, 0, 1], [1, 2, 1], [2, 0, 2], [2, 1, 0], [3, 0, 3], + [3, 1, 1]]) + check_equal(logits3, 4, sec_sparse, dispatch_res) + + +class TestExpertWeightGradWithZero(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("zero_stage", [0, 1, 2]) + def test(self, zero_stage): + + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + def seed_everything(seed=11): + random.seed(seed) + torch.manual_seed(seed) + get_accelerator().manual_seed(seed) + get_accelerator().manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def get_state_dict_ep2(state_dict): + """ + convert state_dict from EP=1 to EP=2 + """ + rank = int(deepspeed.comm.get_rank()) + ep_state_dict = dict() + dst_sub_key = "deepspeed_moe.experts.deepspeed_experts.0" + src_sub_key = f"deepspeed_moe.experts.deepspeed_experts.{rank}" + for moe_layer in ["moe_1", "moe_2"]: + for mlp_in_moe in [0, 1]: + dst_key = f"{moe_layer}.{dst_sub_key}.{mlp_in_moe}" + src_key = f"{moe_layer}.{src_sub_key}.{mlp_in_moe}" + ep_state_dict[f"{dst_key}.weight"] = state_dict[f"{src_key}.weight"].detach().clone() + ep_state_dict[f"{dst_key}.bias"] = state_dict[f"{src_key}.bias"].detach().clone() + + for key in state_dict.keys(): + if "deepspeed_moe.experts.deepspeed_experts" not in key: + ep_state_dict[key] = state_dict[key].detach().clone() + return ep_state_dict + + def get_models(hidden_dim): + model_ep1 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=1, use_rts=False) + model_ep2 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=2, use_rts=False) + + state_dict_ep1 = model_ep1.state_dict() + state_dict_ep2 = get_state_dict_ep2(state_dict_ep1) + model_ep2.load_state_dict(state_dict_ep2) + + model_ep1, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep1) + model_ep2, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep2) + + return model_ep1, model_ep2 + + def extract_expert_grad(model, expert_id): + + def _get_weight_bias(experts): + return ([deepspeed.utils.safe_get_full_grad(expert[0].weight) + for expert in experts][expert_id].detach().clone(), + [deepspeed.utils.safe_get_full_grad(expert[0].bias) + for expert in experts][expert_id].detach().clone(), + [deepspeed.utils.safe_get_full_grad(expert[1].weight) + for expert in experts][expert_id].detach().clone(), + [deepspeed.utils.safe_get_full_grad(expert[1].bias) + for expert in experts][expert_id].detach().clone()) + + return (*_get_weight_bias(model.moe_1.deepspeed_moe.experts.deepspeed_experts), + *_get_weight_bias(model.moe_2.deepspeed_moe.experts.deepspeed_experts)) + + seed_everything() + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.1, + } + }, + "zero_optimization": { + "stage": zero_stage + } + } + + hidden_dim = 4 + total_samples = 2 + rank = deepspeed.comm.get_rank() + model_ep1, model_ep2 = get_models(hidden_dim) + + data_loader = sequence_dataloader(model=model_ep1, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model_ep1.device, + dtype=torch.float32) + expert_weight_grad_ep1 = [] + expert_weight_grad_ep2 = [] + for batch in data_loader: + loss_ep1 = model_ep1(batch[0], batch[1]) + loss_ep2 = model_ep2(batch[0], batch[1]) + + model_ep1.backward(loss_ep1) + model_ep2.backward(loss_ep2) + + expert_weight_grad_ep1.extend(extract_expert_grad(model_ep1, rank)) + expert_weight_grad_ep2.extend(extract_expert_grad(model_ep2, 0)) + + model_ep1.step() + model_ep2.step() + + assert len(expert_weight_grad_ep1) == len(expert_weight_grad_ep2) + for grad_from_ep1, grad_from_ep2 in zip(expert_weight_grad_ep1, expert_weight_grad_ep2): + assert torch.allclose(grad_from_ep1, grad_from_ep2, atol=0, rtol=1e-4) diff --git a/tests/unit/moe/test_moe_tp.py b/tests/unit/moe/test_moe_tp.py index e53e0dc1effc..eb4668015c01 100644 --- a/tests/unit/moe/test_moe_tp.py +++ b/tests/unit/moe/test_moe_tp.py @@ -7,7 +7,7 @@ import deepspeed import pytest from unit.common import DistributedTest -from unit.util import required_torch_version +from deepspeed.utils.torch import required_torch_version from deepspeed.moe.layer import MoE @@ -58,7 +58,7 @@ class TestMOETensorParallel(DistributedTest): def test(self, ep_size, tp_size, enable_expert_tp, use_residual): # TODO: replace this with a true parallel mlp in the future # and run convergence tests - if not required_torch_version(): + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} diff --git a/tests/unit/monitor/test_monitor.py b/tests/unit/monitor/test_monitor.py index 3e04bebfb6c1..d4b3cf43921d 100644 --- a/tests/unit/monitor/test_monitor.py +++ b/tests/unit/monitor/test_monitor.py @@ -7,10 +7,14 @@ from deepspeed.monitor.wandb import WandbMonitor from deepspeed.monitor.csv_monitor import csvMonitor from deepspeed.monitor.config import DeepSpeedMonitorConfig +from deepspeed.monitor.comet import CometMonitor from unit.common import DistributedTest +from unittest.mock import Mock, patch from deepspeed.runtime.config import DeepSpeedConfig +import deepspeed.comm as dist + class TestTensorBoard(DistributedTest): world_size = 2 @@ -97,3 +101,66 @@ def test_empty_csv_monitor(self): assert csv_monitor.enabled == defaults.enabled assert csv_monitor.output_path == defaults.output_path assert csv_monitor.job_name == defaults.job_name + + +class TestCometMonitor(DistributedTest): + world_size = 2 + + def test_comet_monitor(self): + import comet_ml + mock_experiment = Mock() + mock_start = Mock(return_value=mock_experiment) + + config_dict = { + "train_batch_size": 2, + "comet": { + "enabled": True, + "samples_log_interval": 42, + "workspace": "some-workspace", + "project": "some-project", + "api_key": "some-api-key", + "experiment_name": "some-experiment-name", + "experiment_key": "some-experiment-key", + "mode": "get_or_create", + "online": True + } + } + + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + assert comet_monitor.enabled is True + assert comet_monitor.samples_log_interval == 42 + + # experiment should be initialized via comet_ml.start only if rank == 0 + if dist.get_rank() == 0: + mock_start.assert_called_once_with( + api_key="some-api-key", + project="some-project", + workspace="some-workspace", + experiment_key="some-experiment-key", + mode="get_or_create", + online=True, + ) + + mock_experiment.set_name.assert_called_once_with("some-experiment-name") + assert comet_monitor.experiment is mock_experiment + else: + mock_start.assert_not_called() + + def test_empty_comet(self): + import comet_ml + mock_start = Mock() + + config_dict = {"train_batch_size": 2, "comet": {}} + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + defaults = DeepSpeedMonitorConfig().comet + assert comet_monitor.enabled == defaults.enabled + assert comet_monitor.samples_log_interval == defaults.samples_log_interval + mock_start.assert_not_called() diff --git a/tests/unit/multi_output_model.py b/tests/unit/multi_output_model.py index e84215fb4e95..d7a5f9a46b97 100644 --- a/tests/unit/multi_output_model.py +++ b/tests/unit/multi_output_model.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from .common import preferred_dtype class MultiOutputModel(torch.nn.Module): @@ -28,8 +29,11 @@ def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, ta batch_size = model.train_micro_batch_size_per_gpu() train_data = [ - torch.full(size=(total_samples, hidden_dim), fill_value=x, device=device, dtype=torch.half, requires_grad=True) - for x in inputs + torch.full(size=(total_samples, hidden_dim), + fill_value=x, + device=device, + dtype=preferred_dtype(), + requires_grad=True) for x in inputs ] train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets] diff --git a/tests/unit/ops/accelerators/test_accelerator_backward.py b/tests/unit/ops/accelerators/test_accelerator_backward.py index 4c5719bb9c1e..4b1b392e933a 100644 --- a/tests/unit/ops/accelerators/test_accelerator_backward.py +++ b/tests/unit/ops/accelerators/test_accelerator_backward.py @@ -8,17 +8,18 @@ import pytest import random import copy +import os +import deepspeed from torch import nn from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.accelerator import get_accelerator from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln from unit.modelingpreln import BertEncoder as BertEncoderPreln -from unit.common import DistributedTest +from unit.common import DistributedTest, is_rocm_pytorch +from deepspeed.ops.op_builder import TransformerBuilder -#if not deepspeed.ops.__installed_ops__['transformer']: -#pytest.skip( -# "transformer kernels are temporarily disabled because of unexplained failures", -# allow_module_level=True) +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) def check_equal(first, second, atol=1e-2, verbose=False): @@ -243,9 +244,7 @@ def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): check_equal(base_grads, ds_grads, atol=atol, verbose=verbose) -#test_backward[3-1024-120-16-24-True-True-0.05] -#test_backward[3-1024-52-16-24-False-True-0.2] -# 3-128-54-2-24-False-True-0.2 +# NOTE: Keep these different params as they have helped find divergence in behavior between AMD and NVIDIA. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ (64,160,128,2,24,False,True, 0.2), @@ -253,17 +252,16 @@ def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): (8,1600,128,25,3,True,True, 0.05), (8,160,128,2,3,True,True, 0.1), (8,1600,128,2,3,True,True, 0.05), - #(3,1024,119,16,24,True,False, 0.05), - #(3,1024,115,16,24,True,True, 0.05), - #(1024,128,10,2,2,False,False, 0.1), - #(3,1024,52,16,24,False,True, 0.2), - #(3,128,51,2,24,False,False, 0.1), - #(3,128,54,2,24,False,True, 0.2), ]) # yapf: disable class TestCUDABackward(DistributedTest): world_size = 1 + if is_rocm_pytorch(): + #This is to flush denorms in forward pass. Please refer to https://github.com/pytorch/pytorch/blob/main/docs/source/notes/numerical_accuracy.rst#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + os.environ['ROCBLAS_INTERNAL_FP16_ALT_IMPL'] = '1' - def test_backward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol): + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[TransformerBuilder.NAME], + reason="TransformerBuilder has not been implemented on this system.") + def test_backward(self, is_preln, use_fp16, batch_size, hidden_size, seq_len, heads, num_layers, atol): # Only run fp16 test cases on devices with FP16 capability. if not get_accelerator().is_fp16_supported() and (use_fp16 is True or is_preln is False): return @@ -282,38 +280,3 @@ def test_backward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_ ds_config.fp16 = use_fp16 run_backward(ds_config, seq_len, atol=atol, verbose=True) - - # [ - # (3,1024,128,16,24,True,False, 0.07), - # (3,1024,128,16,24,True,True, 0.05), - # (3,1024,128,16,24,False,False, 0.1), - # (3,1024,128,16,24,False,True, 0.2), - # ]) # yapf: disable - #def test_backward_stochastic(batch_size, - # hidden_size, - # seq_len, - # heads, - # num_layers, - # is_preln, - # use_fp16, - # atol): - # # Only run fp16 test cases on devices with FP16 capability. - # if not get_accelerator().is_fp16_supported() and use_fp16 is True: - # return - # - # ds_config = DeepSpeedTransformerConfig() - # ds_config.layer_id = None - # ds_config.batch_size = batch_size - # ds_config.hidden_size = hidden_size - # ds_config.intermediate_size = 4 * hidden_size - # ds_config.max_seq_length = seq_len - # ds_config.heads = heads - # ds_config.attn_dropout_ratio = 0.0 - # ds_config.hidden_dropout_ratio = 0.0 - # ds_config.num_hidden_layers = num_layers - # ds_config.pre_layer_norm = is_preln - # ds_config.initializer_range = 0.02 - # ds_config.fp16 = use_fp16 - # ds_config.stochastic_mode = True - # - # run_backward(ds_config, atol=atol) diff --git a/tests/unit/ops/accelerators/test_accelerator_forward.py b/tests/unit/ops/accelerators/test_accelerator_forward.py index 7c5580e4676a..e2f4ac177f1b 100644 --- a/tests/unit/ops/accelerators/test_accelerator_forward.py +++ b/tests/unit/ops/accelerators/test_accelerator_forward.py @@ -8,12 +8,17 @@ import pytest import random import copy +import deepspeed from torch import nn from unit.modelingpreln import BertEncoder as BertEncoderPreln from unit.modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest +from deepspeed.ops.op_builder import TransformerBuilder + +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) def check_equal(first, second, atol=1e-2, verbose=False): @@ -224,6 +229,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ]) # yapf: disable class TestCUDAForward(DistributedTest): world_size = 1 + reuse_dist_env = True def test_forward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16): # Only run fp16 test cases on devices with FP16 capability. @@ -256,6 +262,8 @@ def test_forward(self, batch_size, hidden_size, seq_len, heads, num_layers, is_p class TestCUDAForwardSmallBatchSize(DistributedTest): world_size = 1 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[TransformerBuilder.NAME], + reason="TransformerBuilder has not been implemented on this system.") def test_forward_with_small_bsz(self, batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16): # Only run fp16 test cases on devices with FP16 capability. diff --git a/tests/unit/ops/adagrad/test_cpu_adagrad.py b/tests/unit/ops/adagrad/test_cpu_adagrad.py index d38d42217872..0c675ecd6a85 100644 --- a/tests/unit/ops/adagrad/test_cpu_adagrad.py +++ b/tests/unit/ops/adagrad/test_cpu_adagrad.py @@ -18,8 +18,8 @@ def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() if verbose: print("x = {}".format(x.flatten())) print("y = {}".format(y.flatten())) @@ -34,17 +34,7 @@ class TestCPUAdagrad(DistributedTest): init_distributed = False set_dist_env = False - @pytest.mark.parametrize('model_size', - [ - (64), - (22), - (55), - (127), - (1024), - (1048576), - (30000000), - ]) # yapf: disable - def test_cpu_adagrad_opt(self, model_size): + def test_cpu_adagrad_opt(self, model_size=64): device = 'cpu' rng_state = torch.get_rng_state() param = torch.nn.Parameter(torch.randn(model_size, device=device)) @@ -65,14 +55,7 @@ def test_cpu_adagrad_opt(self, model_size): check_equal(param, param1, atol=1e-2, verbose=True) - - @pytest.mark.parametrize('model_size,vocabulary_size,dim', - [ - (16 * 2, 16 * 4, 16), - (16 * 32, 16 * 256, 16), - (16 * 256, 16 * 16384, 16), - ]) # yapf: disable - def test_cpu_adagrad_opt_sparse_embedding(self, model_size, vocabulary_size, dim): + def test_cpu_adagrad_opt_sparse_embedding(self, model_size=32, vocabulary_size=64, dim=16): device = 'cpu' rng_state = torch.get_rng_state() diff --git a/tests/unit/ops/adam/test_adamw.py b/tests/unit/ops/adam/test_adamw.py index 8b6f8101cb77..3b1b088766a5 100644 --- a/tests/unit/ops/adam/test_adamw.py +++ b/tests/unit/ops/adam/test_adamw.py @@ -11,7 +11,10 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam from unit.common import DistributedTest from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) # yapf: disable #'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer adam_configs = [["AdamW", False, False, False, (FusedAdam, True)], @@ -36,6 +39,7 @@ adam_configs) class TestAdamConfigs(DistributedTest): world_size = 1 + reuse_dist_env = True def test(self, optimizer, diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index a48b7c7f2839..003a6f8f6a46 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -11,7 +11,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder from unit.common import DistributedTest if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: @@ -21,8 +21,8 @@ def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() + x = first.detach().float().numpy() + y = second.detach().float().numpy() print("ATOL", atol) if verbose: print("x = {}".format(x.flatten())) @@ -43,7 +43,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) -@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize('model_size', [ (64), @@ -55,13 +55,19 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): ]) # yapf: disable class TestCPUAdam(DistributedTest): world_size = 1 + reuse_dist_env = True requires_cuda_env = False if not get_accelerator().is_available(): init_distributed = False set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test_fused_adam_equal(self, dtype, model_size): + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"dtype {dtype} not supported in current accelerator") + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") @@ -88,6 +94,8 @@ def test_fused_adam_equal(self, dtype, model_size): def test_torch_adamw_equal(self, dtype, model_size): if get_accelerator().is_available(): + if dtype == torch.half: + pytest.skip("torch.optim.AdamW with half precision inf/nan output.") if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") ref_param_device = get_accelerator().device_name() @@ -96,20 +104,75 @@ def test_torch_adamw_equal(self, dtype, model_size): pytest.skip("torch.optim.AdamW with half precision only supported in CUDA environments.") ref_param_device = 'cpu' - from deepspeed.ops.adam import DeepSpeedCPUAdam + from deepspeed.ops.adam import DeepSpeedCPUAdam + + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + cpu_param = torch.nn.Parameter(cpu_data) + ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) + + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + ref_optimizer = torch.optim.AdamW([ref_param]) + + _compare_optimizers(model_size=model_size, + param1=cpu_param, + optimizer1=cpu_optimizer, + param2=ref_param, + optimizer2=ref_optimizer) + + +class TestCPUAdamBf16OptimizerStates(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('model_size', [64, 1024]) + def test_bf16_optimizer_states_dtype(self, model_size): + """fp32_optimizer_states=False keeps the Adam moments in the bf16 parameter precision.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + param = torch.nn.Parameter(torch.randn(model_size, device='cpu', dtype=torch.bfloat16)) + optimizer = DeepSpeedCPUAdam([param], fp32_optimizer_states=False) + param.grad = torch.randn(model_size, device='cpu', dtype=torch.bfloat16) + optimizer.step() + + state = optimizer.state[param] + assert state['exp_avg'].dtype == torch.bfloat16 + assert state['exp_avg_sq'].dtype == torch.bfloat16 + assert state['exp_avg'].device == torch.device('cpu') + assert state['exp_avg_sq'].device == torch.device('cpu') + + @pytest.mark.parametrize('model_size', [64, 1024]) + def test_bf16_optimizer_states_match_fp32(self, model_size): + """bf16 moments should track fp32 moments within bf16 tolerance over several steps.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + torch.manual_seed(0) + base = torch.randn(model_size, device='cpu', dtype=torch.float32).to(torch.bfloat16) + param_fp32_states = torch.nn.Parameter(base.clone()) + param_bf16_states = torch.nn.Parameter(base.clone()) - cpu_data = torch.randn(model_size, device='cpu').to(dtype) - cpu_param = torch.nn.Parameter(cpu_data) - ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device)) + opt_fp32_states = DeepSpeedCPUAdam([param_fp32_states], fp32_optimizer_states=True) + opt_bf16_states = DeepSpeedCPUAdam([param_bf16_states], fp32_optimizer_states=False) - cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) - ref_optimizer = torch.optim.AdamW([ref_param]) + for _ in range(10): + grad = torch.randn(model_size, device='cpu', dtype=torch.bfloat16) + param_fp32_states.grad = grad.clone() + param_bf16_states.grad = grad.clone() + opt_fp32_states.step() + opt_bf16_states.step() - _compare_optimizers(model_size=model_size, - param1=cpu_param, - optimizer1=cpu_optimizer, - param2=ref_param, - optimizer2=ref_optimizer) + assert opt_fp32_states.state[param_fp32_states]['exp_avg'].dtype == torch.float32 + assert opt_bf16_states.state[param_bf16_states]['exp_avg'].dtype == torch.bfloat16 + + # bf16 moments round every Adam update to an 8-bit mantissa, so over 10 steps they + # diverge from fp32 moments more than the same-precision comparison in _compare_optimizers + # (1e-2). A wider 5% band keeps this stable while still catching gross errors; the dtype + # assertions above guard the precision itself. Norm comparison follows _compare_optimizers. + tolerance = param_fp32_states.float().norm().detach().numpy() * 5e-2 + check_equal(param_fp32_states.float().norm(), param_bf16_states.float().norm(), atol=tolerance) class TestCPUAdamGPUError(DistributedTest): @@ -124,3 +187,239 @@ def test_cpu_adam_gpu_error(self): param.grad = torch.randn(model_size, device=device) with pytest.raises(AssertionError): optimizer.step() + + +class TestCPUAdamSubgroup(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + @pytest.mark.parametrize('model_size', [64, 128, 1024]) + def test_step_subgroup_basic(self, dtype, model_size): + """Test basic functionality of step_subgroup method.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + # Create parameters + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + # Set gradient + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # Store initial parameter values + initial_param = param.data.clone() + + # Test step_subgroup with subgroup_id=0 + subgroup_id = 0 + optimizer.step_subgroup(subgroup_id) + + # Verify parameter was updated + assert not torch.equal(param.data, initial_param), "Parameters should be updated after step_subgroup" + + # Verify optimizer state was created for subgroup + assert subgroup_id in optimizer.state, "Optimizer state should be created for subgroup" + assert optimizer.state[subgroup_id]['step'] == 1, "Step count should be 1" + assert 'exp_avg' in optimizer.state[subgroup_id], "exp_avg should be in state" + assert 'exp_avg_sq' in optimizer.state[subgroup_id], "exp_avg_sq should be in state" + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_step_subgroup_multiple_calls(self, dtype): + """Test multiple calls to step_subgroup increment step count correctly.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + + # Perform multiple steps + for step in range(1, 4): + param.grad = torch.randn(model_size, device='cpu').to(dtype) + optimizer.step_subgroup(subgroup_id) + + # Verify step count increments + assert optimizer.state[subgroup_id]['step'] == step, f"Step count should be {step}" + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_rollback_subgroup_basic(self, dtype): + """Test basic functionality of rollback_subgroup method.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # First, perform a step to initialize state + optimizer.step_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == 1 + + # Store parameter state after step + param_after_step = param.data.clone() + exp_avg_after_step = optimizer.state[subgroup_id]['exp_avg'].clone() + exp_avg_sq_after_step = optimizer.state[subgroup_id]['exp_avg_sq'].clone() + + # Now rollback + optimizer.rollback_subgroup(subgroup_id) + + # Verify step count decremented + assert optimizer.state[subgroup_id]['step'] == 0, "Step count should be decremented after rollback" + + def test_rollback_subgroup_uninitialized_error(self): + """Test that rollback_subgroup raises error for uninitialized subgroup.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + # Try to rollback uninitialized subgroup + with pytest.raises(RuntimeError, match="Cannot rollback optimizer state for sub_group_id 0"): + optimizer.rollback_subgroup(0) + + def test_rollback_subgroup_zero_step_error(self): + """Test that rollback_subgroup raises error when step count is already 0.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu') + + # Initialize state by doing one step + optimizer.step_subgroup(subgroup_id) + + # Rollback once (step should become 0) + optimizer.rollback_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == 0 + + # Try to rollback again - should raise error + with pytest.raises(RuntimeError, match="Cannot rollback sub_group_id 0: step count is 0"): + optimizer.rollback_subgroup(subgroup_id) + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_step_rollback_sequence(self, dtype): + """Test sequence of step_subgroup and rollback_subgroup operations.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # Perform multiple steps + for step in range(1, 4): + optimizer.step_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == step + + # Rollback steps one by one + for step in range(2, -1, -1): + optimizer.rollback_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == step + + def test_multiple_subgroups(self): + """Test that different subgroups maintain independent state.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + param.grad = torch.randn(model_size, device='cpu') + + # Step different subgroups + optimizer.step_subgroup(0) + optimizer.step_subgroup(1) + optimizer.step_subgroup(0) # Step subgroup 0 again + + # Verify independent step counts + assert optimizer.state[0]['step'] == 2, "Subgroup 0 should have step count 2" + assert optimizer.state[1]['step'] == 1, "Subgroup 1 should have step count 1" + + # Rollback subgroup 0 only + optimizer.rollback_subgroup(0) + assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented" + assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged" + + def test_step_subgroup_same_step_idempotent_across_subgroups(self): + """Repeated same-step subgroup updates should remain bit-identical.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step_subgroup(0) + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step_subgroup(1) + + assert optimizer.state[0]['step'] == logical_step + assert optimizer.state[1]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg']) + assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq']) + + def test_step_same_step_idempotent_across_param_keys(self): + """Repeated optimizer.step() with swapped param keys should be deterministic.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step() + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step() + + assert optimizer.state[param_a]['step'] == logical_step + assert optimizer.state[param_b]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg']) + assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq']) diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py new file mode 100644 index 000000000000..652090d5b9d5 --- /dev/null +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +import pytest + +from cpuinfo import get_cpu_info + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("hybrid-adam is not compatible", allow_module_level=True) + +pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().float().numpy() + y = second.detach().float().numpy() + print("ATOL", atol) + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) +@pytest.mark.parametrize('model_size', [8, 16]) +class TestHybridAdam(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") + def test_hybrid_adam_equal(self, dtype, model_size): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ref_data = torch.randn(model_size).to(dtype) + total_data = ref_data.clone().detach() + + ref_param = torch.nn.Parameter(ref_data) + ref_optimizer = DeepSpeedCPUAdam([ref_param]) + + cpu_data, cuda_data = total_data.chunk(2) + cpu_param = torch.nn.Parameter(cpu_data) + cuda_param = torch.nn.Parameter(cuda_data.to(get_accelerator().device_name())) + + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + cuda_optimizer = FusedAdam([cuda_param]) + + ref_grad = torch.randn(model_size).to(dtype) + cpu_grad, cuda_grad = ref_grad.clone().detach().chunk(2) + + ref_param.grad = ref_grad + cpu_param.grad = cpu_grad + cuda_param.grad = cuda_grad.to(get_accelerator().device_name()) + + ref_optimizer.step() + cpu_optimizer.step() + cuda_optimizer.step() + + cuda_param_copy = cuda_param.cpu() + + total_param = torch.cat((cpu_param, cuda_param_copy)) + + check_equal(ref_param, total_param) diff --git a/tests/unit/ops/adam/test_zf_torch_adam.py b/tests/unit/ops/adam/test_zf_torch_adam.py new file mode 100644 index 000000000000..faa44fe9f853 --- /dev/null +++ b/tests/unit/ops/adam/test_zf_torch_adam.py @@ -0,0 +1,187 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import numpy as np +from torch.nn import Parameter +from deepspeed.ops.adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3 + + +def make_param(Opt, shape, selected_indices=None): + param = Parameter(torch.randn(*shape)) + + if Opt is ZenFlowSelectiveAdamW_stage3: + if param.dim() == 2: + param.ds_shape = (param.shape[1], param.shape[0]) + param.ds_tensor = param.clone().T.contiguous().view(-1) + else: + param.ds_shape = tuple(param.shape) + param.ds_tensor = param.clone() + + param.complete_column_offset = 0 + param.complete_numel = param.numel() + param.group_id = 0 + + if selected_indices is not None: + param.selected_indices = selected_indices + if param.dim() == 2: + param.selected_grad = torch.randn( + param.shape[0], len(selected_indices)) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn( + len(selected_indices), param.ds_shape[1]) + param.temp_selected_param = param.data[:, selected_indices].clone( + ) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[selected_indices, :].clone() + else: + param.selected_grad = torch.randn_like(param.data) + param.temp_selected_param = param.data.clone() + return param + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_init_methods(Opt): + opt1 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False) + assert opt1.step == opt1._step_without_offload + opt2 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True) + assert opt2.step == opt2._step_with_offload + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_step_without_offload(Opt): + param = make_param(Opt, (4, 6), torch.tensor([1, 3, 4])) + param.requires_grad_(True) + opt = Opt([param], lr=1e-3, offload=False) + + old_selected = param.data[:, param.selected_indices].clone( + ) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[param.selected_indices, :].clone() + opt.step() + new_selected = param.data[:, param. + selected_indices] if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view( + param.ds_shape)[param.selected_indices, :] + diff_norm = (old_selected - new_selected).abs().sum().item() + + assert diff_norm > 1e-5, "param was not updated" + assert param.temp_selected_param is None + assert param.selected_grad is None + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_step_with_offload_bucket_flush(Opt): + param1 = make_param(Opt, (2, 4), torch.tensor([1, 2])) + param2 = make_param(Opt, (2, 4), torch.tensor([0, 3])) + + param1.exp_avg = torch.zeros_like(param1.temp_selected_param) + param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param) + param1.exp_avg_cpu_data = param1.exp_avg.clone().cpu() + param1.exp_avg_sq_cpu_data = param1.exp_avg_sq.clone().cpu() + + param2.exp_avg = torch.zeros_like(param2.temp_selected_param) + param2.exp_avg_sq = torch.zeros_like(param2.temp_selected_param) + param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu() + param2.exp_avg_sq_cpu_data = param2.exp_avg_sq.clone().cpu() + + opt = Opt([param1, param2], lr=1e-3, offload=True, bucket_size=1) + opt.step() + assert param1.temp_selected_param is None + assert param2.temp_selected_param is None + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_clear_selected_mv(Opt): + param = make_param(Opt, (2, 4), torch.tensor([0, 2])) + opt = Opt([param], lr=1e-3, offload=False) + opt.step() + state = opt.state[param] + assert "exp_avg" in state + opt.clear_selected_mv() + assert state["exp_avg"].abs().sum() == 0 + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_group_step_without_offload(Opt): + param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3])) + opt = Opt([param], lr=1e-3, offload=False) + group_to_paramlist = {0: [param]} if not Opt is ZenFlowSelectiveAdamW_stage3 else [param] + opt.group_step(group_to_paramlist) + assert param.selected_grad is None + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_group_step_with_offload(Opt): + param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3])) + opt = Opt([param], lr=1e-3, offload=True) + + state = opt.state.setdefault(param, {}) + state["step"] = torch.zeros((), dtype=param.dtype, device=param.device) + param.exp_avg = torch.zeros_like(param.data[:, param.selected_indices]) + param.exp_avg_sq = torch.zeros_like(param.data[:, param.selected_indices]) + param.exp_avg_cpu_data = param.exp_avg.clone().cpu() + param.exp_avg_sq_cpu_data = param.exp_avg_sq.clone().cpu() + + group_to_paramlist = {0: [param]} if Opt is not ZenFlowSelectiveAdamW_stage3 else [param] + opt.group_step(group_to_paramlist) + assert param.selected_grad is None + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_1d_param_support(Opt): + param = make_param(Opt, (10, ), torch.arange(10)) + opt = Opt([param], lr=1e-3, offload=False) + opt.step() + assert param.temp_selected_param is None + assert param.selected_grad is None + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_state_increment(Opt): + param = make_param(Opt, (2, 4), torch.arange(4)) + + opt = Opt([param], lr=1e-3, offload=False) + opt.step() + step1 = opt.state[param]['step'].item() + + param.selected_grad = torch.randn(2, 4) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2) + param.temp_selected_param = param.data.clone() if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2) + param.selected_indices = torch.arange(4) + + opt.step() + step2 = opt.state[param]['step'].item() + assert step2 == step1 + 1 + + +def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4): + torch_param = torch.nn.Parameter(param.detach().clone()) + torch_opt = torch.optim.AdamW([torch_param], lr=zenflow_opt.param_groups[0]['lr']) + + for _ in range(10): + grad = torch.randn_like(param) + param.selected_indices = torch.arange(param.shape[1]) + param.selected_grad = grad if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3) else grad.T + param.temp_selected_param = param.data.clone() if not isinstance( + zenflow_opt, ZenFlowSelectiveAdamW_stage3) else param.ds_tensor.view(param.ds_shape).clone() + + torch_param.grad = grad.clone() + + zenflow_opt.step() + torch_opt.step() + + if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3): + np.testing.assert_allclose(torch_param.data.cpu().numpy(), + param.data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW") + else: + np.testing.assert_allclose(torch_param.data.cpu().numpy(), + param.ds_tensor.view(param.ds_shape).T.clone().data.cpu().numpy(), + atol=atol, + err_msg="Mismatch with torch.AdamW") + + +@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3]) +def test_against_torch_adamw(Opt): + param = make_param(Opt, (2, 4), torch.arange(4)) + opt = Opt([param], lr=1e-3, offload=False) + _compare_with_torch_adamw(param, opt) diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index a37bcd9c869b..6ceff7c289f4 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -23,12 +23,10 @@ pytest.skip('Skip tests since async-io is not compatible', allow_module_level=True) -def _skip_for_invalid_environment(use_cuda_device=True, use_cuda_pinned_tensor=True): - if not get_accelerator().is_available(): - if use_cuda_device: - pytest.skip("GPU tensors only supported in CUDA environments.") +def _skip_for_invalid_environment(use_cuda_pinned_tensor=True): + if get_accelerator().device_name() != 'cuda': if use_cuda_pinned_tensor: - pytest.skip("CUDA-pinned tensors only supported in CUDA environments.") + pytest.skip("torch.pin_memory is only supported in CUDA environments.") def _get_local_rank(): @@ -37,10 +35,14 @@ def _get_local_rank(): return 0 -def _do_ref_write(tmpdir, index=0): +def _get_file_path(tmpdir, file_prefix, index=0): file_suffix = f'{_get_local_rank()}_{index}' - ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') - ref_buffer = os.urandom(IO_SIZE) + return os.path.join(tmpdir, f'{file_prefix}_{file_suffix}.pt') + + +def _do_ref_write(tmpdir, index=0, num_bytes=IO_SIZE): + ref_file = _get_file_path(tmpdir, '_py_random', index) + ref_buffer = os.urandom(num_bytes) with open(ref_file, 'wb') as f: f.write(ref_buffer) @@ -48,17 +50,16 @@ def _do_ref_write(tmpdir, index=0): def _get_test_write_file(tmpdir, index): - file_suffix = f'{_get_local_rank()}_{index}' - return os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') + return _get_file_path(tmpdir, '_aio_write_random', index) -def _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffer, index=0): +def _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer, index=0): test_file = _get_test_write_file(tmpdir, index) test_buffer = get_accelerator().ByteTensor(list(ref_buffer)) return test_file, test_buffer -def _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, aio_handle=None, index=0): +def _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, aio_handle=None, index=0): test_file = _get_test_write_file(tmpdir, index) if aio_handle is None: test_buffer = get_accelerator().pin_memory(torch.ByteTensor(list(ref_buffer))) @@ -73,7 +74,7 @@ def _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, aio_handle=None, ind def _validate_handle_state(handle, single_submit, overlap_events): assert handle.get_single_submit() == single_submit assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_intra_op_parallelism() == IO_PARALLEL assert handle.get_block_size() == BLOCK_SIZE assert handle.get_queue_depth() == QUEUE_DEPTH @@ -83,17 +84,21 @@ def _validate_handle_state(handle, single_submit, overlap_events): @pytest.mark.parametrize("overlap_events", [True, False]) class TestRead(DistributedTest): world_size = 1 + reuse_dist_env = True requires_cuda_env = False if not get_accelerator().is_available(): init_distributed = False set_dist_env = False - def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events): - _skip_for_invalid_environment(use_cuda_device=False, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) - if use_cuda_pinned_tensor: + if use_unpinned_tensor: + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + elif use_cuda_pinned_tensor: aio_buffer = get_accelerator().pin_memory(torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu')) else: aio_buffer = h.new_cpu_locked_tensor(IO_SIZE, torch.empty(0, dtype=torch.uint8)) @@ -101,7 +106,7 @@ def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, over _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.sync_pread(aio_buffer, ref_file) + read_status = h.sync_pread(aio_buffer, ref_file, 0) assert read_status == 1 with open(ref_file, 'rb') as f: @@ -111,14 +116,14 @@ def test_parallel_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, over if not use_cuda_pinned_tensor: h.free_cpu_locked_tensor(aio_buffer) - @pytest.mark.parametrize("cuda_device", [True, False]) - def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) use_cpu_locked_tensor = False h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) - if cuda_device: + if use_unpinned_tensor: aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) elif use_cuda_pinned_tensor: aio_buffer = get_accelerator().pin_memory(torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu')) @@ -129,7 +134,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap _validate_handle_state(h, single_submit, overlap_events) ref_file, _ = _do_ref_write(tmpdir) - read_status = h.async_pread(aio_buffer, ref_file) + read_status = h.async_pread(aio_buffer, ref_file, 0) assert read_status == 0 wait_status = h.wait() @@ -148,25 +153,29 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap @pytest.mark.parametrize("overlap_events", [True, False]) class TestWrite(DistributedTest): world_size = 1 + reuse_dist_env = True requires_cuda_env = False if not get_accelerator().is_available(): init_distributed = False set_dist_env = False - def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events): - _skip_for_invalid_environment(use_cuda_device=False, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_file, ref_buffer = _do_ref_write(tmpdir) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + if use_unpinned_tensor: + aio_file, aio_buffer = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer) if use_cuda_pinned_tensor: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer) else: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, h) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, h) _validate_handle_state(h, single_submit, overlap_events) - write_status = h.sync_pwrite(aio_buffer, aio_file) + write_status = h.sync_pwrite(aio_buffer, aio_file, 0) assert write_status == 1 if not use_cuda_pinned_tensor: @@ -177,25 +186,25 @@ def test_parallel_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, ove filecmp.clear_cache() assert filecmp.cmp(ref_file, aio_file, shallow=False) - @pytest.mark.parametrize("cuda_device", [True, False]) - def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + @pytest.mark.parametrize("use_unpinned_tensor", [True, False]) + def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap_events, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_file, ref_buffer = _do_ref_write(tmpdir) h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) use_cpu_locked_tensor = False - if cuda_device: - aio_file, aio_buffer = _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffer) + if use_unpinned_tensor: + aio_file, aio_buffer = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffer) elif use_cuda_pinned_tensor: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer) else: - aio_file, aio_buffer = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffer, h) + aio_file, aio_buffer = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffer, h) use_cpu_locked_tensor = True _validate_handle_state(h, single_submit, overlap_events) - write_status = h.async_pwrite(aio_buffer, aio_file) + write_status = h.async_pwrite(aio_buffer, aio_file, 0) assert write_status == 0 wait_status = h.wait() @@ -212,7 +221,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla @pytest.mark.sequential @pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) -@pytest.mark.parametrize("cuda_device", [True, False]) +@pytest.mark.parametrize("use_unpinned_tensor", [True, False]) class TestAsyncQueue(DistributedTest): world_size = 1 requires_cuda_env = False @@ -221,8 +230,8 @@ class TestAsyncQueue(DistributedTest): set_dist_env = False @pytest.mark.parametrize("async_queue", [2, 3]) - def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_files = [] for i in range(async_queue): @@ -234,7 +243,7 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) use_cpu_locked_tensor = False - if cuda_device: + if use_unpinned_tensor: aio_buffers = [ torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) for _ in range(async_queue) @@ -252,7 +261,7 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pread(aio_buffers[i], ref_files[i]) + read_status = h.async_pread(aio_buffers[i], ref_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -268,8 +277,8 @@ def test_read(self, tmpdir, async_queue, use_cuda_pinned_tensor, cuda_device): h.free_cpu_locked_tensor(t) @pytest.mark.parametrize("async_queue", [2, 3]) - def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, cuda_device): - _skip_for_invalid_environment(use_cuda_device=cuda_device, use_cuda_pinned_tensor=use_cuda_pinned_tensor) + def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, use_unpinned_tensor): + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) ref_files = [] ref_buffers = [] @@ -285,21 +294,21 @@ def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, cuda_device): aio_files = [] aio_buffers = [] for i in range(async_queue): - if cuda_device: - f, buf = _get_test_write_file_and_cuda_buffer(tmpdir, ref_buffers[i], i) + if use_unpinned_tensor: + f, buf = _get_test_write_file_and_unpinned_tensor(tmpdir, ref_buffers[i], i) elif use_cuda_pinned_tensor: - f, buf = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffers[i], None, i) + f, buf = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffers[i], None, i) else: - f, buf = _get_test_write_file_and_cpu_buffer(tmpdir, ref_buffers[i], h, i) + f, buf = _get_test_write_file_and_pinned_tensor(tmpdir, ref_buffers[i], h, i) aio_files.append(f) aio_buffers.append(buf) - use_cpu_locked_tensor = not (cuda_device or use_cuda_pinned_tensor) + use_cpu_locked_tensor = not (use_unpinned_tensor or use_cuda_pinned_tensor) _validate_handle_state(h, single_submit, overlap_events) for i in range(async_queue): - read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) + read_status = h.async_pwrite(aio_buffers[i], aio_files[i], 0) assert read_status == 0 wait_status = h.wait() @@ -314,3 +323,91 @@ def test_write(self, tmpdir, use_cuda_pinned_tensor, async_queue, cuda_device): filecmp.clear_cache() assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) + + +@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False]) +@pytest.mark.parametrize('file_partitions', [[1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1]]) +class TestAsyncFileOffset(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('use_fd', [False, True]) + def test_offset_write(self, tmpdir, file_partitions, use_cuda_pinned_tensor, use_fd): + + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) + ref_file = _get_file_path(tmpdir, '_py_random') + aio_file = _get_file_path(tmpdir, '_aio_random') + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + if use_cuda_pinned_tensor: + data_buffer = torch.ByteTensor(list(os.urandom(file_size))).pin_memory() + else: + data_buffer = h.new_cpu_locked_tensor(file_size, torch.empty(0, dtype=torch.uint8)) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + ref_fd = open(ref_file, 'wb') + for i in range(len(file_partitions)): + src_buffer = torch.narrow(data_buffer, 0, file_offsets[i], file_partitions[i] * partition_unit_size) + + ref_fd.write(src_buffer.numpy().tobytes()) + ref_fd.flush() + + if use_fd: + aio_fd = os.open(aio_file, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY) + write_status = h.async_pwrite(buffer=src_buffer, fd=aio_fd, file_offset=file_offsets[i]) + else: + write_status = h.async_pwrite(buffer=src_buffer, filename=aio_file, file_offset=file_offsets[i]) + assert write_status == 0 + wait_status = h.wait() + assert wait_status == 1 + + if use_fd: + os.path.isfile(aio_fd) + os.close(aio_fd) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + ref_fd.close() + + if not use_cuda_pinned_tensor: + h.free_cpu_locked_tensor(data_buffer) + + def test_offset_read(self, tmpdir, file_partitions, use_cuda_pinned_tensor): + + _skip_for_invalid_environment(use_cuda_pinned_tensor=use_cuda_pinned_tensor) + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + ref_file, _ = _do_ref_write(tmpdir, 0, file_size) + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + if use_cuda_pinned_tensor: + data_buffer = torch.zeros(file_size, dtype=torch.uint8, device='cpu').pin_memory() + else: + data_buffer = h.new_cpu_locked_tensor(file_size, torch.empty(0, dtype=torch.uint8)) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + with open(ref_file, 'rb') as ref_fd: + for i in range(len(file_partitions)): + ref_fd.seek(file_offsets[i]) + bytes_to_read = file_partitions[i] * partition_unit_size + ref_buf = list(ref_fd.read(bytes_to_read)) + + dst_tensor = torch.narrow(data_buffer, 0, 0, bytes_to_read) + assert 1 == h.sync_pread(dst_tensor, ref_file, file_offsets[i]) + assert dst_tensor.tolist() == ref_buf + + if not use_cuda_pinned_tensor: + h.free_cpu_locked_tensor(data_buffer) diff --git a/tests/unit/ops/aio/test_fast_file_writer_fd_close.py b/tests/unit/ops/aio/test_fast_file_writer_fd_close.py new file mode 100644 index 000000000000..7b63b8976ebb --- /dev/null +++ b/tests/unit/ops/aio/test_fast_file_writer_fd_close.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression test for FastFileWriter file-descriptor cleanup. + +Without proper os.close() in _fini(), every save through FastFileWriter +leaks one fd pointing at the just-written file. When the user later +unlinks the file (e.g. checkpoint rotation), the leaked fd holds the +inode in the filesystem's orphan list, so blocks are never freed and +the filesystem eventually reports ENOSPC even though `ls` shows only +N files on disk. + +These tests assert that after FastFileWriter.close() returns, no fd +in /proc/self/fd points at the (possibly already-unlinked) file. +""" +import os +import sys +import pytest +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import AsyncIOBuilder +from deepspeed.io import FastFileWriter, FastFileWriterConfig + +if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip("async_io op is not compatible on this system", allow_module_level=True) + +if not sys.platform.startswith("linux"): + pytest.skip("test uses /proc/self/fd which is Linux-only", allow_module_level=True) + +if get_accelerator().device_name() != 'cuda': + pytest.skip("FastFileWriter requires CUDA-pinned tensors", allow_module_level=True) + +BLOCK_SIZE = 1 * 1024 * 1024 +QUEUE_DEPTH = 8 +PINNED_BYTES = 8 * 1024 * 1024 +PAYLOAD_BYTES = 1 * 1024 * 1024 + + +def _count_deleted_fds(target_dir): + """How many fds in /proc/self/fd point at a now-deleted file located + under target_dir? Restricting to target_dir avoids false positives + from unrelated deleted fds in the test process.""" + pid = os.getpid() + n = 0 + for entry in os.listdir(f"/proc/{pid}/fd"): + try: + target = os.readlink(f"/proc/{pid}/fd/{entry}") + except OSError: + continue + if target.startswith(str(target_dir)) and target.endswith("(deleted)"): + n += 1 + return n + + +def _build_writer(file_path): + aio = AsyncIOBuilder().load(verbose=False).aio_handle(block_size=BLOCK_SIZE, + queue_depth=QUEUE_DEPTH, + single_submit=False, + overlap_events=False, + intra_op_parallelism=1) + pinned = torch.zeros(PINNED_BYTES, dtype=torch.uint8).pin_memory() + cfg = FastFileWriterConfig(dnvme_handle=aio, + pinned_tensor=pinned, + double_buffer=True, + num_parallel_writers=1, + writer_rank=0) + return FastFileWriter(file_path=str(file_path), config=cfg) + + +@pytest.mark.sequential +def test_close_releases_fd_after_unlink(tmp_path): + """Single save + unlink must not leave a deleted-fd reference.""" + target = tmp_path / "ckpt_single.pt" + buf = torch.zeros(PAYLOAD_BYTES, dtype=torch.uint8) + + before = _count_deleted_fds(tmp_path) + w = _build_writer(target) + torch.save(obj=buf, f=w) + w.close() + os.unlink(target) + after = _count_deleted_fds(tmp_path) + + assert after == before, (f"FastFileWriter leaked an fd: deleted-fd count went " + f"from {before} to {after} after a single save+close+unlink. " + f"This indicates _fini() did not os.close(self._aio_fd).") + + +@pytest.mark.sequential +@pytest.mark.parametrize("n_iters", [5, 20]) +def test_rotation_loop_does_not_leak(tmp_path, n_iters): + """N iterations of save+close+unlink should leave zero deleted-fds. + Mirrors the real checkpoint-rotation workload that originally + surfaced this bug.""" + buf = torch.zeros(PAYLOAD_BYTES, dtype=torch.uint8) + before = _count_deleted_fds(tmp_path) + + for i in range(n_iters): + path = tmp_path / f"ckpt_{i}.pt" + w = _build_writer(path) + torch.save(obj=buf, f=w) + w.close() + os.unlink(path) + + after = _count_deleted_fds(tmp_path) + assert after == before, (f"FastFileWriter leaked {after - before} fd(s) over {n_iters} " + f"save+close+unlink iterations (expected 0).") diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py new file mode 100644 index 000000000000..cebe9a635cdf --- /dev/null +++ b/tests/unit/ops/aio/test_gds.py @@ -0,0 +1,357 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import os +import filecmp +import torch +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import GDSBuilder +from unit.common import DistributedTest + +KILO_BYTE = 1024 * 256 +BLOCK_SIZE = KILO_BYTE +QUEUE_DEPTH = 2 +IO_SIZE = 4 * BLOCK_SIZE +IO_PARALLEL = 2 + +if not deepspeed.ops.__compatible_ops__[GDSBuilder.NAME]: + pytest.skip('Skip tests since gds is not compatible', allow_module_level=True) + + +def _get_local_rank(): + if get_accelerator().is_available(): + return dist.get_rank() + return 0 + + +def _do_ref_write(tmpdir, index=0, file_size=IO_SIZE): + file_suffix = f'{_get_local_rank()}_{index}' + ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') + ref_buffer = os.urandom(file_size) + with open(ref_file, 'wb') as f: + f.write(ref_buffer) + + return ref_file, ref_buffer + + +def _get_file_path(tmpdir, file_prefix, index=0): + file_suffix = f'{_get_local_rank()}_{index}' + return os.path.join(tmpdir, f'{file_prefix}_{file_suffix}.pt') + + +def _get_test_write_file(tmpdir, index): + file_suffix = f'{_get_local_rank()}_{index}' + return os.path.join(tmpdir, f'_gds_write_random_{file_suffix}.pt') + + +def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index=0): + test_file = _get_test_write_file(tmpdir, index) + test_buffer = get_accelerator().ByteTensor(list(ref_buffer)) + gds_handle.pin_device_tensor(test_buffer) + return test_file, test_buffer + + +def _validate_handle_state(handle, single_submit, overlap_events): + assert handle.get_single_submit() == single_submit + assert handle.get_overlap_events() == overlap_events + assert handle.get_intra_op_parallelism() == IO_PARALLEL + assert handle.get_block_size() == BLOCK_SIZE + assert handle.get_queue_depth() == QUEUE_DEPTH + + +@pytest.mark.parametrize("single_submit", [True, False]) +@pytest.mark.parametrize("overlap_events", [True, False]) +class TestRead(DistributedTest): + world_size = 1 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_parallel_read(self, tmpdir, single_submit, overlap_events): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + _validate_handle_state(h, single_submit, overlap_events) + + ref_file, _ = _do_ref_write(tmpdir) + read_status = h.sync_pread(gds_buffer, ref_file, 0) + assert read_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffer.tolist() + + h.unpin_device_tensor(gds_buffer) + + def test_async_read(self, tmpdir, single_submit, overlap_events): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + _validate_handle_state(h, single_submit, overlap_events) + + ref_file, _ = _do_ref_write(tmpdir) + read_status = h.async_pread(gds_buffer, ref_file, 0) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffer.tolist() + + h.unpin_device_tensor(gds_buffer) + + +@pytest.mark.parametrize("single_submit", [True, False]) +@pytest.mark.parametrize("overlap_events", [True, False]) +class TestWrite(DistributedTest): + world_size = 1 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_parallel_write(self, tmpdir, single_submit, overlap_events): + + ref_file, ref_buffer = _do_ref_write(tmpdir) + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.sync_pwrite(gds_buffer, gds_file, 0) + assert write_status == 1 + + h.unpin_device_tensor(gds_buffer) + + assert os.path.isfile(gds_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, gds_file, shallow=False) + + def test_async_write(self, tmpdir, single_submit, overlap_events): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.async_pwrite(gds_buffer, gds_file, 0) + assert write_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + h.unpin_device_tensor(gds_buffer) + + assert os.path.isfile(gds_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, gds_file, shallow=False) + + +@pytest.mark.sequential +class TestAsyncQueue(DistributedTest): + world_size = 1 + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize("async_queue", [2, 3]) + def test_read(self, tmpdir, async_queue): + + ref_files = [] + for i in range(async_queue): + f, _ = _do_ref_write(tmpdir, i) + ref_files.append(f) + + single_submit = True + overlap_events = True + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_buffers = [ + torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) for _ in range(async_queue) + ] + for buf in gds_buffers: + h.pin_device_tensor(buf) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pread(gds_buffers[i], ref_files[i], 0) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + with open(ref_files[i], 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == gds_buffers[i].tolist() + + for t in gds_buffers: + h.unpin_device_tensor(t) + + @pytest.mark.parametrize("async_queue", [2, 3]) + def test_write(self, tmpdir, async_queue): + ref_files = [] + ref_buffers = [] + for i in range(async_queue): + f, buf = _do_ref_write(tmpdir, i) + ref_files.append(f) + ref_buffers.append(buf) + + single_submit = True + overlap_events = True + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL) + + gds_files = [] + gds_buffers = [] + for i in range(async_queue): + f, buf = _get_test_write_file_and_device_buffer(tmpdir, ref_buffers[i], h, i) + gds_files.append(f) + gds_buffers.append(buf) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pwrite(gds_buffers[i], gds_files[i], 0) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for t in gds_buffers: + h.unpin_device_tensor(t) + + for i in range(async_queue): + assert os.path.isfile(gds_files[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_files[i], gds_files[i], shallow=False) + + +@pytest.mark.parametrize("use_new_api", [True, False]) +class TestLockDeviceTensor(DistributedTest): + world_size = 2 + reuse_dist_env = True + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + def test_pin_device_tensor(self, use_new_api): + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + unpinned_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) + if use_new_api: + pinned_buffer = h.new_pinned_device_tensor(unpinned_buffer.numel(), unpinned_buffer) + else: + pinned_buffer = torch.empty_like(unpinned_buffer) + h.pin_device_tensor(pinned_buffer) + + assert unpinned_buffer.device == pinned_buffer.device + assert unpinned_buffer.dtype == pinned_buffer.dtype + assert unpinned_buffer.numel() == pinned_buffer.numel() + + if use_new_api: + h.free_pinned_device_tensor(pinned_buffer) + else: + h.unpin_device_tensor(pinned_buffer) + + +@pytest.mark.parametrize('file_partitions', [[1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1]]) +class TestAsyncFileOffset(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('use_fd', [False, True]) + def test_offset_write(self, tmpdir, use_fd, file_partitions): + ref_file = _get_file_path(tmpdir, '_py_random') + aio_file = _get_file_path(tmpdir, '_aio_random') + partition_unit_size = IO_SIZE + file_size = sum(file_partitions) * partition_unit_size + + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + gds_buffer = torch.empty(file_size, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + ref_fd = open(ref_file, 'wb') + for i in range(len(file_partitions)): + src_buffer = torch.narrow(gds_buffer, 0, file_offsets[i], + file_partitions[i] * partition_unit_size).to(device='cpu') + + ref_fd.write(src_buffer.numpy().tobytes()) + ref_fd.flush() + + if use_fd: + aio_fd = os.open(aio_file, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY) + write_status = h.async_pwrite(buffer=src_buffer, fd=aio_fd, file_offset=file_offsets[i]) + else: + write_status = h.async_pwrite(buffer=src_buffer, filename=aio_file, file_offset=file_offsets[i]) + + assert write_status == 0 + wait_status = h.wait() + assert wait_status == 1 + + if use_fd: + assert os.path.isfile(aio_fd) + os.close(aio_fd) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + ref_fd.close() + + h.unpin_device_tensor(gds_buffer) + + def test_offset_read(self, tmpdir, file_partitions): + partition_unit_size = BLOCK_SIZE + file_size = sum(file_partitions) * partition_unit_size + ref_file, _ = _do_ref_write(tmpdir, 0, file_size) + h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL) + + gds_buffer = torch.empty(file_size, dtype=torch.uint8, device=get_accelerator().device_name()) + h.pin_device_tensor(gds_buffer) + + file_offsets = [] + next_offset = 0 + for i in range(len(file_partitions)): + file_offsets.append(next_offset) + next_offset += file_partitions[i] * partition_unit_size + + with open(ref_file, 'rb') as ref_fd: + for i in range(len(file_partitions)): + ref_fd.seek(file_offsets[i]) + bytes_to_read = file_partitions[i] * partition_unit_size + ref_buf = list(ref_fd.read(bytes_to_read)) + + dst_tensor = torch.narrow(gds_buffer, 0, 0, bytes_to_read) + read_status = h.async_pread(dst_tensor, ref_file, file_offsets[i]) + assert read_status == 0 + wait_status = h.wait() + assert wait_status == 1 + assert dst_tensor.tolist() == ref_buf + + h.unpin_device_tensor(gds_buffer) diff --git a/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py new file mode 100644 index 000000000000..307f59157220 --- /dev/null +++ b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch +from torch.nn import functional as F +import deepspeed +from deepspeed.ops.op_builder import EvoformerAttnBuilder +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +from deepspeed.accelerator import get_accelerator +from unit.util import skip_on_arch + +if not deepspeed.ops.__compatible_ops__[EvoformerAttnBuilder.NAME]: + pytest.skip("DS4Sci_EvoformerAttention ops are not available on this system", allow_module_level=True) + + +def attention_reference( + q_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + k_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + v_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + biases: List[torch.Tensor], + sm_scale: float) -> torch.Tensor: + q = q_input.transpose(-2, -3) + k = k_input.transpose(-2, -3) + v = v_input.transpose(-2, -3) + k_t = k.transpose(-1, -2) + a = torch.matmul(q, k_t) * sm_scale + + for b in biases: + a += b + + a = F.softmax(a, dim=-1) + a_v = torch.matmul(a, v) + o = a_v.transpose(-2, -3) + + return o + + +@pytest.mark.sequential +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)]) +def test_DS4Sci_EvoformerAttention(dtype, tensor_shape): + skip_on_arch(8 if dtype == torch.bfloat16 else 7) + batch, n, seq_len, heads, dim = tensor_shape + Q = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + K = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + V = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + mask = torch.randint(0, 2, (batch, n, 1, 1, seq_len), dtype=dtype, device=get_accelerator().device_name()) + mask_bias = 1e9 * (mask - 1) + bias = torch.randn(batch, + 1, + heads, + seq_len, + seq_len, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name()) + ref_out = attention_reference(Q, K, V, [mask_bias, bias], 1 / (dim**0.5)) + ref_out.backward(dummy_out) + ref_dv, V.grad = V.grad.clone(), None + ref_dk, K.grad = K.grad.clone(), None + ref_dq, Q.grad = Q.grad.clone(), None + ref_db, bias.grad = bias.grad.clone(), None + + out = DS4Sci_EvoformerAttention(Q, K, V, [mask_bias, bias]) + out.backward(dummy_out) + dv, v_grad = V.grad.clone(), None + dk, k_grad = K.grad.clone(), None + dq, q_grad = Q.grad.clone(), None + db, bias.grad = bias.grad.clone(), None + + eps = 1e-2 if dtype == torch.float16 else 5e-2 + + assert torch.max(torch.abs(ref_out - out)).item() < eps, f"out eps: {torch.max(torch.abs(ref_out - out))}" + assert torch.max(torch.abs(ref_dv - dv)) < eps, f"dv eps: {torch.max(torch.abs(ref_dv - dv))}" + assert torch.max(torch.abs(ref_dk - dk)) < eps, f"dk eps: {torch.max(torch.abs(ref_dk - dk))}" + assert torch.max(torch.abs(ref_dq - dq)) < eps, f"dq eps: {torch.max(torch.abs(ref_dq - dq))}" + assert torch.max(torch.abs(ref_db - db)) < 2 * eps, f"db eps: {torch.max(torch.abs(ref_db - db))}" diff --git a/tests/unit/ops/deepspeed4science/test_evoformer_attn_builder.py b/tests/unit/ops/deepspeed4science/test_evoformer_attn_builder.py new file mode 100644 index 000000000000..ff78839c8900 --- /dev/null +++ b/tests/unit/ops/deepspeed4science/test_evoformer_attn_builder.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from deepspeed.ops.op_builder.builder import CUDAOpBuilder +# Import the concrete builder class instead of the accelerator-dispatched alias. +from deepspeed.ops.op_builder.evoformer_attn import EvoformerAttnBuilder + + +def make_cutlass_checkout(path): + include_dir = path / "include" / "cutlass" + include_dir.mkdir(parents=True) + (include_dir / "cutlass.h").write_text("// cutlass marker\n") + util_dir = path / "tools" / "util" / "include" + util_dir.mkdir(parents=True) + return path + + +def test_filter_ccs_removes_below_70_and_keeps_ptx_suffix(): + builder = EvoformerAttnBuilder() + result = builder.filter_ccs(["6.0", "6.1", "7.0", "8.0+PTX"]) + + majors = [int(cc[0]) for cc in result] + assert 6 not in majors + assert 7 in majors + assert 8 in majors + + ptx_entries = [cc for cc in result if cc[1].endswith("+PTX")] + assert len(ptx_entries) == 1 + assert ptx_entries[0] == ["8", "0+PTX"] + + +def test_nvcc_args_deprecates_env_and_omits_gpu_arch_define(): + builder = EvoformerAttnBuilder() + with patch.dict("os.environ", {"DS_EVOFORMER_GPU_ARCH": "80"}, clear=False): + with patch.object(builder, "warning") as warn: + with patch.object(CUDAOpBuilder, "nvcc_args", return_value=["-O3", "-lineinfo"]): + args = builder.nvcc_args() + + warning_messages = [call.args[0] for call in warn.call_args_list if call.args] + assert any("DS_EVOFORMER_GPU_ARCH is deprecated and ignored" in msg for msg in warning_messages) + assert all("-DGPU_ARCH=" not in arg for arg in args) + + +def test_no_cuda_arch_in_checkarch(): + header = Path(__file__).resolve().parents[4] / "csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h" + text = header.read_text() + start = text.index("struct CheckArch") + end = text.index("};", start) + 2 + block = text[start:end] + assert "__CUDA_ARCH__" not in block + + +def test_include_paths_uses_cutlass_path_env(tmp_path): + cutlass_path = make_cutlass_checkout(tmp_path / "cutlass") + + with patch.dict("os.environ", {"CUTLASS_PATH": str(cutlass_path)}, clear=False): + builder = EvoformerAttnBuilder() + + assert builder.include_paths() == [ + str(cutlass_path / "include"), + str(cutlass_path / "tools" / "util" / "include"), + ] + + +def test_include_paths_finds_python_package_candidate_without_env(tmp_path): + cutlass_path = make_cutlass_checkout(tmp_path / "python_package_cutlass") + + with patch.dict("os.environ", {}, clear=True): + builder = EvoformerAttnBuilder() + + with patch.object(EvoformerAttnBuilder, "_python_package_cutlass_paths", return_value=[cutlass_path]): + assert builder.include_paths()[0] == str(cutlass_path / "include") + + +def test_include_paths_finds_cutlass_from_cmake_prefix_path(tmp_path): + cutlass_path = make_cutlass_checkout(tmp_path / "prefix") + + with patch.dict("os.environ", {"CMAKE_PREFIX_PATH": str(cutlass_path)}, clear=True): + builder = EvoformerAttnBuilder() + with patch.object(EvoformerAttnBuilder, "_python_package_cutlass_paths", return_value=[]): + assert builder.include_paths()[0] == str(cutlass_path / "include") + + +def test_include_paths_finds_cutlass_from_compiler_include_path(tmp_path): + cutlass_path = make_cutlass_checkout(tmp_path / "prefix") + + with patch.dict("os.environ", {"CPATH": str(cutlass_path / "include")}, clear=True): + builder = EvoformerAttnBuilder() + with patch.object(EvoformerAttnBuilder, "_python_package_cutlass_paths", return_value=[]): + assert builder.include_paths()[0] == str(cutlass_path / "include") + + +def test_include_paths_accepts_cutlass_include_dir_directly(tmp_path): + cutlass_path = make_cutlass_checkout(tmp_path / "cutlass") + + with patch.dict("os.environ", {"CUTLASS_PATH": str(cutlass_path / "include")}, clear=False): + builder = EvoformerAttnBuilder() + + assert builder.include_paths() == [ + str(cutlass_path / "include"), + str(cutlass_path / "tools" / "util" / "include"), + ] + + +def test_include_paths_reports_missing_cutlass(tmp_path): + with patch.dict("os.environ", {}, clear=True): + builder = EvoformerAttnBuilder() + + with patch.object(builder, "_candidate_cutlass_paths", return_value=[tmp_path / "missing"]): + with pytest.raises(RuntimeError, match="Unable to locate CUTLASS"): + builder.include_paths() diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py new file mode 100644 index 000000000000..d068a05b77bb --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 + +from deepspeed import get_accelerator +from deepspeed.linear import QuantizationConfig + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8], ids=[ + "qbits8", +]) +@pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) +def test_fp_quant(dtype, q_bits, M): + device_name = get_accelerator().device_name() + quantization_group_size = 128 + quant_config = QuantizationConfig(q_dtype=FPQuantizerBuilder.get_default_quant_dtype(), + group_size=quantization_group_size) + fpq = FP_Quantize(quantization_config=quant_config) + + N = 8192 + H = 4096 + + x = torch.randn(M, H, dtype=dtype, device=device_name) + weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name) + + weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True) + scale = fpq.get_scales() + out = matmul_fp8(x, weight, scale, quantization_group_size, fpq) + + out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale)) + + error = ((out - out_q).abs() / (out.abs() + 1e-5)).sum() / out.numel() + assert 0.004 > error, f"failed on batch-size {M} with error {error}" diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py new file mode 100644 index 000000000000..0655b0ce26a3 --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +from deepspeed.linear import QuantizationConfig + +import deepspeed + +from deepspeed.ops.fp_quantizer import FP_Quantize +from deepspeed.ops.op_builder import FPQuantizerBuilder +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +# warning: this import silently JIT builds a set of kernels and may take a minute +from qtorch.quant import float_quantize + + +def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_size=1024): + ori_dt = input.dtype + ori_shape = input.shape + last_dim = group_size + input = input.view(-1, last_dim) + + q_bits = exp_bits + man_bits + 1 + q_range = FPQuantizerBuilder.get_quant_range(q_bits) + input_to_float = input.float() + input_max = input_to_float.abs().amax(dim=-1, keepdim=True) + + return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \ + input_max / q_range).to(ori_dt)).reshape(ori_shape) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_meta(dtype): + device_name = get_accelerator().device_name() + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = group_size + fpq = FP_Quantize(quantization_config=quant_config) + + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype) + + ds_x = x.clone().to(device_name) + x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor) + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_selective(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + device_name = get_accelerator().device_name() + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = group_size + fpq = FP_Quantize(quantization_config=quant_config) + + indexes = torch.zeros(2, dtype=torch.int32, device=device_name) + indexes[0] = 1 + indexes[1] = 3 + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device=device_name) + + x = x.reshape(4, 1, x.shape[-1]) + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.selective_dequantize(x_quantized, indexes, q_bits=q_bits) + + qtorch_out = qtorch_quantize(x.index_select(0, indexes), + exp_bits=exp_bits, + man_bits=man_bits, + group_size=group_size) + qtorch_error = (qtorch_out - x.index_select(0, indexes)).abs().sum() / x.numel() + ds_error = (x_dequantized - x.index_select(0, indexes)).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) +def test_fp_quant(dtype, q_bits): + device_name = get_accelerator().device_name() + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = 128 + fpq = FP_Quantize(quantization_config=quant_config) + + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype) + + ds_x = x.clone().to(device_name) + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits) + + if q_bits == 8: + exp_bits = 4 + man_bits = 3 + elif q_bits == 6: + exp_bits = 3 + man_bits = 2 + elif q_bits == 12: + exp_bits = 4 + man_bits = 7 + else: + raise ValueError(f"unknown {q_bits=}") + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=quant_config.group_size) + + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" diff --git a/tests/unit/ops/lion/test_cpu_lion.py b/tests/unit/ops/lion/test_cpu_lion.py new file mode 100644 index 000000000000..dce027e286fb --- /dev/null +++ b/tests/unit/ops/lion/test_cpu_lion.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +import pytest +from cpuinfo import get_cpu_info + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.lion import FusedLion +from deepspeed.ops.op_builder import CPULionBuilder +from unit.common import DistributedTest + +pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().float().numpy() + y = second.detach().float().numpy() + print("ATOL", atol) + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): + for i in range(10): + param1.grad = torch.randn(model_size, device=param1.device).to(param1.dtype) + param2.grad = param1.grad.clone().detach().to(device=param2.device, dtype=param2.dtype) + + optimizer1.step() + optimizer2.step() + + tolerance = param1.float().norm().detach().numpy() * 1e-2 + check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) + + +@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + #(55), + (128), + (1024), + (1048576), + ]) # yapf: disable +class TestCPULion(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], + reason="CPULionBuilder has not been implemented on this system.") + def test_fused_lion_equal(self, dtype, model_size): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-lion with half precision not supported on AMD CPUs") + + from deepspeed.ops.lion import DeepSpeedCPULion + + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + cpu_param = torch.nn.Parameter(cpu_data) + cuda_param = torch.nn.Parameter(cpu_data.to(get_accelerator().device_name())) + + cpu_optimizer = DeepSpeedCPULion([cpu_param]) + cuda_optimizer = FusedLion([cuda_param]) + + _compare_optimizers(model_size=model_size, + param1=cpu_param, + optimizer1=cpu_optimizer, + param2=cuda_param, + optimizer2=cuda_optimizer) + + +class TestCPULionGPUError(DistributedTest): + + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], + reason="CPULionBuilder has not been implemented on this system.") + def test_cpu_lion_gpu_error(self): + model_size = 64 + from deepspeed.ops.lion import DeepSpeedCPULion + device = get_accelerator().device_name(0) # 'cuda:0' or 'xpu:0' + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + optimizer = DeepSpeedCPULion([param]) + + param.grad = torch.randn(model_size, device=device) + with pytest.raises(AssertionError): + optimizer.step() diff --git a/tests/unit/ops/lion/test_lion.py b/tests/unit/ops/lion/test_lion.py new file mode 100644 index 000000000000..507ff72ea51a --- /dev/null +++ b/tests/unit/ops/lion/test_lion.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from deepspeed.ops.lion import FusedLion +from deepspeed.ops.lion import DeepSpeedCPULion +from unit.common import DistributedTest +from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import CPULionBuilder + +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) +# yapf: disable +#'optimizer, zero_offload, resulting_optimizer +lion_configs = [["Lion", False, FusedLion], + ["Lion", True, DeepSpeedCPULion]] + +@pytest.mark.parametrize( + 'optimizer, zero_offload, resulting_optimizer', + lion_configs) +class TestLionConfigs(DistributedTest): + world_size = 1 + reuse_dist_env = True + + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME], reason="CPULionBuilder has not been implemented on this system.") + def test(self, + optimizer, + zero_offload, + resulting_optimizer): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": optimizer, + "params": { + "lr": 0.00015, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": zero_offload + } + } + model = SimpleModel(10) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + # get base optimizer under zero + ds_optimizer = model.optimizer.optimizer + opt_class = resulting_optimizer + assert isinstance(ds_optimizer, opt_class) diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py new file mode 100644 index 000000000000..84b06dd96265 --- /dev/null +++ b/tests/unit/ops/muon/test_muon.py @@ -0,0 +1,179 @@ +# Copyright (c) 2025 Peng Du and Zhipeng Wang +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + +# 'optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, save_muon_momentum_buffer_in_memory' + +muon_configs = [] +for optimizer_name in ['muon', 'adam']: + for stage in [1, 2, 3]: + for lr in [0.01, 0.05]: + for model_dim in [32, 128]: + for nlayer in [5, 10]: + for offload_optimizer in [True, False]: + for save_in_mem in ([True, False] if stage == 3 else [False]): + muon_configs.append( + [optimizer_name, stage, lr, model_dim, nlayer, offload_optimizer, save_in_mem]) + + +@pytest.mark.parametrize( + 'optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, save_muon_momentum_buffer_in_memory', + muon_configs) +class TestMuonConfigs(DistributedTest): + + def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, + save_muon_momentum_buffer_in_memory): + optimizer_params = {"lr": lr} + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": optimizer_type, + "params": optimizer_params + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + "save_muon_momentum_buffer_in_memory": save_muon_momentum_buffer_in_memory, + }, + } + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + # Perform a few training steps to ensure the optimizer works correctly + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer) + initial_params = [p.clone().cpu() for p in model.parameters()] + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + assert optimizer_type in optimizer.optimizer.__class__.__name__.lower( + ), f"Expected optimizer type {optimizer_type}, got {optimizer.optimizer.__class__.__name__}" + steps = 5 + for _ in range(steps): + # Random inputs: (batch_size, hidden_dim) + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + # Random class labels: (batch_size,) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + # Forward + loss + loss = engine(x, y) + # Backward + engine.backward(loss) + engine.step() + + # Verify that parameters have been updated + after_training = [p.clone().cpu() for p in model.parameters()] + for initial, final in zip(initial_params, after_training): + assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training" + + +class TestGramNewtonSchulz(DistributedTest): + """Test Gram Newton-Schulz integration with Muon optimizer.""" + + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('ns_method', ['gram', 'standard']) + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_ns_method_training(self, ns_method, zero_stage): + """Verify both ns_method values work end-to-end with DeepSpeed.""" + hidden_dim = 64 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01, + "ns_method": ns_method, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=3) + initial_params = [p.clone().cpu() for p in model.parameters()] + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + after_training = [p.clone().cpu() for p in model.parameters()] + for initial, final in zip(initial_params, after_training): + assert not torch.equal(initial, final), "Parameters should have been updated" + + @pytest.mark.parametrize('ns_method', ['gram', 'standard']) + def test_ns_method_stage3(self, ns_method): + """Verify ns_method works with ZeRO Stage 3.""" + hidden_dim = 64 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01, + "ns_method": ns_method, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + }, + "zero_optimization": { + "stage": 3, + "reduce_scatter": False, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=3) + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() diff --git a/tests/unit/ops/muon/test_muon_partial_training.py b/tests/unit/ops/muon/test_muon_partial_training.py new file mode 100644 index 000000000000..8e34e66389e1 --- /dev/null +++ b/tests/unit/ops/muon/test_muon_partial_training.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Test for PR #7869: Fix Muon optimizer with partial model training + +This test verifies that the fix for Muon optimizer parameter grouping works +correctly when only part of the model parameters are trainable. + +The bug occurred when: +1. Some parameters use Muon optimizer (p.use_muon = True) +2. Other parameters use AdamW optimizer (p.use_muon = False) +3. All trainable parameters happen to use the same optimizer type + +This caused one of the parameter groups to be empty, leading to: +ValueError: torch.cat(): expected a non-empty list of Tensors + +The fix filters parameters to only include those with requires_grad=True, +ensuring empty parameter groups are properly handled. +""" + +import torch.nn as nn +import deepspeed +from unit.common import DistributedTest + + +class PartialTrainableModel(nn.Module): + """ + A model where some parameters use Muon and some use AdamW. + + This simulates the scenario where: + - Hidden layers use Muon (ndim >= 2) + - Embeddings and biases use AdamW (ndim < 2) + """ + + def __init__(self, vocab_size=100, hidden_dim=64, num_layers=2): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)]) + self.output = nn.Linear(hidden_dim, vocab_size) + + # Set use_muon attribute for parameters + # Muon should be used for ndim >= 2 (matrices) + # AdamW should be used for ndim < 2 (embeddings, biases) + for name, param in self.named_parameters(): + if param.ndim >= 2: + param.use_muon = True + else: + param.use_muon = False + + +class TestMuonPartialModelTraining(DistributedTest): + """Test Muon optimizer with partial model training scenarios.""" + + world_size = 2 + reuse_dist_env = True + requires_cuda_env = False + + def test_muon_with_all_trainable_params(self): + """ + Test when all parameters are trainable. + + This should work fine as both Muon and AdamW parameter groups + will be non-empty. + """ + model = PartialTrainableModel() + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Muon", + "params": { + "lr": 0.02, + "weight_decay": 0.01 + } + }, + "zero_optimization": { + "stage": 2 + }, + } + + # This should not raise ValueError + model_engine, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) + + # Verify the model was initialized successfully + assert model_engine is not None + + def test_muon_with_partial_trainable_params_same_optimizer(self): + """ + Test the bug scenario: all trainable params use the same optimizer. + + This is the bug case where: + - All trainable parameters have use_muon=True (or all False) + - This causes one parameter group to be empty + - Without the fix, this raises: ValueError: torch.cat(): expected a non-empty list of Tensors + + The fix filters by requires_grad, so empty groups are properly handled. + """ + model = PartialTrainableModel() + + # Freeze all Linear layers (which have use_muon=True) + # Keep only embeddings and biases trainable (use_muon=False) + for name, param in model.named_parameters(): + if "layers" in name or "output" in name: + param.requires_grad = False + + # Now all trainable parameters have use_muon=False + # This would cause muon_params to be empty without the fix + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Muon", + "params": { + "lr": 0.02, + "weight_decay": 0.01 + } + }, + "zero_optimization": { + "stage": 2 + }, + } + + # This would raise ValueError without the fix + # With the fix, it should initialize successfully + model_engine, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) + + # Verify the model was initialized successfully + assert model_engine is not None + + def test_muon_with_mixed_trainable_params(self): + """ + Test when trainable parameters use both optimizers. + + This is the normal case where: + - Some trainable params have use_muon=True + - Some trainable params have use_muon=False + - Both parameter groups are non-empty + + This should work fine even without the fix. + """ + model = PartialTrainableModel() + + # Freeze only the first Linear layer + # This leaves both Muon and AdamW parameters trainable + for name, param in model.named_parameters(): + if "layers.0" in name: + param.requires_grad = False + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Muon", + "params": { + "lr": 0.02, + "weight_decay": 0.01 + } + }, + "zero_optimization": { + "stage": 2 + }, + } + + # This should work fine + model_engine, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) + + # Verify the model was initialized successfully + assert model_engine is not None diff --git a/tests/unit/ops/quantizer/test_dequantize.py b/tests/unit/ops/quantizer/test_dequantize.py deleted file mode 100644 index ed0b7d4fbce2..000000000000 --- a/tests/unit/ops/quantizer/test_dequantize.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import pytest -import torch -from deepspeed.ops import op_builder -from deepspeed.accelerator import get_accelerator - -quantize_module = None - - -def int4x2to2xint4(int4X2tensor): - high = int4X2tensor >> 4 - low = (int4X2tensor << 4) >> 4 - return torch.stack((high, low), dim=-1).flatten() - - -def run_quantize(data, num_groups, q_bits, is_symmetric_quant): - global quantize_module - if quantize_module is None: - quantize_module = op_builder.QuantizerBuilder().load() - - return quantize_module.quantize(data, num_groups, q_bits, - quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric) - - -def run_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant): - global quantize_module - if quantize_module is None: - quantize_module = op_builder.QuantizerBuilder().load() - - return quantize_module.dequantize(quantized_data, params, num_groups, q_bits, - quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric) - - -def run_ref_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant): - - if (q_bits == 4): - quantized_data = int4x2to2xint4(quantized_data) - - quantized_data = quantized_data.reshape(num_groups, -1).to(torch.float32) - - if is_symmetric_quant: - return (quantized_data * params).to(torch.float16) - else: - scales = params[:, 0].reshape(-1, 1) - offsets = params[:, 1].reshape(-1, 1) - return (quantized_data * scales + offsets).to(torch.float16) - - -@pytest.mark.inference_ops -@pytest.mark.parametrize("num_groups", [1, 13, 512]) -@pytest.mark.parametrize("num_elems", [8, 16, 32, 64, 128, 256, 4096, 8192, 12288, 16384]) -@pytest.mark.parametrize("is_symmetric_quant", [True, False]) -@pytest.mark.parametrize("q_bits", [4, 8]) -def test_dequantize(num_elems, num_groups, is_symmetric_quant, q_bits): - - activations = torch.randn((num_groups, num_elems), dtype=torch.float16, device=get_accelerator().device_name()) - quantized_data, params = run_quantize(activations, num_groups, q_bits, is_symmetric_quant) - - ds_dequant = run_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant) - ref_dequant = run_ref_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant) - - assert (torch.allclose(ds_dequant.flatten(), ref_dequant.flatten(), rtol=3e-2, atol=2e-3)) diff --git a/tests/unit/ops/quantizer/test_fake_quantization.py b/tests/unit/ops/quantizer/test_fake_quantization.py index a9c2993ef7c2..2549878ed541 100644 --- a/tests/unit/ops/quantizer/test_fake_quantization.py +++ b/tests/unit/ops/quantizer/test_fake_quantization.py @@ -5,8 +5,12 @@ import torch import pytest +import deepspeed from deepspeed.accelerator import get_accelerator -from deepspeed.ops import op_builder +from deepspeed.ops.op_builder import QuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) quantizer_cuda_module = None @@ -36,7 +40,7 @@ def run_quant_dequant(inputs, groups, bits): global quantizer_cuda_module if quantizer_cuda_module is None: - quantizer_cuda_module = op_builder.QuantizerBuilder().load() + quantizer_cuda_module = QuantizerBuilder().load() return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits) diff --git a/tests/unit/ops/quantizer/test_quantize.py b/tests/unit/ops/quantizer/test_quantize.py index 1ddc86ce09d1..1f5e3dc95721 100644 --- a/tests/unit/ops/quantizer/test_quantize.py +++ b/tests/unit/ops/quantizer/test_quantize.py @@ -5,21 +5,38 @@ import pytest import torch -from deepspeed.ops import op_builder +import deepspeed +from deepspeed.ops.op_builder import QuantizerBuilder from deepspeed.accelerator import get_accelerator +if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + inference_module = None def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant): global inference_module if inference_module is None: - inference_module = op_builder.QuantizerBuilder().load() + inference_module = QuantizerBuilder().load() return inference_module.quantize(activations, num_groups, q_bits, inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric) +def run_dequantize_ds(activations, params, num_groups, q_bits, is_symmetric_quant): + global inference_module + if inference_module is None: + inference_module = QuantizerBuilder().load() + return inference_module.dequantize( + activations, + params, + num_groups, + q_bits, + inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric, + ) + + def get_q_props(q_bits): q_range = 2**q_bits q_min = -(2**(q_bits - 1)) @@ -87,6 +104,22 @@ def run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups): return data_i8, params +def run_float_dequantize(q_bits, is_symmetric_quant, data_i8, params, num_groups): + data_f = data_i8.reshape(num_groups, -1).to(dtype=torch.float32) + + scales = params[:, 0].reshape(-1, 1) + offsets = params[:, 1].reshape(-1, 1) + + if not is_symmetric_quant: + data_f = data_f - offsets + else: + assert offsets.allclose(torch.zeros_like(offsets)) + + data_f = data_f * scales + + return data_f + + @pytest.mark.inference_ops @pytest.mark.parametrize("num_groups", [1, 13, 512]) @pytest.mark.parametrize("num_elems", [8, 16, 32, 64, 128, 256, 4096, 8192, 12288, 16384]) @@ -94,6 +127,8 @@ def run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups): @pytest.mark.parametrize("q_bits", [4, 8]) @pytest.mark.parametrize("directed_case", ["all_zeros", None]) def test_float_quantize(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case): + # fix seed + torch.manual_seed(num_elems) if directed_case == "all_zeros": activations_ds = torch.zeros((num_groups, num_elems), @@ -106,16 +141,14 @@ def test_float_quantize(num_elems, num_groups, is_symmetric_quant, q_bits, direc activations_ref = activations_ds.clone().detach() ref_out_tensor, ref_params = run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups) + ref_dequantized_tensor = run_float_dequantize(q_bits, is_symmetric_quant, ref_out_tensor, ref_params, num_groups) + # we need to convert the tensor to float64 to avoid overflow + ref_quantization_error = torch.sum(torch.abs((activations_ref - ref_dequantized_tensor).to(torch.float64))) ds_out_tensor, ds_out_params = run_quantize_ds(activations_ds, num_groups, q_bits, is_symmetric_quant) + ds_dequantized_tensor = run_dequantize_ds(ds_out_tensor, ds_out_params, num_groups, q_bits, is_symmetric_quant) + assert torch.all(torch.isfinite(ds_dequantized_tensor)) - if (q_bits == 4): - ds_out_tensor = int4x2to2xint4(ds_out_tensor) + ds_quantization_error = torch.sum(torch.abs((activations_ds - ds_dequantized_tensor).to(torch.float64))) - # Allow a max difference of 1 to account for differences in rounding in pytorch implementation - assert (torch.all(torch.lt(torch.abs(ds_out_tensor.flatten() - ref_out_tensor.flatten()), 2))) - if is_symmetric_quant: - assert (torch.allclose(ds_out_params.flatten(), ref_params[:, 0].flatten())) - else: - assert (torch.allclose(ds_out_params[:, 0].flatten(), ref_params[:, 0].flatten())) - assert (torch.allclose(ds_out_params[:, 1].flatten(), ref_params[:, 1].flatten(), atol=5e-5, rtol=5e-5)) + assert (ds_quantization_error <= ref_quantization_error * 1.05) diff --git a/tests/unit/ops/spatial/test_nhwc_bias_add.py b/tests/unit/ops/spatial/test_nhwc_bias_add.py index f243e82f6d3b..3787b46e266a 100644 --- a/tests/unit/ops/spatial/test_nhwc_bias_add.py +++ b/tests/unit/ops/spatial/test_nhwc_bias_add.py @@ -5,9 +5,14 @@ import pytest import torch +import deepspeed +from deepspeed.ops.op_builder import SpatialInferenceBuilder from deepspeed.ops.transformer.inference.bias_add import nhwc_bias_add from deepspeed.accelerator import get_accelerator +if not deepspeed.ops.__compatible_ops__[SpatialInferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + def allclose(x, y): assert x.dtype == y.dtype diff --git a/tests/unit/ops/test_op_builder.py b/tests/unit/ops/test_op_builder.py new file mode 100644 index 000000000000..edbeeb469af7 --- /dev/null +++ b/tests/unit/ops/test_op_builder.py @@ -0,0 +1,218 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import importlib.util +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +BUILDER_PATH = Path(__file__).resolve().parents[3] / "op_builder" / "builder.py" +BUILDER_SPEC = importlib.util.spec_from_file_location("test_op_builder_module", BUILDER_PATH) +builder_module = importlib.util.module_from_spec(BUILDER_SPEC) +BUILDER_SPEC.loader.exec_module(builder_module) +CUDAOpBuilder = builder_module.CUDAOpBuilder + +BUILDER_MODULE = builder_module +CUDA_API = BUILDER_MODULE.torch.cuda #ignore-cuda + + +class _StubCUDAOpBuilder(CUDAOpBuilder): + BUILD_VAR = "STUB_BUILDER" + NAME = "stub" + + def __init__(self): + super().__init__(name="stub") + + def absolute_name(self): + return "deepspeed.ops.stub" + + def sources(self): + return [] + + def include_paths(self): + return [] + + +def make_builder(**overrides): + builder = _StubCUDAOpBuilder() + for key, value in overrides.items(): + setattr(builder, key, value) + return builder + + +def assert_jit_uses_explicit_arch_list(builder, expected_arch_list, env_updates=None): + env_updates = env_updates or {} + + with patch.dict(os.environ, env_updates, clear=False): + if "TORCH_CUDA_ARCH_LIST" not in env_updates: + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "device_count", + side_effect=AssertionError("probe should not be called")) as device_count: + with patch.object(CUDA_API, + "get_device_capability", + side_effect=AssertionError("probe should not be called")) as get_device_capability: + assert builder.compute_capability_args() == [] + assert os.environ["TORCH_CUDA_ARCH_LIST"] == expected_arch_list + + device_count.assert_not_called() + get_device_capability.assert_not_called() + + +def test_jit_mode_prefers_explicit_arch_lists_before_cuda_probe(): + assert_jit_uses_explicit_arch_list(make_builder(jit_mode=True, _jit_arch_list="8.0;8.9"), "8.0;8.9+PTX") + assert_jit_uses_explicit_arch_list(make_builder(jit_mode=True), "8.0;8.9+PTX", {"TORCH_CUDA_ARCH_LIST": "8.0 8.9"}) + + +def test_bad_fork_jit_without_arch_list_raises_actionable_error(): + builder = make_builder(jit_mode=True) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=True): + with patch.object(CUDA_API, "device_count", + side_effect=AssertionError("probe should not be called")) as device_count: + with pytest.raises(RuntimeError, match="TORCH_CUDA_ARCH_LIST"): + builder.compute_capability_args() + + device_count.assert_not_called() + + +def test_jit_mode_probes_devices_when_safe_and_errors_without_visible_gpus(): + builder = make_builder(jit_mode=True) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=False): + with patch.object(CUDA_API, "device_count", return_value=2) as device_count: + with patch.object(CUDA_API, "get_device_capability", side_effect=[(7, 0), + (8, 9)]) as get_device_capability: + assert builder.compute_capability_args() == [] + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "7.0;8.9+PTX" + assert builder.enable_bf16 is False + + device_count.assert_called_once_with() + assert get_device_capability.call_count == 2 + + builder = make_builder(jit_mode=True) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=False): + with patch.object(CUDA_API, "device_count", return_value=0): + with pytest.raises(RuntimeError, match="no CUDA devices"): + builder.compute_capability_args() + + +def test_jit_load_restores_env_and_state_after_failure(): + builder = make_builder() + + def fail_nvcc_args(): + assert getattr(builder, "_jit_arch_list", None) == "8.9" + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9+PTX" + raise RuntimeError("build failed") + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.9"}, clear=False): + with patch.object(builder, "is_compatible", return_value=True): + with patch.object(CUDAOpBuilder, "is_rocm_pytorch", return_value=False): + with patch.object(CUDA_API, "is_available", return_value=True): + with patch("torch.utils.cpp_extension.verify_ninja_availability", return_value=None): + with patch.object(builder, "nvcc_args", side_effect=fail_nvcc_args): + with pytest.raises(RuntimeError, match="build failed"): + builder.jit_load(verbose=False) + + assert getattr(builder, "_jit_arch_list", None) is None + assert builder.jit_mode is False + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "8.9" + + +def test_jit_load_restores_state_after_success(): + builder = make_builder() + op_module = MagicMock() + + def successful_nvcc_args(): + assert builder._jit_arch_list == "8.9" + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9+PTX" + return [] + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.9"}, clear=False): + with patch.object(builder, "is_compatible", return_value=True): + with patch.object(CUDAOpBuilder, "is_rocm_pytorch", return_value=False): + with patch.object(CUDA_API, "is_available", return_value=True): + with patch("torch.utils.cpp_extension.verify_ninja_availability", return_value=None): + with patch.object(builder, "nvcc_args", side_effect=successful_nvcc_args): + with patch.object(builder, "cxx_args", return_value=[]): + with patch("torch.utils.cpp_extension.load", return_value=op_module): + assert builder.jit_load(verbose=False) is op_module + + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "8.9" + assert getattr(builder, "_jit_arch_list", None) is None + assert builder.jit_mode is False + + +def test_non_jit_branch_unchanged(): + builder = make_builder(jit_mode=False) + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.0;8.9+PTX"}, clear=False): + args = builder.compute_capability_args() + + assert args == [ + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_89,code=compute_89", + ] + + +def test_non_jit_branch_sorts_and_dedupes_gencode_flags(): + builder = make_builder(jit_mode=False) + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.0;7.5;8.0;7.0"}, clear=False): + args = builder.compute_capability_args() + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "7.0;7.5;8.0" + + assert args == [ + "-gencode=arch=compute_70,code=sm_70", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + ] + + +def test_non_jit_branch_canonicalizes_mixed_ptx_variants_to_one_sm_and_one_ptx(): + # For mixed inputs such as "8.0;8.0+PTX" or "8.0+PTX;8.0", PyTorch + # canonicalizes the architecture to one sm_80 entry plus one compute_80 + # PTX entry. Dedupe by the canonical numeric arch so we match. + expected_arch_list = "7.5;8.0+PTX" + expected_args = [ + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_80,code=compute_80", + ] + + for arch_input in ("8.0;8.0+PTX;7.5", "7.5;8.0+PTX;8.0", "8.0+PTX;7.5;8.0", "8.0;7.5;8.0+PTX"): + builder = make_builder(jit_mode=False) + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": arch_input}, clear=False): + args = builder.compute_capability_args() + assert os.environ["TORCH_CUDA_ARCH_LIST"] == expected_arch_list, arch_input + assert args == expected_args, arch_input + + +def test_non_jit_branch_canonical_dedupe_mixed_ptx_combinations(): + # Lock in the four mixed-PTX combinations for a single arch so the dedupe + # behavior cannot regress on either ordering or duplication. + builder = make_builder(jit_mode=False) + cases = [ + ("8.0;8.0+PTX", "8.0+PTX", ["-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_80,code=compute_80"]), + ("8.0+PTX;8.0", "8.0+PTX", ["-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_80,code=compute_80"]), + ("8.0;8.0", "8.0", ["-gencode=arch=compute_80,code=sm_80"]), + ("8.0+PTX;8.0+PTX", "8.0+PTX", + ["-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_80,code=compute_80"]), + ] + for arch_input, expected_arch_list, expected_args in cases: + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": arch_input}, clear=False): + args = builder.compute_capability_args() + assert os.environ["TORCH_CUDA_ARCH_LIST"] == expected_arch_list, arch_input + assert args == expected_args, arch_input diff --git a/tests/unit/ops/transformer/inference/__init__.py b/tests/unit/ops/transformer/inference/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/ops/transformer/inference/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py new file mode 100644 index 000000000000..d63c51267e51 --- /dev/null +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch +from deepspeed.accelerator import get_accelerator + +TOLERANCES = None + + +def get_tolerances(): + global TOLERANCES + if TOLERANCES is None: + TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)} + if get_accelerator().is_bf16_supported(): + # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs + # 10 (+1) bits) + TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2) + return TOLERANCES + + +DTYPES = None + + +def get_dtypes(include_float=True): + global DTYPES + if DTYPES is None: + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass + return DTYPES + + +def allclose(x, y, tolerances: Tuple[int, int] = None): + assert x.dtype == y.dtype + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances + return torch.allclose(x, y, rtol=rtol, atol=atol) diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py new file mode 100644 index 000000000000..1ef2da91f72e --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator +from .inference_test_utils import allclose + + +# reference timplementation +def ref_torch_attention(q, k, v, mask, sm_scale): + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = torch.softmax(p.float() + mask, dim=-1).half() + ref_out = torch.matmul(p, v) + return ref_out + + +# test attention operator +@pytest.mark.inference_ops +@pytest.mark.parametrize("BATCH", [1]) # batch +@pytest.mark.parametrize("H", [12]) # heads +@pytest.mark.parametrize("N_CTX", [16, 128]) # sequence length +@pytest.mark.parametrize("D_HEAD", [64, 128]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_flash", [True, False]) +def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + if not deepspeed.HAS_TRITON: + pytest.skip("triton is not installed") + + minus_inf = -65504.0 + dev = deepspeed.accelerator.get_accelerator().device_name() + # skip autotune in testing + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() + + from deepspeed.ops.transformer.inference.triton.attention import _triton_attention, _triton_packed_flash + torch.manual_seed(20) + q = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5) + k = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5) + v = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5) + sm_scale = 0.3 + + # reference implementation + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + score = p + mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device=dev) + M = torch.tril(torch.ones((N_CTX, N_CTX), device=dev)) + if causal: + for z in range(BATCH): + for h in range(H): + mask[:, :, M == 0] = minus_inf + p = torch.softmax(p.float() + mask, dim=-1).half() + softmax_out = p + ref_out = torch.matmul(p, v) + context = ref_out + + # adjust it to expected tensor format and run test + qkv = torch.randn((BATCH, N_CTX, 3 * H * D_HEAD), dtype=dtype, device=dev, requires_grad=False) + qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + + if use_flash: + if not get_accelerator().is_triton_supported(): + pytest.skip("triton flash attention is supported when the compute capability > 8.0") + triton_mask = torch.zeros((BATCH, 1, 1, N_CTX), dtype=dtype, device=dev) + if not causal: + lengths = torch.randint(N_CTX - 8, N_CTX, (BATCH, 1), device=dev) + for i, l in enumerate(lengths): + triton_mask[i, ..., l:] = minus_inf + mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device=dev) + for b in range(BATCH): + mask[b, :, :, lengths[b]:] = minus_inf + ref_out = ref_torch_attention(q, k, v, mask, sm_scale) + tri_out = _triton_packed_flash(qkv, D_HEAD, triton_mask, sm_scale, causal=causal, add_mask=(not causal)) + else: + tri_out = _triton_attention(qkv, + input_mask=mask, + layer_past=None, + alibi=None, + scale=sm_scale, + head_size=D_HEAD, + triangular=False, + use_cuda_flash=False, + use_triton_flash=False, + use_ds_attention=False) + tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) + assert (allclose(ref_out, tri_out)) diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index 36a01f2be8e7..eb283924f73c 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -8,39 +8,28 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_add import BiasAddOp +from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None -torch_minor_version = None - - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - def run_bias_add_reference(activations, bias): return activations + bias def run_bias_add_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_add_fp16(activations, bias) - else: - return inference_module.bias_add_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasAddOp(config)(activations, bias) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("channels", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_add(batch, sequence, channels, dtype): activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) bias_ds = torch.randn((channels), dtype=dtype, device=get_accelerator().device_name()) @@ -50,4 +39,6 @@ def test_bias_add(batch, sequence, channels, dtype): ds_out = run_bias_add_ds(activations_ds, bias_ds) ref_out = run_bias_add_reference(activations_ref, bias_ref) - assert allclose(ds_out, ref_out) + if not allclose(ds_out, ref_out): + print((ds_out - ref_out).abs().max()) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index 477c0a3bc7c7..c995d2a8c46d 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -8,19 +8,13 @@ import deepspeed from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.accelerator import get_accelerator +from deepspeed.ops.transformer.inference.op_binding.gated_activation import GatedActivationOp +from deepspeed.utils.types import ActivationFuncType +from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None -torch_minor_version = None - - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-3, 5e-4), torch.float16: (3e-2, 2e-3), torch.int8: (0, 0)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - def run_bias_geglu_reference(activations, bias): # Expected behavior is that of casting to float32 internally @@ -31,17 +25,14 @@ def run_bias_geglu_reference(activations, bias): def run_bias_geglu_ds(activation, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.bias_geglu(activation, bias) + return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_GELU) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("channels", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_geglu(batch, sequence, channels, dtype): activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name()) bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name()) @@ -49,3 +40,29 @@ def test_bias_geglu(batch, sequence, channels, dtype): ds_out = run_bias_geglu_ds(activation, bias) ref_out = run_bias_geglu_reference(activation, bias) assert (allclose(ds_out, ref_out)) + + +def run_gated_silu_reference(activations, bias): + # Expected behavior is that of casting to float32 internally + # Explicitly using the default GeLU + activations = activations + bias.reshape(1, 1, -1) + hidden_states, gate = activations.chunk(2, dim=-1) + return hidden_states * torch.nn.functional.silu(gate.to(torch.float32)).to(activations.dtype) + + +def run_gated_silu_ds(activation, bias): + return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_SILU) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255]) +@pytest.mark.parametrize("channels", [512, 1232, 4096]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_gated_silu(batch, sequence, channels, dtype): + activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name()) + bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name()) + + ds_out = run_gated_silu_ds(activation, bias) + ref_out = run_gated_silu_reference(activation, bias) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index 1c5e7d58f85a..f0a09245e890 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -8,20 +8,14 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder -from packaging import version as pkg_version +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version +from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None -torch_minor_version = None - - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - def run_bias_gelu_reference(activations, bias): # Expected behavior is that of casting to float32 internally and using the tanh approximation @@ -30,22 +24,17 @@ def run_bias_gelu_reference(activations, bias): def run_bias_gelu_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_gelu_fp16(activations, bias) - else: - return inference_module.bias_gelu_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasGeluOp(config)(activations, bias) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("channels", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_gelu(batch, sequence, channels, dtype): - if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"): + if not required_torch_version(min_version=1.12): pytest.skip("gelu implementation matches only after torch 1.12") activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 50daa221f4cc..69078f9f7646 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -8,19 +8,13 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_relu import BiasReluOp +from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None -torch_minor_version = None - - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - def run_bias_relu_reference(activations, bias): # Expected behavior is that of casting to float32 internally @@ -28,20 +22,15 @@ def run_bias_relu_reference(activations, bias): def run_bias_relu_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_relu_fp16(activations, bias) - else: - return inference_module.bias_relu_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasReluOp(config)(activations, bias) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("channels", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_relu(batch, sequence, channels, dtype): activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) bias_ds = torch.randn((channels), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py new file mode 100644 index 000000000000..a58abfdb100c --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def version_appropriate_gelu(activations): + # gelu behavior changes (correctly) in torch 1.12 + if required_torch_version(min_version=1.12): + return torch.nn.functional.gelu(activations, approximate='tanh') + else: + return torch.nn.functional.gelu(activations) + + +def run_gelu_reference(activations): + # Expected behavior is that of casting to float32 internally and using the tanh approximation + return version_appropriate_gelu(activations.to(torch.float32)).to(activations.dtype) + + +def run_gelu_ds(activations, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import gelu + return gelu(activations) + + device = deepspeed.accelerator.get_accelerator().device_name() + channels = activations.shape[-1] + bias = torch.zeros((channels), dtype=activations.dtype, device=device) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasGeluOp(config)(activations, bias) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255]) +@pytest.mark.parametrize("channels", [512, 1232, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("use_triton_ops", [True, False]) +def test_gelu(batch, sequence, channels, dtype, use_triton_ops): + device = deepspeed.accelerator.get_accelerator().device_name() + activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device) + activations_ref = activations_ds.clone().detach() + + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + ds_out = run_gelu_ds(activations_ds, use_triton_ops) + ref_out = run_gelu_reference(activations_ref) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index c765fd86744d..4a84add16046 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -8,72 +8,94 @@ import pytest from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp +from .inference_test_utils import allclose, get_dtypes +try: + import triton # noqa: F401 # type: ignore + from deepspeed.ops.transformer.inference.triton import ( + layer_norm, + layer_norm_residual, + ) +except ImportError: + print("triton import failed") if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - - -def ref_implementation(vals, gamma, beta, espilon, channels, dtype): +def ref_implementation(vals, gamma, beta, epsilon, channels, dtype): vals_f = vals.to(torch.float32) gamma_f = gamma.to(torch.float32) beta_f = beta.to(torch.float32) - return torch.nn.functional.layer_norm(vals_f, (channels, ), weight=gamma_f, bias=beta_f).to(dtype) + return torch.nn.functional.layer_norm(vals_f, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) def ds_implementation(vals, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.layer_norm(vals, gamma, beta, epsilon) + return LayerNormOp()(vals, gamma, beta, epsilon) + + +def ds_triton_implementation(vals, gamma, beta, epsilon): + return layer_norm(vals, gamma, beta, epsilon) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 32]) @pytest.mark.parametrize("seq_len", [1, 128]) @pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_layer_norm(batch, seq_len, channels, dtype): +@pytest.mark.parametrize("dtype", get_dtypes()) +@pytest.mark.parametrize("use_triton_ops", [False, True]) +def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) epsilon = 1e-5 ref_output = ref_implementation(vals, gamma, beta, epsilon, channels, dtype) - new_output = ds_implementation(vals, gamma, beta, epsilon) + if use_triton_ops: + new_output = ds_triton_implementation(vals, gamma, beta, epsilon) + if dtype != torch.float16: # fp16 supported in triton + return + else: + new_output = ds_implementation(vals, gamma, beta, epsilon) - assert allclose(new_output, ref_output) + if not allclose(new_output, ref_output): + #print(new_output - ref_output) + assert allclose(new_output, ref_output) -def residual_ref_implementation(vals, bias, res, gamma, beta, espilon, channels, dtype): +def residual_ref_implementation(vals, bias, res, gamma, beta, epsilon, channels, dtype): vals_f = vals.to(torch.float32) bias_f = bias.to(torch.float32).reshape(1, 1, -1) res_f = res.to(torch.float32) gamma_f = gamma.to(torch.float32) beta_f = beta.to(torch.float32) - return torch.nn.functional.layer_norm(vals_f + bias_f + res_f, (channels, ), weight=gamma_f, bias=beta_f).to(dtype) + return torch.nn.functional.layer_norm(vals_f + bias_f + res_f, (channels, ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(dtype) def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon) + return LayerNormOp.layer_norm_residual(vals, bias, res, gamma, beta, epsilon) + + +def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon): + return layer_norm_residual(vals, bias, res, gamma, beta, epsilon) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 32]) @pytest.mark.parametrize("seq_len", [1, 128]) @pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_layer_norm_residual(batch, seq_len, channels, dtype): +@pytest.mark.parametrize("dtype", get_dtypes()) +@pytest.mark.parametrize("use_triton_ops", [False, True]) +def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) bias = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) @@ -81,35 +103,41 @@ def test_layer_norm_residual(batch, seq_len, channels, dtype): beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) epsilon = 1e-5 - new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon) + if use_triton_ops: + new_output = residual_ds_triton_implementation(vals, bias, residual, gamma, beta, epsilon) + if dtype != torch.float16: # fp16 supported in triton + return + else: + new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon) + ref_output = residual_ref_implementation(vals, bias, residual, gamma, beta, epsilon, channels, dtype) + print((new_output - ref_output).abs().max()) + assert allclose(new_output, ref_output) -def residual_store_ref_implementation(vals, bias, res, gamma, beta, espilon, channels, dtype): +def residual_store_ref_implementation(vals, bias, res, gamma, beta, epsilon, channels, dtype): vals_f = vals.to(torch.float32) bias_f = bias.to(torch.float32).reshape(1, 1, -1) res_f = res.to(torch.float32) gamma_f = gamma.to(torch.float32) beta_f = beta.to(torch.float32) res_output = vals_f + bias_f + res_f - norm_output = torch.nn.functional.layer_norm(res_output, (channels, ), weight=gamma_f, bias=beta_f).to(dtype) + norm_output = torch.nn.functional.layer_norm(res_output, (channels, ), weight=gamma_f, bias=beta_f, + eps=epsilon).to(dtype) return norm_output, res_output.to(dtype) def residual_store_ds_implementation(vals, bias, res, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.layer_norm_residual_store_pre_ln_res(vals, bias, res, gamma, beta, epsilon) + return LayerNormOp.layer_norm_residual_store_pre_ln_res(vals, bias, res, gamma, beta, epsilon) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 32]) @pytest.mark.parametrize("seq_len", [1, 128]) @pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype): vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) @@ -126,3 +154,38 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype): assert allclose(ds_res_output, norm_res_output) assert allclose(ds_norm_output, ref_norm_output) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("M", [4]) +@pytest.mark.parametrize("N", [4]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("residual", [True, False]) +@pytest.mark.parametrize("input_bias", [True, False]) +def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + dev = get_accelerator().device_name() + torch.manual_seed(0) + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False) + bias = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False) + x_bias = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=dev) + dy = .1 * torch.randn_like(x) + if residual: + res = torch.rand(x_shape, dtype=dtype, device=dev, requires_grad=False) + else: + res = torch.zeros(x_shape, dtype=dtype, device=dev, requires_grad=False) + x.requires_grad_(True) + # forward pass + if residual or input_bias: + y_tri = layer_norm_residual(x, x_bias if input_bias else None, res, weight, bias, eps) + else: + y_tri = layer_norm(x, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, + eps).to(dtype) + # compare + assert (allclose(y_tri, y_ref)) diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py new file mode 100644 index 000000000000..11ca10baff42 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (5e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_matmul_ref(a, b): + return torch.matmul(a, b) + + +def run_matmul_ds(a, b, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import matmul_4d as matmul + return matmul(a, b) + + assert use_triton_ops, "Only triton softmax is supported for now" + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("H", [1, 2, 16]) +@pytest.mark.parametrize("M", [1, 7, 8, 128]) +@pytest.mark.parametrize("K", [2, 5, 16, 128]) +@pytest.mark.parametrize("N", [1, 2, 8, 512]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("use_triton_ops", [True]) +def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + if not deepspeed.HAS_TRITON: + pytest.skip("triton is not installed") + + # skip autotune in testing + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() + + a_ds = torch.randn((B, H, M, K), dtype=dtype, device='cuda') + b_ds = torch.randn((B, H, K, N), dtype=dtype, device='cuda') + a_ref = a_ds.clone().detach() + b_ref = b_ds.clone().detach() + + ds_out = run_matmul_ds(a_ds, b_ds, use_triton_ops) + ref_out = run_matmul_ref(a_ref, b_ref) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py index 79313bd68bdb..dcf9f16baaf1 100644 --- a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py +++ b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py @@ -8,35 +8,26 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.moe_res_matmul import MoEResMatmulOp +from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - - -def allclose(x, y): - assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] - return torch.allclose(x, y, rtol=rtol, atol=atol) - def run_moe_res_matmul_reference(residual, coef1, coef2, output): return residual * coef1 + output * coef2 def run_moe_res_matmul_ds(residual, coef, output): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() coef_t = coef.transpose(-1, -2).contiguous() - return inference_module.moe_res_matmul(residual, coef_t, output) + return MoEResMatmulOp()(residual, coef_t, output) @pytest.mark.inference_ops @pytest.mark.parametrize("hidden_dim", [16, 64]) @pytest.mark.parametrize("c", [1, 4]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_moe_residual_matmul(hidden_dim, c, dtype): residual_ds = torch.randn((c, hidden_dim * c, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) coeff1 = torch.randn((1, 1, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index f5571d33b7bc..cab64d1d0555 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -8,22 +8,36 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding import ResidualAddOp +from .inference_test_utils import get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) +TOLERANCES = None + + +def get_tolerances(): + global TOLERANCES + if TOLERANCES is None: + # Residual add, as a sequence of casted additions, currently requires a higher tolerance + # than the other operators for FP16. We should instead better align the behaviors + # of the reference to match our kernel implementation (TODO(cmikeh2)) + TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 4e-3)} + if get_accelerator().is_bf16_supported(): + # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs + # 10 (+1) bits) + TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2) + return TOLERANCES + def allclose(x, y): assert x.dtype == y.dtype - rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-2)}[x.dtype] + rtol, atol = get_tolerances()[x.dtype] return torch.allclose(x, y, rtol=rtol, atol=atol) -@pytest.fixture(scope="module") -def inference_module(): - return InferenceBuilder().load() - - def res_add_bias_ref(hidden_state, residual, attn_output, attn_bias, final_bias, mp_size=1, pre_attn_norm=True): if pre_attn_norm: hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size @@ -52,13 +66,16 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("hidden_dim", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) @pytest.mark.parametrize("mlp_after_attn", [True, False]) @pytest.mark.parametrize("add_bias", [True, False]) @pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("pre_attn_norm", [True, False]) -def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, - pre_attn_norm): +@pytest.mark.parametrize("use_triton_ops", [True, False]) +def test_residual_add(batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm, + use_triton_ops): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) @@ -73,11 +90,17 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_ ds_out, residual, attn_output, attn_bias, final_bias, mp_size, mlp_after_attn, add_bias, pre_attn_norm ] - if dtype == torch.float16: - ds_out = inference_module.residual_add_bias_fp16(*res_add_args) - elif dtype == torch.float32: - ds_out = inference_module.residual_add_bias_fp32(*res_add_args) + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import residual_add_bias + ds_out = residual_add_bias(*res_add_args) else: - raise ValueError(f"Unsupported dtype: {dtype}") + config = DeepSpeedInferenceConfig(dtype=dtype) + ds_out = ResidualAddOp(config).residual_add_func(*res_add_args) + + if not allclose(ds_out, ref_out): + print((ds_out - ref_out).abs().max()) + print((ds_out - ref_out).abs().mean()) + print((ds_out - ref_out)) + assert (allclose(ds_out, ref_out)) assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_rms_norm.py b/tests/unit/ops/transformer/inference/test_rms_norm.py new file mode 100644 index 000000000000..fde9c9510771 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_rms_norm.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import InferenceBuilder # type: ignore +from deepspeed.ops.transformer.inference.op_binding.pre_rms_norm import PreRMSNormOp +from deepspeed.ops.transformer.inference.op_binding.rms_norm import RMSNormOp +from .inference_test_utils import allclose, get_dtypes + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +def ref_implementation(vals, gamma, epsilon): + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + +def ds_implementation(vals, gamma, epsilon): + return RMSNormOp()(vals, gamma, epsilon) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 32]) +@pytest.mark.parametrize("seq_len", [1, 128]) +@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_rms_norm(batch, seq_len, channels, dtype): + device = get_accelerator().current_device_name() + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=device) + gamma = torch.randn((channels), dtype=dtype, device=device) + epsilon = 1e-5 + + ref_output = ref_implementation(vals, gamma, epsilon) + new_output = ds_implementation(vals, gamma, epsilon) + + assert allclose(new_output, ref_output) + + +def pre_ds_implementation(vals, residual, gamma, epsilon): + return PreRMSNormOp()(vals, residual, gamma, epsilon) + + +def pre_ref_implementation(vals, residual, gamma, epsilon): + residual = vals.to(torch.float32) + residual.to(torch.float32) + vals = residual + + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals, residual.to(gamma.dtype) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 32]) +@pytest.mark.parametrize("seq_len", [1, 128]) +@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_pre_norm(batch, seq_len, channels, dtype): + device = get_accelerator().current_device_name() + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=device) + residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=device) + gamma = torch.randn((channels), dtype=dtype, device=device) + epsilon = 1e-5 + + ref_output = pre_ref_implementation(vals, residual, gamma, epsilon) + new_output = pre_ds_implementation(vals, residual, gamma, epsilon) + + assert allclose(new_output[0], ref_output[0]) + #assert allclose(new_output[1], ref_output[1]) diff --git a/tests/unit/ops/transformer/inference/test_rope.py b/tests/unit/ops/transformer/inference/test_rope.py new file mode 100644 index 000000000000..1f0ca0578e04 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_rope.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("num_heads", [64, 32, 16, 8]) +def test_rope_warp_size_alignment(num_heads): + if get_accelerator().device_name() != "cuda": + pytest.skip("This test runs only on GPU") + + batch = 1 + head = 8 + seq_len = 1024 + head_dim = 32 + rotary_dim = 32 + offset = 8 + rotate_half = False + rope_theta = 2 + + cuda0 = torch.device('cuda:0') + query = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + key = torch.randn(batch, head, seq_len, head_dim, device=cuda0) + + inference = InferenceBuilder().load() + # For num_heads values of 64, 32, 16, 8 + # corresponding threads_per_head (defined in apply_rotary_pos_emb.cu) values are 4, 8, 16, 32 + inference.apply_rotary_pos_emb(query, key, rotary_dim, offset, num_heads, rotate_half, rope_theta) diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py new file mode 100644 index 000000000000..83785ac38ebb --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_softmax_reference(input): + return torch.nn.functional.softmax(input, dim=-1) + + +def run_softmax_ds(input, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import softmax + # return torch.empty_like(input) + return softmax(input) + + assert use_triton_ops, "Only triton softmax is supported for now" + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255, 1232]) +@pytest.mark.parametrize("channels", [512, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_triton_ops", [True]) +def test_softmax(batch, sequence, channels, dtype, use_triton_ops): + if not deepspeed.get_accelerator().is_triton_supported(): + pytest.skip("triton is not supported on this system") + + device = deepspeed.accelerator.get_accelerator().device_name() + input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device) + input_ref = input_ds.clone().detach() + + ds_out = run_softmax_ds(input_ds, use_triton_ops) + ref_out = run_softmax_reference(input_ref) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 58e5392418fa..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -34,7 +34,7 @@ def sequential_model(): @pytest.fixture def simple_config(): config_dict = { - "train_batch_size": 1, + "train_batch_size": 2, "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "optimizer": { @@ -60,8 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True - def test(self, sequential_model, simple_config, batch_input): + @pytest.mark.parametrize("activation_checkpoints", [False, True]) + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -70,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input): pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name()) @@ -82,6 +87,13 @@ def test(self, sequential_model, simple_config, batch_input): model=pipe_model, model_parameters=[p for p in pipe_model.parameters()]) + if activation_checkpoints: + deepspeed.checkpointing.configure(None, + deepspeed_config=pipe_model.config, + partition_activations=True, + contiguous_checkpointing=True, + num_checkpoints=9) + if pipe_model.is_first_stage or pipe_model.is_last_stage: pipe_input = base_input.clone().detach().to(get_accelerator().device_name()) # label 0 is meaningless diff --git a/tests/unit/profiling/flops_profiler/test_flops_profiler.py b/tests/unit/profiling/flops_profiler/test_flops_profiler.py index 04a63195f5a4..c72deecf287f 100644 --- a/tests/unit/profiling/flops_profiler/test_flops_profiler.py +++ b/tests/unit/profiling/flops_profiler/test_flops_profiler.py @@ -9,9 +9,13 @@ from deepspeed.profiling.flops_profiler import get_model_profile from unit.simple_model import SimpleModel, random_dataloader from unit.common import DistributedTest -from unit.util import required_minimum_torch_version +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator -pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=3), +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.3), reason='requires Pytorch version 1.3 or above') diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 0232457a4f9c..dd3bcd7fb6bd 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -8,6 +8,7 @@ import pytest import torch import deepspeed +from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.accelerator import get_accelerator from copy import deepcopy from unit.common import DistributedTest @@ -62,6 +63,8 @@ def _match_outputs(ref, tgt): def _test_activation_checkpoint(module, *inputs): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # Move to device module.to(get_accelerator().device_name()) @@ -82,6 +85,8 @@ def _test_activation_checkpoint(module, *inputs): def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # Move to device module.to(get_accelerator().device_name()) @@ -255,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output): else: ordering += [torch.is_tensor(non_tensor_output)] _test_activation_checkpoint_ordering(module, ordering, inputs) + + +class TestCheckpointableLayersConfig(DistributedTest): + world_size = 1 + + def test_gpt2_checkpointable_layers(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + + # Create a simple topology for testing + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1) + + # Create test classes that we want to checkpoint + class TestTransformerLayer(torch.nn.Module): + + def forward(self, x): + return x + + class ParallelTransformerLayerPipe(TestTransformerLayer): + pass + + class GMLPBlock(TestTransformerLayer): + pass + + # Create a mock GPT2 model with different layer types + class TestGPT2ModelPipe(PipelineModule): + + def __init__(self): + self.layers_spec = [ + LayerSpec(ParallelTransformerLayerPipe), + LayerSpec(GMLPBlock), + LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed + ] + + super().__init__(layers=self.layers_spec, + topology=topo, + checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"]) + + model = TestGPT2ModelPipe() + model.to(get_accelerator().device_name()) + + # Build layers manually for testing + layers = [spec.build() for spec in model.layers_spec] + + # Test that _is_checkpointable returns correct values + assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe + assert model._is_checkpointable([layers[1]]) == True # GMLPBlock + assert model._is_checkpointable([layers[2]]) == False # Linear layer diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing_non_reentrant.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing_non_reentrant.py new file mode 100644 index 000000000000..06e40655e75d --- /dev/null +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing_non_reentrant.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# TODO: add tests with model parallelism for activation partitioning and other features. + +import sys +import torch +import pytest +from importlib import util + +from deepspeed.runtime.activation_checkpointing.checkpointing import non_reentrant_checkpoint +from unit.common import DistributedTest + +# the hack to clone the module `test_activation_checkpointing` and inject +# `non_reentrant_checkpoint` as the `ckpt` of the origin test module +ORG_SPEC = util.find_spec('test_activation_checkpointing') +test_act_ckpt = util.module_from_spec(ORG_SPEC) +ORG_SPEC.loader.exec_module(test_act_ckpt) +sys.modules['test_act_ckpt'] = test_act_ckpt +test_act_ckpt.ckpt = non_reentrant_checkpoint + +HIDDEN_DIM = test_act_ckpt.HIDDEN_DIM + +MaskedLinear = test_act_ckpt.MaskedLinear +MaskedLinearSeq = test_act_ckpt.MaskedLinearSeq +MaskedLinearSeqDup = test_act_ckpt.MaskedLinearSeqDup +DropMaskLinear = test_act_ckpt.DropMaskLinear +LinearNonTensorInput = test_act_ckpt.LinearNonTensorInput +LinearNonTensorOutput = test_act_ckpt.LinearNonTensorOutput + +_test_activation_checkpoint = test_act_ckpt._test_activation_checkpoint +_mixed_mask = test_act_ckpt._mixed_mask +_bool_to_float = test_act_ckpt._bool_to_float +_test_activation_checkpoint_ordering = test_act_ckpt._test_activation_checkpoint_ordering + + +class TestActivationCheckpointWithGrad(test_act_ckpt.TestActivationCheckpoint): + """test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad""" + pass + + +class TestCheckpointNonTensorWithGrad(test_act_ckpt.TestCheckpointNonTensor): + """test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad""" + pass + + +class TestCheckpointNonTensorOutputOrderingWithGrad(test_act_ckpt.TestCheckpointNonTensorOutputOrdering): + """test `non_reentrant_checkpoint` can still checkpoint activations for inputs with grad""" + pass + + +# below classes are used to test the graph with inputs have no grad and parameters has grad, namely partial graph? +@pytest.mark.parametrize('mask', [ + _mixed_mask(), + _bool_to_float(_mixed_mask()), +]) +class TestActivationCheckpointWithoutGrad(DistributedTest): + """test all input tensors without grad""" + world_size = 1 + + def test_ckpt_inputs1_outputs1(self, mask): + module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs) + + def test_ckpt_inputs2_outputs1(self, mask): + module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs, mask) + + def test_ckpt_inputs2_outputs2(self, mask): + module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs, mask) + + def test_ckpt_inputs2_outputs3(self, mask): + module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs, mask) + + def test_ckpt_arg_none(self, mask): + module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM) + inputs = (torch.rand(HIDDEN_DIM), None) + _test_activation_checkpoint(module, *inputs) + + +@pytest.mark.parametrize('non_tensor', [None, 2, True, (None, 2.5), (None, True, torch.randn(HIDDEN_DIM))]) +class TestCheckpointNonTensorWithoutGrad(DistributedTest): + """test all input tensors without grad""" + world_size = 1 + + def test_ckpt_non_tensor_input(self, non_tensor): + module = LinearNonTensorInput(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs, non_tensor) + + def test_ckpt_non_tensor_output(self, non_tensor): + module = LinearNonTensorOutput(non_tensor) + inputs = torch.rand(HIDDEN_DIM) + _test_activation_checkpoint(module, inputs) + + +@pytest.mark.parametrize('non_tensor_output', [ + None, (torch.randn(HIDDEN_DIM), 2.5), (None, torch.randn(HIDDEN_DIM), True), (None, True, torch.randn(HIDDEN_DIM)) +]) +class TestCheckpointNonTensorOutputOrderingWithoutGrad(DistributedTest): + """test all input tensors without grad""" + world_size = 1 + + def test_ckpt_non_tensor_output_ordering(self, non_tensor_output): + module = LinearNonTensorOutput(non_tensor_output) + inputs = torch.rand(HIDDEN_DIM) + + # First return is a tensor + ordering = [True] + if type(non_tensor_output) in [list, tuple]: + ordering += [torch.is_tensor(t) for t in non_tensor_output] + else: + ordering += [torch.is_tensor(non_tensor_output)] + _test_activation_checkpoint_ordering(module, ordering, inputs) diff --git a/tests/unit/runtime/comm/test_coalesced_collectives.py b/tests/unit/runtime/comm/test_coalesced_collectives.py index 8e736c1eaaa6..2d5db192f2ca 100644 --- a/tests/unit/runtime/comm/test_coalesced_collectives.py +++ b/tests/unit/runtime/comm/test_coalesced_collectives.py @@ -7,9 +7,11 @@ """ import torch +import deepspeed import deepspeed.comm as dist -from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce from deepspeed.accelerator import get_accelerator +import pytest from unit.common import DistributedTest @@ -59,3 +61,101 @@ def test(self): assert torch.allclose(output, torch.zeros_like(output)) elif dist.get_rank() == 1: assert output.shape == (0, ) + + +# Currently we cannot test all_to_all_quant_reduce in non-fallback cases because we don't support multinodes tests. +class TestAllToAllQuantReduceFallback(DistributedTest): + world_size = 2 + + def test_1d_tensor(self): + # case 1: 1D tensor + input = torch.zeros((10, ), dtype=torch.half, device=get_accelerator().current_device_name()) + from deepspeed.ops.op_builder import QuantizerBuilder + if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("QuantizerBuilder is not implemented") + output = all_to_all_quant_reduce([input], {})[0] + + if dist.get_rank() == 0: + assert output.shape == (5, ) + assert torch.allclose(output, torch.zeros_like(output)) + elif dist.get_rank() == 1: + assert output.shape == (5, ) + assert torch.allclose(output, torch.zeros_like(output)) + + def test_non_divisible(self): + # case 2: tensor size not divisible by global_world_size + input = torch.zeros((7, 7), dtype=torch.half, device=get_accelerator().current_device_name()) + from deepspeed.ops.op_builder import QuantizerBuilder + if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("QuantizerBuilder is not implemented") + output = all_to_all_quant_reduce([input], {})[0] + + if dist.get_rank() == 0: + assert output.shape == (25, ) + assert torch.allclose(output, torch.zeros_like(output)) + elif dist.get_rank() == 1: + assert output.shape == (24, ) + assert torch.allclose(output, torch.zeros_like(output)) + + +class TestLocoQuantized(DistributedTest): + + world_size = 1 + + @pytest.mark.parametrize("num_bits", [4, 8]) + @pytest.mark.parametrize("tensor_size", [(16, 16), (64, 64)]) + @pytest.mark.parametrize("devices_per_node", [4, 8]) + def test_loco_quantized_reduction(self, num_bits, tensor_size, devices_per_node): + from deepspeed.ops.op_builder import QuantizerBuilder + if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]: + pytest.skip("QuantizerBuilder is not implemented") + + quantizer_module = QuantizerBuilder().load() + + tensor = torch.randn(tensor_size, device='cuda', dtype=torch.half) + + num_nodes = 2 # Fake world size + total_elements = tensor.numel() + total_devices = devices_per_node * num_nodes + num_groups = max(tensor.shape[0], tensor.shape[1], total_devices) + + # Initialize error_feedback tensor + error_feedback = torch.randn(tensor_size, device=tensor.device, dtype=tensor.dtype) + error_feedback_ori = error_feedback.clone() + # Swizzle the original tensor + tensor_reshaped = tensor.reshape(num_nodes, devices_per_node, total_elements // total_devices) + swizzled_tensor = tensor_reshaped.permute(1, 0, 2).reshape(tensor.size()) + + # Perform loco_swizzle_quant + output, scales = quantizer_module.loco_swizzle_quant(tensor, error_feedback, 0.0, num_groups, num_bits, + quantizer_module.Symmetric, 1, num_nodes, + devices_per_node) + + # Compare swizzled_tensor with the output of loco_swizzle_quant + dequantized = quantizer_module.dequantize(output, scales, scales.numel(), num_bits, + quantizer_module.Symmetric).view(tensor.size()) + + assert torch.allclose(swizzled_tensor + error_feedback_ori, dequantized + error_feedback) + + # Calculate elements per group and groups per partition + elements_per_group = total_elements // num_groups + groups_per_partition = num_groups // devices_per_node + + # Reshape dequantized data to match the grouping in loco_quantized_reduction + dequantized_reshaped = dequantized.view(devices_per_node, groups_per_partition, elements_per_group) + + # Perform reduction across devices_per_node dimension + reduced_dequantized = dequantized_reshaped.cumsum(dim=0)[-1] + # Initialize error_feedback tensor + error_feedback = torch.randn(reduced_dequantized.shape, device=tensor.device, dtype=dequantized.dtype) + error_feedback_ori = error_feedback.clone() + + # perform loco_quantized_reduction + output, scales = quantizer_module.loco_quantized_reduction(output, scales, error_feedback, 0.0, num_groups, + num_groups // devices_per_node, num_bits, + quantizer_module.Symmetric, devices_per_node) + + dequantized_reduced = quantizer_module.dequantize(output, scales, scales.numel(), num_bits, + quantizer_module.Symmetric).view(error_feedback.size()) + + assert torch.allclose(reduced_dequantized + error_feedback_ori, dequantized_reduced + error_feedback) diff --git a/tests/unit/runtime/half_precision/onebit/test_onebit.py b/tests/unit/runtime/half_precision/onebit/test_onebit.py index d3b0a90e2fa5..cc8180563db5 100644 --- a/tests/unit/runtime/half_precision/onebit/test_onebit.py +++ b/tests/unit/runtime/half_precision/onebit/test_onebit.py @@ -8,7 +8,6 @@ import deepspeed.comm as dist import deepspeed import pytest -import copy import os import numpy as np @@ -18,12 +17,12 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader from unit.alexnet_model import AlexNetPipe, train_cifar -from unit.util import required_minimum_torch_version +from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator PipeTopo = PipeDataParallelTopology -if not required_minimum_torch_version(major_version=1, minor_version=8): +if not required_torch_version(min_version=1.8): pytest.skip( "NCCL-based 1-bit compression requires torch 1.8 or higher", allow_module_level=True, @@ -34,12 +33,18 @@ pytest.skip("NCCL-based 1-bit compression is not yet supported w. ROCm 5 until cupy supports ROCm 5", allow_module_level=True) +if get_accelerator().device_name() == 'hpu': + pytest.skip("1-bit compression is not supported by HPU.", allow_module_level=True) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) class TestOneBitAdamBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -81,6 +86,8 @@ class TestOneBitAdamExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -126,7 +133,11 @@ def test(self): model=model, model_parameters=optimizer_grouped_parameters, ) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -138,13 +149,15 @@ def test(self): v["exp_avg"], v["exp_avg"].mul_(mask1.to(device=v["exp_avg"].device)), atol=1e-07, - ), f"Momentum mask is not working properly" + ), "Momentum mask is not working properly" class TestOneBitAdamCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -212,17 +225,14 @@ def test(self, tmpdir): }, ] - model_1, optimizer_1, _, _ = deepspeed.initialize( - config=config_dict, - model=model, - model_parameters=optimizer_grouped_parameters_1, - ) - data_loader = random_dataloader( - model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device, - ) + model_1, optimizer_1, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) @@ -231,11 +241,11 @@ def test(self, tmpdir): assert optimizer_1.optimizer.adam_freeze_key is True mask1 = mask1.to(device=optimizer_1.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" save_folder = os.path.join(tmpdir, "saved_checkpoint") model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Momentum mask should not change after saving checkpoint" + atol=1e-07), "Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize( config=config_dict, @@ -245,7 +255,7 @@ def test(self, tmpdir): # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" model_2.load_checkpoint( save_folder, tag=None, @@ -253,11 +263,11 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Momentum mask should not change after loading checkpoint" + atol=1e-07), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_2.state.values(): - assert "worker_error" not in v, f"Incorrect worker error" - assert "server_error" not in v, f"Incorrect server error" + assert "worker_error" not in v, "Incorrect worker error" + assert "server_error" not in v, "Incorrect server error" assert optimizer_2.optimizer.adam_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize( @@ -266,19 +276,18 @@ def test(self, tmpdir): model_parameters=optimizer_grouped_parameters_3, ) optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader( - model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device, - ) + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.adam_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint - assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), f"Incorrect momentum mask" + assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), "Incorrect momentum mask" model_3.load_checkpoint( save_folder, tag=None, @@ -286,14 +295,16 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert ("exp_avg_mask" - not in optimizer_3.param_groups[0]), f"Momentum mask should not change after loading checkpoint" + not in optimizer_3.param_groups[0]), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_3.state.values(): - assert "worker_error" not in v, f"Incorrect worker error" - assert "server_error" not in v, f"Incorrect server error" + assert "worker_error" not in v, "Incorrect worker error" + assert "server_error" not in v, "Incorrect server error" assert optimizer_3.optimizer.adam_freeze_key is False def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -318,7 +329,11 @@ def test_overflow(self, tmpdir): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) save_folder = os.path.join(tmpdir, "saved_checkpoint") for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -334,27 +349,21 @@ def test_overflow(self, tmpdir): @pytest.mark.parametrize( "topo_config", [ - { - "num_pp": 1, - "num_dp": 4 - }, { "num_pp": 2, "num_dp": 2 }, - { - "num_pp": 4, - "num_dp": 1 - }, ], ) class TestOneBitAdamFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, + "train_batch_size": 4, + "grandient_accumulation_steps": 1, "steps_per_print": 20, "optimizer": { "type": "OneBitAdam", @@ -384,20 +393,12 @@ def test(self, topo_config): } topo = PipeTopo(**topo_config) - steps = 500 # Must be >=100 + steps = 100 - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - test_net = copy.deepcopy(init_net) + # TODO: Add correctness tests/asserts comparing with baseline? + test_net = AlexNetPipe() test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar( - test_model, - config=config_dict, - num_steps=steps, - fp16=config_dict["fp16"]["enabled"], - ) + test_losses = train_cifar(test_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled']) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @@ -405,6 +406,8 @@ class TestZeroOneAdamBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -449,6 +452,8 @@ class TestZeroOneAdamExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -497,7 +502,11 @@ def test(self): model=model, model_parameters=optimizer_grouped_parameters, ) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -509,13 +518,15 @@ def test(self): v["exp_avg"], v["exp_avg"].mul_(mask1.to(device=v["exp_avg"].device)), atol=1e-07, - ), f"Momentum mask is not working properly" + ), "Momentum mask is not working properly" class TestZeroOneAdamCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -591,12 +602,11 @@ def test(self, tmpdir): model=model, model_parameters=optimizer_grouped_parameters_1, ) - data_loader = random_dataloader( - model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device, - ) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) @@ -604,11 +614,11 @@ def test(self, tmpdir): # Test whether momentum mask still exist after saving checkpoint mask1 = mask1.to(device=optimizer_1.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" save_folder = os.path.join(tmpdir, "saved_checkpoint") model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Momentum mask should not change after saving checkpoint" + atol=1e-07), "Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize( config=config_dict, @@ -618,7 +628,7 @@ def test(self, tmpdir): # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" model_2.load_checkpoint( save_folder, tag=None, @@ -626,11 +636,11 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Momentum mask should not change after loading checkpoint" + atol=1e-07), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_2.state.values(): - assert "worker_error" not in v, f"Incorrect worker error" - assert "server_error" not in v, f"Incorrect server error" + assert "worker_error" not in v, "Incorrect worker error" + assert "server_error" not in v, "Incorrect server error" model_3, optimizer_3, _, _ = deepspeed.initialize( config=config_dict, @@ -638,18 +648,17 @@ def test(self, tmpdir): model_parameters=optimizer_grouped_parameters_3, ) optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader( - model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device, - ) + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() # Test whether momentum mask stays the same after loading checkpoint - assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), f"Incorrect momentum mask" + assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), "Incorrect momentum mask" model_3.load_checkpoint( save_folder, tag=None, @@ -657,13 +666,15 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert ("exp_avg_mask" - not in optimizer_3.param_groups[0]), f"Momentum mask should not change after loading checkpoint" + not in optimizer_3.param_groups[0]), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_3.state.values(): - assert "worker_error" not in v, f"Incorrect worker error" - assert "server_error" not in v, f"Incorrect server error" + assert "worker_error" not in v, "Incorrect worker error" + assert "server_error" not in v, "Incorrect server error" def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -691,7 +702,11 @@ def test_overflow(self, tmpdir): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) save_folder = os.path.join(tmpdir, "saved_checkpoint") for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -707,27 +722,21 @@ def test_overflow(self, tmpdir): @pytest.mark.parametrize( "topo_config", [ - { - "num_pp": 1, - "num_dp": 4 - }, { "num_pp": 2, "num_dp": 2 }, - { - "num_pp": 4, - "num_dp": 1 - }, ], ) class TestZeroOneAdamFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, + "train_batch_size": 4, + "grandient_accumulation_steps": 1, "steps_per_print": 20, "optimizer": { "type": "ZeroOneAdam", @@ -760,20 +769,12 @@ def test(self, topo_config): } topo = PipeTopo(**topo_config) - steps = 500 # Must be >=100 - - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() + steps = 100 - test_net = copy.deepcopy(init_net) + # TODO: Add correctness tests/asserts comparing with baseline? + test_net = AlexNetPipe() test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar( - test_model, - config=config_dict, - num_steps=steps, - fp16=config_dict["fp16"]["enabled"], - ) + test_losses = train_cifar(test_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled']) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @@ -781,6 +782,8 @@ class TestOneBitLambBasic(DistributedTest): world_size = 2 def test(self, dtype): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -828,6 +831,8 @@ class TestOneBitLampExpAvgMask(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -878,7 +883,11 @@ def test(self): model=model, model_parameters=optimizer_grouped_parameters, ) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -890,13 +899,15 @@ def test(self): v["exp_avg"], v["exp_avg"].mul_(mask1.to(device=v["exp_avg"].device)), atol=1e-07, - ), f"Momentum mask is not working properly" + ), "Momentum mask is not working properly" class TestOneBitLambCheckpointing(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -973,12 +984,11 @@ def test(self, tmpdir): model=model, model_parameters=optimizer_grouped_parameters_1, ) - data_loader = random_dataloader( - model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device, - ) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) @@ -987,15 +997,15 @@ def test(self, tmpdir): assert optimizer_1.optimizer.lamb_freeze_key is True mask1 = mask1.to(device=optimizer_1.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" scaling_coeff_1 = [] for v in optimizer_1.state.values(): - assert "scaling_coeff" in v, f"Incorrect scaling_coeff" + assert "scaling_coeff" in v, "Incorrect scaling_coeff" scaling_coeff_1.append(v["scaling_coeff"]) save_folder = os.path.join(tmpdir, "saved_checkpoint") model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"], mask1, - atol=1e-07), f"Momentum mask should not change after saving checkpoint" + atol=1e-07), "Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize( config=config_dict, @@ -1005,7 +1015,7 @@ def test(self, tmpdir): # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]["exp_avg_mask"].device) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Incorrect momentum mask" + atol=1e-07), "Incorrect momentum mask" model_2.load_checkpoint( save_folder, tag=None, @@ -1013,16 +1023,16 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"], mask2, - atol=1e-07), f"Momentum mask should not change after loading checkpoint" + atol=1e-07), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset - assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" + assert len(optimizer_2.optimizer.worker_errors) == 0, "Incorrect worker error" + assert len(optimizer_2.optimizer.server_errors) == 0, "Incorrect server error" # Test whether scaling_coeffs is loaded correctly scaling_coeff_2 = [] for v in optimizer_2.state.values(): - assert "scaling_coeff" in v, f"Incorrect scaling_coeff" + assert "scaling_coeff" in v, "Incorrect scaling_coeff" scaling_coeff_2.append(v["scaling_coeff"]) - assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" + assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), "Incorrect scaling_coeffs" assert optimizer_2.optimizer.lamb_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize( @@ -1031,19 +1041,18 @@ def test(self, tmpdir): model_parameters=optimizer_grouped_parameters_3, ) optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader( - model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device, - ) + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.lamb_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint - assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), f"Incorrect momentum mask" + assert ("exp_avg_mask" not in optimizer_3.param_groups[0]), "Incorrect momentum mask" model_3.load_checkpoint( save_folder, tag=None, @@ -1051,18 +1060,20 @@ def test(self, tmpdir): load_lr_scheduler_states=True, ) assert ("exp_avg_mask" - not in optimizer_3.param_groups[0]), f"Momentum mask should not change after loading checkpoint" + not in optimizer_3.param_groups[0]), "Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset - assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" + assert len(optimizer_3.optimizer.worker_errors) == 0, "Incorrect worker error" + assert len(optimizer_3.optimizer.server_errors) == 0, "Incorrect server error" # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are reset for v in optimizer_3.state.values(): - assert v["lamb_coeff_freeze"] == 0.0, f"Incorrect lamb_coeff_freeze" - assert v["last_factor"] == 1.0, f"Incorrect last_factor" - assert "scaling_coeff" not in v, f"Incorrect scaling_coeff" + assert v["lamb_coeff_freeze"] == 0.0, "Incorrect lamb_coeff_freeze" + assert v["last_factor"] == 1.0, "Incorrect last_factor" + assert "scaling_coeff" not in v, "Incorrect scaling_coeff" assert optimizer_3.optimizer.lamb_freeze_key is False def test_overflow(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -1093,7 +1104,11 @@ def test_overflow(self, tmpdir): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) save_folder = os.path.join(tmpdir, "saved_checkpoint") for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -1109,27 +1124,21 @@ def test_overflow(self, tmpdir): @pytest.mark.parametrize( "topo_config", [ - { - "num_pp": 1, - "num_dp": 4 - }, { "num_pp": 2, "num_dp": 2 }, - { - "num_pp": 4, - "num_dp": 1 - }, ], ) class TestOneBitLambFP16Pipeline(DistributedTest): world_size = 4 def test(self, topo_config): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, + "train_batch_size": 4, + "grandient_accumulation_steps": 1, "steps_per_print": 20, "optimizer": { "type": "OneBitLamb", @@ -1159,20 +1168,12 @@ def test(self, topo_config): } topo = PipeTopo(**topo_config) - steps = 500 # Must be >=100 + steps = 100 - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - test_net = copy.deepcopy(init_net) + # TODO: Add correctness tests/asserts comparing with baseline? + test_net = AlexNetPipe() test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar( - test_model, - config=config_dict, - num_steps=steps, - fp16=config_dict["fp16"]["enabled"], - ) + test_losses = train_cifar(test_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled']) @pytest.mark.sequential @@ -1180,6 +1181,8 @@ class TestCompressedAllReduceBasic(DistributedTest): world_size = 2 def test(self, tmpdir): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") from deepspeed.runtime.comm.nccl import NcclBackend size = dist.get_world_size() @@ -1241,3 +1244,95 @@ def torch_sim(a): if torch.sum(check_mag_mask) != 0: print("Fails at {} of positions".format(torch.sum(check_mag_mask))) assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 + + +class TestOneBitLambEmptyParameters(DistributedTest): + world_size = 2 + + def test(self): + """Test that OnebitLamb correctly filters out empty parameters (numel=0)""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + # Create a model with normal and empty parameters + class ModelWithEmptyParam(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + # Empty parameter (0 elements) + self.empty_param = torch.nn.Parameter(torch.empty(0, 10)) + + def forward(self, x, y): + return self.cross_entropy_loss(self.linear(x), y) + + model = ModelWithEmptyParam() + model.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + # Create parameter groups including empty parameter + param_groups = [ + { + 'params': [model.linear.weight, model.linear.bias], + 'weight_decay': 0.01 + }, + { + 'params': [model.empty_param], + 'weight_decay': 0.0 + } # Empty parameter + ] + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": get_accelerator().communication_backend_name(), + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1, + }, + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16, + }, + } + + # Verify empty parameter is filtered out + model, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=param_groups, + ) + + # Check that empty parameter is not in optimizer param_groups + for group in optimizer.optimizer.param_groups: + for p in group['params']: + assert p.numel() > 0, "Empty parameters should be filtered out" + + # Run a few training steps to ensure no NaN + data_loader = random_dataloader( + model=model, + total_samples=20, + hidden_dim=10, + device=model.device, + dtype=torch.float16, + ) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Verify no NaN in parameters + for group in optimizer.optimizer.param_groups: + for p in group['params']: + assert not torch.isnan(p).any(), "Parameters should not contain NaN" diff --git a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py index 2a58fd6b4a57..0daf195b3fe2 100644 --- a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py +++ b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py @@ -5,9 +5,12 @@ import torch import deepspeed +from deepspeed.accelerator import get_accelerator +import pytest import numpy as np from unit.common import DistributedTest from unit.simple_model import SimpleModel +from deepspeed.ops.op_builder import FusedLambBuilder def run_model_step(model, gradient_list): @@ -22,6 +25,9 @@ class TestFused(DistributedTest): world_size = 1 def test_no_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -45,18 +51,20 @@ def test_no_overflow(self): expected_loss_scale = 2**8 expected_scale_window = 2 # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale - assert optim.scale_window == expected_scale_window + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.scale_window == expected_scale_window for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)): run_model_step(model, [value]) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == (i + 1) - if optim.cur_iter % expected_scale_window == 0: + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == (i + 1) + if optim.loss_scale_config.cur_iter % expected_scale_window == 0: expected_loss_scale *= 2 def test_all_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -79,17 +87,19 @@ def test_all_overflow(self): expected_loss_scale = 2**4 # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 for i, value in enumerate(overflow_gradients): run_model_step(model, [value]) expected_loss_scale = max(expected_loss_scale / 2, 1) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == (i + 1) + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == (i + 1) def test_some_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -114,39 +124,43 @@ def test_some_overflow(self): expected_scale_window = 2 expected_iteration = 0 # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale - assert optim.scale_window == expected_scale_window + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.scale_window == expected_scale_window # Run model with overflows to decrease scale overflow_gradients = [float('inf'), float('nan')] expected_iteration += len(overflow_gradients) run_model_step(model, overflow_gradients) expected_loss_scale /= (2**len(overflow_gradients)) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration # Run model scale_window + 1 times to increase scale once normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1) expected_iteration += len(normal_gradients) run_model_step(model, normal_gradients) expected_loss_scale *= 2 - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration # Run model with overflows to decrease scale overflow_gradients = [float('inf')] expected_iteration += len(overflow_gradients) run_model_step(model, overflow_gradients) expected_loss_scale /= (2**len(overflow_gradients)) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") class TestUnfused(DistributedTest): world_size = 1 def test_no_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -169,18 +183,23 @@ def test_no_overflow(self): expected_loss_scale = 2**8 expected_scale_window = 2 # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale - assert optim.scale_window == expected_scale_window + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.scale_window == expected_scale_window for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)): run_model_step(model, [value]) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == (i + 1) - if optim.cur_iter % expected_scale_window == 0: + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == (i + 1) + if optim.loss_scale_config.cur_iter % expected_scale_window == 0: expected_loss_scale *= 2 def test_all_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + min_loss_scale_value = 2.0 + config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -195,7 +214,7 @@ def test_all_overflow(self): "loss_scale": 0, "initial_scale_power": 4, "loss_scale_window": 2, - "min_loss_scale": 0.25 + "min_loss_scale": min_loss_scale_value } } hidden_dim = 1 @@ -203,20 +222,22 @@ def test_all_overflow(self): model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) expected_loss_scale = 2**4 - expected_min_loss_scale = 0.25 + expected_min_loss_scale = min_loss_scale_value # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale - assert optim.min_loss_scale == expected_min_loss_scale + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.min_loss_scale == expected_min_loss_scale overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 for i, value in enumerate(overflow_gradients): run_model_step(model, [value]) expected_loss_scale = max(expected_loss_scale / 2, expected_min_loss_scale) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == (i + 1) + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == (i + 1) def test_some_overflow(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -241,30 +262,30 @@ def test_some_overflow(self): expected_scale_window = 2 expected_iteration = 0 # Ensure the dynamic loss scaler is correctly configured. - assert optim.dynamic_loss_scale == True - assert optim.cur_scale == expected_loss_scale - assert optim.scale_window == expected_scale_window + assert optim.loss_scale_config.dynamic_loss_scale == True + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.scale_window == expected_scale_window # Run model with overflows to decrease scale overflow_gradients = [float('inf'), float('nan')] expected_iteration += len(overflow_gradients) run_model_step(model, overflow_gradients) expected_loss_scale /= (2**len(overflow_gradients)) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration # Run model scale_window + 1 times to increase scale once normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1) expected_iteration += len(normal_gradients) run_model_step(model, normal_gradients) expected_loss_scale *= 2 - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration # Run model with overflows to decrease scale overflow_gradients = [float('inf')] expected_iteration += len(overflow_gradients) run_model_step(model, overflow_gradients) expected_loss_scale /= (2**len(overflow_gradients)) - assert optim.cur_scale == expected_loss_scale - assert optim.cur_iter == expected_iteration + assert optim.loss_scale_config.cur_scale == expected_loss_scale + assert optim.loss_scale_config.cur_iter == expected_iteration diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index 6d88af00078a..ac4cefed1db6 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -10,22 +10,23 @@ from deepspeed.ops.adam import FusedAdam from unit.common import DistributedTest from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, SimpleMoEModel, sequence_dataloader -from unit.util import required_torch_version +from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedLambBuilder +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer -try: - from apex import amp # noqa: F401 - _amp_available = True -except ImportError: - _amp_available = False -amp_available = pytest.mark.skipif(not _amp_available, reason="apex/amp is not installed") +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) class TestLambFP32GradClip(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -55,7 +56,11 @@ def test(self): class TestLambFP16(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test__basic(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -74,13 +79,21 @@ def test__basic(self): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test_empty_grad(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -99,7 +112,11 @@ def test_empty_grad(self): model = SimpleModel(hidden_dim, empty_grad=True) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -143,13 +160,19 @@ class TestAdamwFP16Basic(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}} hidden_dim = 10 model = SimpleModel(hidden_dim) optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -160,7 +183,9 @@ class TestFP16OptimizerForMoE(DistributedTest): world_size = 2 def test_unfused_gradnorm(self, monkeypatch): - if not required_torch_version(): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = {"train_batch_size": 2, "steps_per_print": 1, "fp16": {"enabled": True}} @@ -181,14 +206,20 @@ def mock_unscale_and_clip_grads(total_norm, apply_scale=True): optimizer=optimizer, dist_init_required=False) monkeypatch.setattr(optimizer, 'unscale_and_clip_grads', mock_unscale_and_clip_grads) - data_loader = sequence_dataloader(model=engine, total_samples=50, hidden_dim=hidden_dim, device=engine.device) + data_loader = sequence_dataloader(model=engine, + total_samples=50, + hidden_dim=hidden_dim, + device=engine.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = engine(batch[0], batch[1]) engine.backward(loss) engine.step() def test_fused_gradnorm(self, monkeypatch): - if not required_torch_version(): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = {"train_batch_size": 2, "steps_per_print": 1, "fp16": {"enabled": True}} @@ -203,22 +234,32 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) # initialize MoE model = SimpleMoEModel(hidden_dim, ep_size=2) + param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'} + params = split_params_into_different_moe_groups_for_optimizer(param_group) # optimizer = torch.optim.AdamW(params=model.parameters()) - optimizer = FusedAdam(params=model.parameters()) + optimizer = FusedAdam(params=params) engine, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer, dist_init_required=False) monkeypatch.setattr(optimizer, 'unscale_and_clip_grads', mock_unscale_and_clip_grads) - data_loader = sequence_dataloader(model=engine, total_samples=50, hidden_dim=hidden_dim, device=engine.device) + data_loader = sequence_dataloader(model=engine, + total_samples=50, + hidden_dim=hidden_dim, + device=engine.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = engine(batch[0], batch[1]) engine.backward(loss) engine.step() @pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)]) + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + reason="FusedLambBuilder has not been implemented on this system.") def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool): - if not required_torch_version(): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + if not required_torch_version(min_version=1.8): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = { @@ -251,7 +292,11 @@ def mock_unscale_and_clip_grads(total_norm, apply_scale=True): dist_init_required=False) monkeypatch.setattr(optimizer, 'unscale_and_clip_grads', mock_unscale_and_clip_grads) optimizer.fused_lamb_legacy = fused_lamb_legacy - data_loader = sequence_dataloader(model=engine, total_samples=50, hidden_dim=hidden_dim, device=engine.device) + data_loader = sequence_dataloader(model=engine, + total_samples=50, + hidden_dim=hidden_dim, + device=engine.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = engine(batch[0], batch[1]) engine.backward(loss) @@ -262,13 +307,19 @@ class TestAdamwFP16EmptyGrad(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}} hidden_dim = 10 model = SimpleModel(hidden_dim) optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -281,6 +332,8 @@ class TestAdamFP16ZeroOneCycleCompatibility(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -319,20 +372,27 @@ def test(self, zero_stage, use_cpu_offload): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("use_cpu_offload", [True, False]) -@pytest.mark.parametrize("hidden_dim", [9, 10]) class TestZeroStaticScale(DistributedTest): world_size = 1 - def test(self, zero_stage, use_cpu_offload, hidden_dim): + def test(self, zero_stage, use_cpu_offload, hidden_dim=4): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -363,12 +423,18 @@ def test(self, zero_stage, use_cpu_offload, hidden_dim): assert optim.loss_scaler.loss_scale == 138. # Now make sure things work.. - data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("use_cpu_offload", [True, False]) @@ -376,6 +442,8 @@ class TestZeroAllowUntestedOptimizer(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -401,6 +469,7 @@ def test(self, zero_stage, use_cpu_offload): model=model, optimizer=optimizer, model_parameters=model.parameters()) + model.destroy() @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @@ -409,6 +478,8 @@ class TestZeroEmptyPartition(DistributedTest): world_size = 3 def test(self, zero_stage, use_cpu_offload): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") @@ -443,106 +514,17 @@ def test(self, zero_stage, use_cpu_offload): model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) # Now make sure things work.. - data_loader = random_dataloader(model=model, total_samples=1, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - -@amp_available -class TestAmp(DistributedTest): - world_size = 2 - - def test_adam_basic(self): - config_dict = {"train_batch_size": 2, "steps_per_print": 1, "amp": {"enabled": True}} - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - optimizer = torch.optim.Adam(params=model.parameters()) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - def test_lamb_basic(self): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Lamb", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - def test_adam_O2(self): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - "opt_level": "O2" - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - def test_adam_O2_empty_grad(self): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - "opt_level": "O2" - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() + model.destroy() @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @@ -551,6 +533,8 @@ class TestZeroSupportedClientOptimizer(DistributedTest): world_size = 1 def test(self, zero_stage, optimizer_constructor): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -566,12 +550,15 @@ def test(self, zero_stage, optimizer_constructor): model = SimpleModel(hidden_dim) client_optimizer = optimizer_constructor(params=model.parameters()) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=client_optimizer) + model.destroy() class TestZero2ReduceScatterOff(DistributedTest): world_size = 2 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -598,7 +585,11 @@ def test(self): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) @@ -611,6 +602,8 @@ class TestFP16AdamTypes(DistributedTest): world_size = 1 def test(self, adam_type, torch_impl): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -631,7 +624,11 @@ def test(self, adam_type, torch_impl): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for _, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -643,6 +640,8 @@ class TestZero3LazyScatter(DistributedTest): world_size = 1 def test(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -663,21 +662,33 @@ def test(self): hidden_dim = 10 model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + model, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + ) - data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for _, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize('stage', [1, 2, 3]) class TestZeroEmptyGrad(DistributedTest): world_size = 1 def test(self, stage): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -693,8 +704,14 @@ def test(self, stage): model = SimpleModel(hidden_dim) optimizer = torch.optim.Adam(model.parameters()) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + model.destroy() diff --git a/tests/unit/runtime/half_precision/test_fp8.py b/tests/unit/runtime/half_precision/test_fp8.py new file mode 100644 index 000000000000..a54833824bac --- /dev/null +++ b/tests/unit/runtime/half_precision/test_fp8.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest, is_rocm_pytorch +from unit.util import skip_on_arch + +try: + import transformer_engine.pytorch as transformer_engine + from transformer_engine.common import recipe +except ImportError: + pytest.skip("Transformer Engine package is missing, skipping tests", allow_module_level=True) + + +@pytest.mark.parametrize("base_datatype", ["fp16", "bf16", "fp32"]) +class TestFp8ComposabilityAcrossZero(DistributedTest): + world_size = 1 + + def test(self, base_datatype): + skip_on_arch(min_arch=9) + + def run_zero(stage, model_dtype): + num_batches = 128 + batch_size = 16 + hidden_dim = 768 + # Have to set seed before model + torch.random.manual_seed(42) + enable_fp16 = model_dtype == torch.float16 + enable_bf16 = model_dtype == torch.bfloat16 + # TransformerEngine Model + model = transformer_engine.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=model_dtype) + + # Create FP8 recipe. Note: All input args are optional. + fp8_recipe = recipe.DelayedScaling(fp8_format=recipe.Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max") + config = { + "train_batch_size": batch_size, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001 + } + }, + "zero_optimization": { + "stage": stage + }, + "fp16": { + "enabled": enable_fp16, + "loss_scale": 0.1 + }, + "bf16": { + "enabled": enable_bf16 + } + } + # Init DeepSpeed + model, optimizer, _, _ = deepspeed.initialize(args=None, + model=model, + model_parameters=model.parameters(), + config=config) + + batches = torch.randn(num_batches, batch_size, hidden_dim, device=model.device, dtype=model_dtype) + for batch in batches: + # Enables autocasting for the forward pass + with transformer_engine.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + out = model(batch) + loss = out.mean() + model.backward(loss) + model.step() + return loss + + if base_datatype == "fp16": + model_dtype = torch.float16 + elif base_datatype == "bf16": + model_dtype = torch.bfloat16 + else: + model_dtype = torch.float32 + + # Set default tolerances + rtol, atol = 1e-07, 1e-05 + + # Relax tolerance only for ROCm + FP16 + if is_rocm_pytorch() and base_datatype in ["fp16", "bf16"]: + rtol, atol = 1e-07, 1e-04 + + # config + zero_stage = [0, 1, 2, 3] + losses = [] + for stage in zero_stage: + loss = run_zero(stage, model_dtype) + losses.append(loss) + all_equal = all(torch.allclose(loss, losses[0], rtol, atol) for loss in losses) + assert (all_equal) diff --git a/tests/unit/runtime/half_precision/test_zero_optim_overflow.py b/tests/unit/runtime/half_precision/test_zero_optim_overflow.py new file mode 100644 index 000000000000..62995fbe104d --- /dev/null +++ b/tests/unit/runtime/half_precision/test_zero_optim_overflow.py @@ -0,0 +1,365 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator +import pytest +import numpy as np +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.utils import safe_set_full_grad + + +def has_inf_or_nan(x): + float_x = x.float() + nan = float_x.isnan() + inf = float_x.isinf() + inf_or_nan = nan.logical_or(inf) + return inf_or_nan.float().max() + + +def run_model_step(model, x_sample, y_label, grad_value): + loss = model(x_sample, y_label) + model.backward(loss) + for p in model.parameters(): + grad = torch.empty_like(p, dtype=p.dtype) + grad.fill_(grad_value) + safe_set_full_grad(p, grad) + model.step() + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("offload_optimizer", [False, True]) +class TestZeROFloat16(DistributedTest): + world_size = 2 + + def test_no_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 8, + "loss_scale_window": 2 + }, + "zero_optimization": { + "stage": zero_stage + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**8 + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + num_iterations = 10 + grad_values = np.random.uniform(-0.1, 0.1, num_iterations) + data_loader = random_dataloader(model=model, + total_samples=num_iterations, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for i, (batch, grad_value) in enumerate(zip(data_loader, grad_values)): + run_model_step(model, batch[0], batch[1], grad_value) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == (i + 1) + + if loss_scaler.cur_iter % expected_scale_window == 0: + expected_loss_scale *= 2 + + def test_all_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + initial_scale_power = len(overflow_gradients) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": initial_scale_power, + "loss_scale_window": 2, + "hysteresis": 1, + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**initial_scale_power + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + expected_loss_scale = max(expected_loss_scale / 2, 1) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == (i + 1) + + def test_some_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + initial_scale_power = 8 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": initial_scale_power, + "loss_scale_window": 2, + "hysteresis": 1, + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + expected_loss_scale = 2**initial_scale_power + expected_scale_window = 2 + # Ensure the dynamic loss scaler is correctly configured. + loss_scaler = optim.loss_scaler + + assert optim.dynamic_loss_scale == True + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.scale_window == expected_scale_window + + expected_iteration = 0 + + # Run model with overflows to decrease scale + overflow_gradients = [float('inf'), float('nan')] + expected_iteration += len(overflow_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, overflow_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale /= (2**len(overflow_gradients)) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + # Run model scale_window + 1 times to increase scale once + normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1) + expected_iteration += len(normal_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(normal_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, normal_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale *= 2 + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + # Run model with overflows to decrease scale + overflow_gradients = [float('inf')] + expected_iteration += len(overflow_gradients) + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + for batch, grad_value in zip(data_loader, overflow_gradients): + run_model_step(model, batch[0], batch[1], grad_value) + + expected_loss_scale /= (2**len(overflow_gradients)) + assert loss_scaler.cur_scale == expected_loss_scale + assert loss_scaler.cur_iter == expected_iteration + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("offload_optimizer", [False, True]) +class TestZeROBFloat16(DistributedTest): + world_size = 2 + + def test_no_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + }, + "zero_optimization": { + "stage": zero_stage + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + num_iterations = 10 + grad_values = np.random.uniform(-0.1, 0.1, num_iterations) + data_loader = random_dataloader(model=model, + total_samples=num_iterations, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for i, (batch, grad_value) in enumerate(zip(data_loader, grad_values)): + run_model_step(model, batch[0], batch[1], grad_value) + + assert model.skipped_steps == 0 + assert all([not has_inf_or_nan(p) for p in model.parameters()]) + + def test_detect_grad_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + "check_grad_overflow": True + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + assert model.skipped_steps == (i + 1) + + assert all([not has_inf_or_nan(p) for p in model.parameters()]) + + def test_ignore_grad_overflow(self, zero_stage, offload_optimizer): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "bf16": { + "enabled": True, + "check_grad_overflow": False + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + model, optim, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + + overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6 + data_loader = random_dataloader(model=model, + total_samples=len(overflow_gradients), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + + for i, (batch, grad_value) in enumerate(zip(data_loader, overflow_gradients)): + run_model_step(model, batch[0], batch[1], grad_value) + + assert model.skipped_steps == 0 + assert all([has_inf_or_nan(p) for p in model.parameters()]) diff --git a/tests/unit/runtime/pipe/test_pipe.py b/tests/unit/runtime/pipe/test_pipe.py index c4958b721f2c..f198762c5fcc 100644 --- a/tests/unit/runtime/pipe/test_pipe.py +++ b/tests/unit/runtime/pipe/test_pipe.py @@ -7,15 +7,43 @@ import torch.nn as nn import pytest +import torch + +import deepspeed import deepspeed.comm as dist from deepspeed.runtime.pipe.topology import PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule from unit.alexnet_model import AlexNetPipe, train_cifar from unit.common import DistributedTest -from unit.util import skip_on_arch +from unit.util import skip_on_arch, no_child_process_in_deepspeed_io PipeTopo = PipeDataParallelTopology +config_dict = { + "train_batch_size": 4, + "grandient_accumulation_steps": 1, + "steps_per_print": 20, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": False + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } +} + def rel_diff(A, B): return abs(A - B) / abs(A) @@ -38,36 +66,10 @@ def rel_diff(A, B): class TestPipeCifar10(DistributedTest): world_size = 4 - def test(self, topo_config): + def test_pipe_base(self, topo_config): skip_on_arch(min_arch=7) - - config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, - "steps_per_print": 20, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [0.9, 0.999], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": False - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - topo = PipeTopo(**topo_config) - steps = 500 # must be >=100 + steps = 100 # must be >=100 # Allocate model for consistent initial weights. init_net = AlexNetPipe() @@ -103,3 +105,148 @@ def test(self, topo_config): test = test_losses[-lastX:] test_avg = sum(test) / len(test) assert rel_diff(base_avg, test_avg) < 0.05 # Originally 0.03, but seeing instability with AMD results + + # def _check_model_params_equal(self, model1, model2): + # for p1, p2 in zip(model1.parameters(), model2.parameters()): + # if p1.data.ne(p2.data).sum() > 0: + # assert False, f"model params not equal" + + def test_pipe_use_reentrant(self, topo_config): + skip_on_arch(min_arch=7) + + topo = PipeTopo(**topo_config) + steps = 100 # must be >=100 + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + # Train with not set use_reentrant, default: True + base_net = copy.deepcopy(init_net) + base_model = PipelineModule(layers=base_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) + base_losses = train_cifar(base_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled']) + + # Train with set use_reentrant=False, this will use ``non_reentrant_checkpoint`` + test_config_dict = copy.deepcopy(config_dict) + test_config_dict['pipeline']['use_reentrant'] = False + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) + test_losses = train_cifar(test_model, + config=test_config_dict, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + abs_diffs = [l0 - l1 for l0, l1 in zip(base_losses, test_losses)] + rel_diffs = [rel_diff(l0, l1) for l0, l1 in zip(base_losses, test_losses)] + if dist.get_rank() == 0: + print(f'abs min={min(abs_diffs)} max={max(abs_diffs)} avg={sum(abs_diffs)/len(abs_diffs)}') + print(f'rel min={min(rel_diffs)} max={max(rel_diffs)} avg={sum(rel_diffs)/len(rel_diffs)}') + print(f'first: base={base_losses[0]} test={test_losses[0]} abs={abs_diffs[0]} rel={rel_diffs[0]}') + + for lastX in [1, 10, 100]: + base_avg = sum(base_losses[-lastX:]) / lastX + test_avg = sum(test_losses[-lastX:]) / lastX + print( + f'last-{lastX}: base={base_avg} test={test_avg} abs={base_avg - test_avg} rel={rel_diff(base_avg, test_avg)}' + ) + lastX = 100 + base = base_losses[-lastX:] + base_avg = sum(base) / len(base) + test = test_losses[-lastX:] + test_avg = sum(test) / len(test) + assert rel_diff(base_avg, test_avg) < 0.05 + + # the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1) + # Check if models have same weights after training + # self._check_model_params_equal(base_model, test_model) + + +class DynamicShapeTestLayer(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.fc = nn.Linear(hidden_size, hidden_size) + self.shapes = set() + + def forward(self, x): + self.shapes.add(x.shape) + y = self.fc(x) + return y + + +class DynamicShapeTestModel(nn.Module): + + def __init__(self, n_layers, hidden_size): + super().__init__() + self.layers = nn.ModuleList([DynamicShapeTestLayer(hidden_size) for _ in range(n_layers)]) + + +@pytest.mark.parametrize('topo_config', [ + { + "num_pp": 1, + "num_dp": 4 + }, + { + "num_pp": 2, + "num_dp": 2 + }, + { + "num_pp": 4, + "num_dp": 1 + }, +]) +class TestPipeDynamicShape(DistributedTest): + world_size = 4 + + def test_pipe_base(self, topo_config): + """This test checks if the pipeline engine can handle dynamic shapes correctly. + We pass inputs of different shapes to the pipeline engine. + """ + + n_iter = 10 + n_layers = 4 + n_samples = 1024 + batch_size = 4 + channel_dims = [8, 16, 32, 64] + hidden_size = 16 + + topo = PipeTopo(**topo_config) + + model = DynamicShapeTestModel(n_layers, hidden_size) + model = PipelineModule(layers=model.layers, topology=topo, loss_fn=nn.MSELoss(), dynamic_shape=True) + + # Each batch has different channel dim but we use the same channel dim in the same batch + xs = [ + torch.randn(channel_dims[(i // batch_size) % len(channel_dims)], hidden_size, dtype=torch.float32) + for i in range(n_samples) + ] + ys = [torch.randn_like(x) for x in xs] + + class CustomDataset(torch.utils.data.Dataset): + + def __init__(self, xs, ys): + self.xs = xs + self.ys = ys + + def __len__(self): + return len(self.xs) + + def __getitem__(self, idx): + return self.xs[idx], self.ys[idx] + + dataset = CustomDataset(xs, ys) + + config_dict["train_batch_size"] = batch_size + + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=dataset) + + for _ in range(n_iter): + _ = engine.train_batch() + + # Check if all layers have seen different shapes + for layer in model.modules(): + if isinstance(layer, DynamicShapeTestLayer): + assert len(layer.shapes) > 1 diff --git a/tests/unit/runtime/sparse_tensor/test_averaging_sparse_gradients.py b/tests/unit/runtime/sparse_tensor/test_averaging_sparse_gradients.py index 92da2257bdb0..badd0bcee549 100644 --- a/tests/unit/runtime/sparse_tensor/test_averaging_sparse_gradients.py +++ b/tests/unit/runtime/sparse_tensor/test_averaging_sparse_gradients.py @@ -4,9 +4,14 @@ # DeepSpeed Team import torch +import pytest import deepspeed from unit.common import DistributedTest from unit.util import skip_on_arch +from deepspeed.accelerator import get_accelerator + +if get_accelerator().device_name() == 'hpu': + pytest.skip("sparse_gradients not supported by HPU.", allow_module_level=True) class Model(torch.nn.Module): diff --git a/tests/unit/runtime/sparse_tensor/test_sparse_grads.py b/tests/unit/runtime/sparse_tensor/test_sparse_grads.py index 0689adc08670..6338a16b8dbb 100644 --- a/tests/unit/runtime/sparse_tensor/test_sparse_grads.py +++ b/tests/unit/runtime/sparse_tensor/test_sparse_grads.py @@ -4,11 +4,15 @@ # DeepSpeed Team import torch +import pytest import deepspeed from unit.common import DistributedTest - +from deepspeed.accelerator import get_accelerator import deepspeed.utils.groups as groups +if get_accelerator().device_name() == 'hpu': + pytest.skip("sparse_gradients not supported by HPU.", allow_module_level=True) + class Model(torch.nn.Module): diff --git a/tests/unit/runtime/tensor_parallel/test_autotp_universal_checkpoint.py b/tests/unit/runtime/tensor_parallel/test_autotp_universal_checkpoint.py new file mode 100644 index 000000000000..f0f7315562f2 --- /dev/null +++ b/tests/unit/runtime/tensor_parallel/test_autotp_universal_checkpoint.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.checkpoint.constants import (PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, PARAMETER_WITH_SUB_PARAMS, + TP_REPLICATED_PARAMETER_PATTERNS, DS_AUTOTP_UC_META) +from deepspeed.module_inject.layers import (_build_param_uc_restore_meta, _get_param_uc_conversion_meta, + LinearAllreduce, LinearLayer, SubParamLinearLayer, + collect_autotp_universal_checkpoint_info) + + +def test_collect_autotp_universal_checkpoint_info_row_parallel(): + layer = LinearAllreduce(torch.nn.Linear(16, 8, bias=True), mp_group=None, name="proj") + model = torch.nn.Module() + model.proj = layer + + uc_info = collect_autotp_universal_checkpoint_info(model) + + # collect_autotp_universal_checkpoint_info() stores regex patterns like r"^proj\.weight$" + assert r"^proj\.weight$" in uc_info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] + # bias in LinearAllreduce is marked replicated, so it should appear in replicated patterns + assert r"^proj\.bias$" in uc_info[TP_REPLICATED_PARAMETER_PATTERNS] + + +def test_collect_autotp_universal_checkpoint_info_subparams(): + layer = SubParamLinearLayer(torch.nn.Linear(12, 12, bias=True), + mp_group=None, + shape=(3, -1), + partition_dim=0, + name="qkv") + model = torch.nn.Module() + model.qkv = layer + + uc_info = collect_autotp_universal_checkpoint_info(model) + + assert len(uc_info[PARAMETER_WITH_SUB_PARAMS]) == 1 + assert uc_info[PARAMETER_WITH_SUB_PARAMS][0]["partition_dim"] == 0 + + +def test_collect_autotp_universal_checkpoint_info_column_parallel_bias_not_replicated(): + layer = LinearLayer(torch.nn.Linear(16, 8, bias=True), mp_group=None, name="dense") + model = torch.nn.Module() + model.dense = layer + + uc_info = collect_autotp_universal_checkpoint_info(model) + + assert not any("dense.weight" in p for p in uc_info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS]) + assert not any("dense.bias" in p for p in uc_info[TP_REPLICATED_PARAMETER_PATTERNS]) + + +def test_collect_autotp_universal_checkpoint_info_subparams_preserves_shape_metadata(): + layer = SubParamLinearLayer(torch.nn.Linear(12, 12, bias=True), + mp_group=None, + shape=((2, 10), 12), + partition_dim=0, + name="fused") + model = torch.nn.Module() + model.fused = layer + + uc_info = collect_autotp_universal_checkpoint_info(model) + + assert uc_info[PARAMETER_WITH_SUB_PARAMS][0]["shape"] == [(2, 10), 12] + + +def test_subparam_layer_marks_standardized_param_metadata(): + layer = SubParamLinearLayer(torch.nn.Linear(12, 12, bias=True), + mp_group=None, + shape=(3, -1), + partition_dim=0, + name="packed") + + weight_meta = getattr(layer.weight, DS_AUTOTP_UC_META) + bias_meta = getattr(layer.bias, DS_AUTOTP_UC_META) + + assert weight_meta["sub_param_sizes"] == (4, 4, 4) + assert tuple(weight_meta["target_partition_shape"]) == tuple(layer.weight.shape) + assert tuple(bias_meta["target_partition_shape"]) == tuple(layer.bias.shape) + + +def test_universal_checkpoint_info_excludes_param_level_recovery_fields(): + layer = SubParamLinearLayer(torch.nn.Linear(12, 12, bias=True), + mp_group=None, + shape=(3, -1), + partition_dim=0, + name="packed") + model = torch.nn.Module() + model.packed = layer + + uc_info = collect_autotp_universal_checkpoint_info(model) + subparam_entry = uc_info[PARAMETER_WITH_SUB_PARAMS][0] + + assert "shape" in subparam_entry + assert "partition_dim" in subparam_entry + assert "patterns" in subparam_entry + assert "sub_param_sizes" not in subparam_entry + assert "target_partition_shape" not in subparam_entry + + +def test_collect_uses_conversion_view_not_recovery_fields(): + layer = SubParamLinearLayer(torch.nn.Linear(12, 12, bias=True), + mp_group=None, + shape=(3, -1), + partition_dim=0, + name="packed") + model = torch.nn.Module() + model.packed = layer + + meta = getattr(layer.weight, "ds_autotp_universal_checkpoint_meta") + meta["partition_dim"] = 99 + meta["sub_param_shape"] = (999, -1) + + uc_info = collect_autotp_universal_checkpoint_info(model) + subparam_entry = uc_info[PARAMETER_WITH_SUB_PARAMS][0] + + assert subparam_entry["partition_dim"] == 0 + assert subparam_entry["shape"] == [3, -1] + + +def test_param_uc_restore_builder_normalizes_shapes_and_nests_conversion_view(): + restore_meta = _build_param_uc_restore_meta(partition_type="column", + partition_dim=0, + logical_shape=[12, 8], + output_shape=[12], + sub_param_shape=[3, -1], + sub_param_sizes=[4, 4, 4], + target_partition_shape=torch.Size([4, 8]), + original_shape=torch.Size([12, 8]), + is_bias=False, + replicated=False) + + assert restore_meta["logical_shape"] == (12, 8) + assert restore_meta["output_shape"] == (12, ) + assert restore_meta["sub_param_shape"] == (3, -1) + assert restore_meta["sub_param_sizes"] == (4, 4, 4) + assert restore_meta["target_partition_shape"] == (4, 8) + assert restore_meta["original_shape"] == (12, 8) + assert restore_meta["conversion"] == { + "partition_type": "column", + "partition_dim": 0, + "sub_param_shape": (3, -1), + "original_shape": (12, 8), + "is_bias": False, + "replicated": False, + } + + +def test_conversion_helper_reads_builder_nested_view(): + param = torch.nn.Parameter(torch.zeros(4, 8)) + param.ds_autotp_universal_checkpoint_meta = _build_param_uc_restore_meta(partition_type="row", + partition_dim=1, + logical_shape=[4, 16], + output_shape=[4], + original_shape=[4, 16], + is_bias=False, + replicated=False) + + assert _get_param_uc_conversion_meta(param) == param.ds_autotp_universal_checkpoint_meta["conversion"] diff --git a/tests/unit/runtime/test_autocast.py b/tests/unit/runtime/test_autocast.py index 9176770afda7..21ffc9bfbb4d 100644 --- a/tests/unit/runtime/test_autocast.py +++ b/tests/unit/runtime/test_autocast.py @@ -3,10 +3,14 @@ # DeepSpeed Team +import functools + import pytest import torch +import deepspeed.runtime.zero.linear as zero_linear from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3 from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest @@ -26,8 +30,6 @@ def test_missing_amp_autocast(self, half_op): assert output.dtype == ds_linear.weight.dtype def test_disable_autocast_linear(self, half_op): - amp = get_accelerator().amp() - hidden_dim = 4 if half_op: input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half() @@ -36,18 +38,15 @@ def test_disable_autocast_linear(self, half_op): input = torch.randn(hidden_dim).to(get_accelerator().device_name()) ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()) - with amp.autocast(False): + with torch.amp.autocast(device_type=get_accelerator().device_name(), enabled=False): output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype -@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed') @pytest.mark.parametrize('half_input, half_weight', [(False, False), (False, True), (True, False), (True, True)]) class TestAutoCastEnable(DistributedTest): def test_autocast_linear(self, tmpdir, half_input, half_weight): - amp = get_accelerator().amp() - hidden_dim = 4 input = torch.randn(hidden_dim).to(get_accelerator().device_name()) ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()) @@ -58,6 +57,33 @@ def test_autocast_linear(self, tmpdir, half_input, half_weight): if half_weight: ds_linear = ds_linear.half() - with amp.autocast(): + with torch.amp.autocast(device_type=get_accelerator().device_name()): output = ds_linear(input) assert output.dtype == torch.half or output.dtype == torch.bfloat16 + + +def test_get_autocast_decorators_use_torch_amp_on_torch_2_4_or_newer(): + if not required_torch_version(min_version=2.4): + pytest.skip('torch.amp.custom_fwd/custom_bwd are only available on torch >= 2.4') + + device_type = get_accelerator().device_name() + + assert isinstance(zero_linear.autocast_custom_fwd, functools.partial) + assert isinstance(zero_linear.autocast_custom_bwd, functools.partial) + assert zero_linear.autocast_custom_fwd.func is torch.amp.custom_fwd + assert zero_linear.autocast_custom_bwd.func is torch.amp.custom_bwd + assert zero_linear.autocast_custom_fwd.keywords == {'device_type': device_type} + assert zero_linear.autocast_custom_bwd.keywords == {'device_type': device_type} + + +def test_get_autocast_decorators_use_legacy_amp_or_noop_before_torch_2_4(): + if required_torch_version(min_version=2.4): + pytest.skip('legacy AMP fallback only applies on torch < 2.4') + + device_type = get_accelerator().device_name() + legacy_amp = getattr(getattr(torch, device_type, None), 'amp', None) + expected_custom_fwd = getattr(legacy_amp, 'custom_fwd', zero_linear.noop_decorator) + expected_custom_bwd = getattr(legacy_amp, 'custom_bwd', zero_linear.noop_decorator) + + assert zero_linear.autocast_custom_fwd is expected_custom_fwd + assert zero_linear.autocast_custom_bwd is expected_custom_bwd diff --git a/tests/unit/runtime/test_data.py b/tests/unit/runtime/test_data.py index 8f71ca979b4d..7ae0814c823a 100644 --- a/tests/unit/runtime/test_data.py +++ b/tests/unit/runtime/test_data.py @@ -42,6 +42,7 @@ def test(self, train_batch_size, drop_last): model=model, training_data=train_dataset, optimizer=optimizer) + training_dataloader.num_local_io_workers = 0 # We can't do nested mp.pool for n, batch in enumerate(training_dataloader): x = batch[0].to(get_accelerator().current_device_name()) y = batch[1].to(get_accelerator().current_device_name()) diff --git a/tests/unit/runtime/test_data_efficiency.py b/tests/unit/runtime/test_data_efficiency.py index b9bd9c3aa56e..61a91c9abbe9 100644 --- a/tests/unit/runtime/test_data_efficiency.py +++ b/tests/unit/runtime/test_data_efficiency.py @@ -7,6 +7,7 @@ import os import deepspeed from deepspeed.accelerator import get_accelerator +import pytest from unit.common import DistributedTest from unit.simple_model import Curriculum_SimpleModel, SimpleModel, random_dataloader, random_dataset @@ -49,10 +50,16 @@ def get_model_parallel_group(self): return self.tp_group +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) class TestDataEfficiency(DistributedTest): world_size = 2 - def test_curriculum_learning(self): + def test_curriculum_learning(self, dtype): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -64,11 +71,6 @@ def test_curriculum_learning(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "data_efficiency": { "enabled": True, "seed": 1234, @@ -99,13 +101,18 @@ def test_curriculum_learning(self): } } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: + config_dict["bf16"] = {"enabled": True} + def data_post_process(data, data_sampler_state_dict): assert 'dummy_metric' in data_sampler_state_dict['current_difficulties'] return data hidden_dim = 10 model = SimpleModel(hidden_dim) - dataset = random_dataset(20, hidden_dim, torch.device('cpu'), dtype=torch.half) + dataset = random_dataset(20, hidden_dim, torch.device('cpu'), dtype=dtype) model, _, data_loader, _ = deepspeed.initialize(config=config_dict, model=model, training_data=dataset, @@ -124,10 +131,16 @@ def data_post_process(data, data_sampler_state_dict): break +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) class TestLegacyCurriculumScheduler(DistributedTest): world_size = 2 - def test_fixed_discrete(self): + def test_fixed_discrete(self, dtype): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -139,11 +152,6 @@ def test_fixed_discrete(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "curriculum_learning": { "enabled": True, "curriculum_type": "seqlen", @@ -156,12 +164,20 @@ def test_fixed_discrete(self): } } } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4} model = Curriculum_SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) for n, batch in enumerate(data_loader): loss, seqlen = model(batch[0], batch[1]) model.backward(loss) @@ -169,9 +185,14 @@ def test_fixed_discrete(self): true_seqlen = 5 if n + 1 in ground_truths: true_seqlen = ground_truths[n + 1] - assert seqlen == true_seqlen, f"Incorrect curriculum schedule" + assert seqlen == true_seqlen, f"Incorrect curriculum schedule {n=}, {seqlen=}, {true_seqlen=}" + + def test_fixed_linear(self, dtype): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") - def test_fixed_linear(self): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -183,11 +204,6 @@ def test_fixed_linear(self): } }, "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, "curriculum_learning": { "enabled": True, "curriculum_type": "seqlen", @@ -200,16 +216,24 @@ def test_fixed_linear(self): } } } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: + config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10} model = Curriculum_SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) for n, batch in enumerate(data_loader): loss, seqlen = model(batch[0], batch[1]) model.backward(loss) model.step() if n + 1 in ground_truths: true_seqlen = ground_truths[n + 1] - assert seqlen == true_seqlen, f"Incorrect curriculum schedule" + assert seqlen == true_seqlen, "Incorrect curriculum schedule" diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index 6cd01644fad5..789818b05bef 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -9,6 +9,7 @@ import json import hjson import argparse +import torch from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.accelerator import get_accelerator @@ -19,7 +20,8 @@ # A test on its own import deepspeed -from deepspeed.runtime.config import DeepSpeedConfig, get_bfloat16_enabled +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.runtime.precision_config import get_bfloat16_config class TestBasicConfig(DistributedTest): @@ -47,9 +49,6 @@ def base_config(): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } return config_dict @@ -70,13 +69,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success): if not success: assert not status - print("Failed but All is well") return assert ds_config.train_batch_size == batch assert ds_config.train_micro_batch_size_per_gpu == micro_batch assert ds_config.gradient_accumulation_steps == gas - print("All is well") #Tests different batch config provided in deepspeed json file @@ -90,7 +87,7 @@ class TestBatchConfig(DistributedTest): def test(self, num_ranks, batch, micro_batch, gas, success): assert dist.get_world_size() == num_ranks, \ - 'The test assumes a world size of f{num_ranks}' + f'The test assumes a world size of {num_ranks}' ds_batch_config = get_test_path('ds_batch_config.json') ds_config = DeepSpeedConfig(ds_batch_config) @@ -156,18 +153,26 @@ def test_get_bfloat16_enabled(bf16_key): "enabled": True, }, } - assert get_bfloat16_enabled(cfg) == True + assert get_bfloat16_config(cfg).enabled == True class TestConfigLoad(DistributedTest): world_size = 1 def test_dict(self, base_config): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=base_config, model=model, model_parameters=model.parameters()) def test_json(self, base_config, tmpdir): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} config_path = os.path.join(tmpdir, "config.json") with open(config_path, 'w') as fp: json.dump(base_config, fp) @@ -176,6 +181,10 @@ def test_json(self, base_config, tmpdir): model, _, _, _ = deepspeed.initialize(config=config_path, model=model, model_parameters=model.parameters()) def test_hjson(self, base_config, tmpdir): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} config_path = os.path.join(tmpdir, "config.json") with open(config_path, 'w') as fp: hjson.dump(base_config, fp) @@ -188,6 +197,10 @@ class TestDeprecatedDeepScaleConfig(DistributedTest): world_size = 1 def test(self, base_config, tmpdir): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} config_path = create_config_from_dict(tmpdir, base_config) parser = argparse.ArgumentParser() args = parser.parse_args(args='') @@ -209,6 +222,10 @@ class TestDistInit(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -227,13 +244,27 @@ class TestInitNoOptimizer(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().device_name() == "cpu": + pytest.skip("This test timesout with CPU accelerator") + + # XXX: the bf16 path w/ no optimizer needs to be fixed + # if get_accelerator().is_bf16_supported(): + # base_config["bf16"] = {"enabled": True} + dtype = torch.float + if get_accelerator().is_fp16_supported(): + dtype = torch.float16 + base_config["fp16"] = {"enabled": True} + del base_config["optimizer"] hidden_dim = 10 model = SimpleModel(hidden_dim=hidden_dim) - model, _, _, _ = deepspeed.initialize(config=base_config, model=model) - data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) with pytest.raises(AssertionError): @@ -246,6 +277,10 @@ class TestArgs(DistributedTest): world_size = 1 def test_none_args(self, base_config): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) model, _, _, _ = deepspeed.initialize(args=None, model=model, config=base_config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device) @@ -253,6 +288,10 @@ def test_none_args(self, base_config): loss = model(batch[0], batch[1]) def test_no_args(self, base_config): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) model, _, _, _ = deepspeed.initialize(model=model, config=base_config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device) @@ -264,6 +303,10 @@ class TestNoModel(DistributedTest): world_size = 1 def test(self, base_config): + if get_accelerator().is_bf16_supported(): + base_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + base_config["fp16"] = {"enabled": True} model = SimpleModel(hidden_dim=10) with pytest.raises(AssertionError): model, _, _, _ = deepspeed.initialize(model=None, config=base_config) diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py index b9c67c9a30dd..4d184b2858a8 100644 --- a/tests/unit/runtime/test_ds_config_model.py +++ b/tests/unit/runtime/test_ds_config_model.py @@ -4,18 +4,25 @@ # DeepSpeed Team import pytest -import os import json +import os +from typing import List, Optional + from pydantic import Field, ValidationError -from typing import List + from deepspeed.runtime import config as ds_config from deepspeed.runtime.config_utils import DeepSpeedConfigModel class SimpleConf(DeepSpeedConfigModel): param_1: int = 0 - param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x])) - param_2: List[str] = None + param_2_old: Optional[str] = Field(None, + json_schema_extra={ + "deprecated": True, + "new_param": "param_2", + "new_param_fn": (lambda x: [x]) + }) + param_2: Optional[List[str]] = None param_3: int = Field(0, alias="param_3_alias") diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 4ff64dea96ef..80fd622d5534 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -11,13 +11,24 @@ from unit.simple_model import SimpleModel, random_dataloader from unit.common import DistributedTest -from unit.util import required_torch_version, bf16_required_version_check, required_amp_check +from unit.util import bf16_required_version_check, required_amp_check import deepspeed from deepspeed.ops.adam import FusedAdam from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR from deepspeed.runtime.config import ADAM_OPTIMIZER +from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.utils import see_memory_usage +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import FusedAdamBuilder + + +# Ensure client multiprocessing is not broken by deepspeed import +@pytest.mark.parametrize('method', ['spawn', 'fork', 'forkserver']) +def test_start_method_safety(method): + import torch.multiprocessing as mp + mp.set_start_method(method, force=True) @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -25,14 +36,11 @@ class TestNoOptim(DistributedTest): world_size = 1 def test(self, zero_stage): - if zero_stage == 3 and not required_torch_version(): + if zero_stage == 3 and not required_torch_version(min_version=1.8): pytest.skip("zero-3 param offload requires at least torch 1.8") ds_config = { 'train_batch_size': self.world_size, - 'fp16': { - 'enabled': True - }, 'zero_optimization': { "stage": zero_stage, "offload_param": { @@ -40,6 +48,10 @@ def test(self, zero_stage): } } } + if get_accelerator().is_bf16_supported(): + ds_config["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + ds_config["fp16"] = {"enabled": True} # 20B test #hidden_dim = 16 * 1024 hidden_dim = 4 @@ -49,11 +61,7 @@ def test(self, zero_stage): see_memory_usage('pre-init', force=True) model, _, _, _ = deepspeed.initialize(model=model, config=ds_config) see_memory_usage('post-init', force=True) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.half) + data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for batch in data_loader: model(batch[0], batch[1]) see_memory_usage('post-fwds', force=True) @@ -68,6 +76,9 @@ def test(self, optimizer_type): def _optimizer_callable(params) -> Optimizer: return AdamW(params=params) + if (optimizer_type is None) and (not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]): + pytest.skip("FusedAdam is not compatible") + hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -96,6 +107,8 @@ def _optimizer_callable(params) -> Optimizer: class TestConfigOptimizer(DistributedTest): world_size = 1 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test(self, client_parameters): ds_config = {"train_batch_size": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}} @@ -112,22 +125,28 @@ def test(self, client_parameters): assert isinstance(ds_optimizer, FusedAdam) -@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'amp', None]) +@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'zero3', 'amp', None]) @pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32']) class TestOptimizerImplementation(DistributedTest): world_size = 1 + reuse_dist_env = True def test(self, optimizer_extension, model_dtype, grad_accum_dtype): + if not get_accelerator().is_fp16_supported(): + if model_dtype == 'fp16' or grad_accum_dtype == 'fp16': + pytest.skip("fp16 is not supported") if optimizer_extension == 'zero1': zero_stage = 1 elif optimizer_extension == 'zero2': zero_stage = 2 + elif optimizer_extension == 'zero3': + zero_stage = 3 else: zero_stage = 0 - amp = True if optimizer_extension == 'amp' else False - fp16 = True if model_dtype == 'fp16' else False - bf16 = True if model_dtype == 'bf16' else False + amp = (optimizer_extension == 'amp') + fp16 = (model_dtype == 'fp16') + bf16 = (model_dtype == 'bf16') # Skip checks if bf16 and not bf16_required_version_check(): pytest.skip( @@ -168,18 +187,42 @@ def test(self, optimizer_extension, model_dtype, grad_accum_dtype): # ZeRO 1 Wrapper is_supported[('zero1', 'fp16', None)] = True is_supported[('zero1', 'fp16', 'fp16')] = True + is_supported[('zero1', 'fp16', 'bf16')] = True + is_supported[('zero1', 'fp16', 'fp32')] = True is_supported[('zero1', 'bf16', None)] = True + is_supported[('zero1', 'bf16', 'fp16')] = True is_supported[('zero1', 'bf16', 'bf16')] = True is_supported[('zero1', 'bf16', 'fp32')] = True is_supported[('zero1', 'fp32', None)] = True + is_supported[('zero1', 'fp32', 'fp16')] = True + is_supported[('zero1', 'fp32', 'bf16')] = True is_supported[('zero1', 'fp32', 'fp32')] = True # ZeRO 2 Wrapper is_supported[('zero2', 'fp16', None)] = True is_supported[('zero2', 'fp16', 'fp16')] = True + is_supported[('zero2', 'fp16', 'bf16')] = True + is_supported[('zero2', 'fp16', 'fp32')] = True is_supported[('zero2', 'bf16', None)] = True + is_supported[('zero2', 'bf16', 'fp16')] = True is_supported[('zero2', 'bf16', 'bf16')] = True + is_supported[('zero2', 'bf16', 'fp32')] = True is_supported[('zero2', 'fp32', None)] = True + is_supported[('zero2', 'fp32', 'fp16')] = True + is_supported[('zero2', 'fp32', 'bf16')] = True is_supported[('zero2', 'fp32', 'fp32')] = True + # ZeRO 3 Wrapper + is_supported[('zero3', 'fp16', None)] = True + is_supported[('zero3', 'fp16', 'fp16')] = True + is_supported[('zero3', 'fp16', 'bf16')] = True + is_supported[('zero3', 'fp16', 'fp32')] = True + is_supported[('zero3', 'bf16', None)] = True + is_supported[('zero3', 'bf16', 'fp16')] = True + is_supported[('zero3', 'bf16', 'bf16')] = True + is_supported[('zero3', 'bf16', 'fp32')] = True + is_supported[('zero3', 'fp32', None)] = True + is_supported[('zero3', 'fp32', 'fp16')] = True + is_supported[('zero3', 'fp32', 'bf16')] = True + is_supported[('zero3', 'fp32', 'fp32')] = True # Amp Wrapper is_supported[('amp', 'fp32', None)] = True is_supported[('amp', 'fp32', 'fp32')] = True @@ -187,7 +230,7 @@ def test(self, optimizer_extension, model_dtype, grad_accum_dtype): is_supported[(None, 'fp16', None)] = True is_supported[(None, 'fp16', 'fp16')] = True # BF16 Wrapper - is_supported[(None, 'bf16', 'fp32')] = True + is_supported[(None, 'bf16', 'bf16')] = True is_supported[(None, 'bf16', None)] = True # No Wrapper is_supported[(None, 'fp32', None)] = True @@ -209,6 +252,55 @@ def test(self, optimizer_extension, model_dtype, grad_accum_dtype): model_parameters=model_parameters) +class TestBf16ZeRO0UnfusedOptimizer(DistributedTest): + world_size = 1 + reuse_dist_env = True + + def test_static_scale_and_zero_grad_after_step(self): + if not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + hidden_dim = 16 + model = SimpleModel(hidden_dim) + client_optimizer = AdamW(model.parameters(), lr=1e-4) + ds_config = { + "train_batch_size": 1, + "train_micro_batch_size_per_gpu": 1, + "bf16": { + "enabled": True + }, + "zero_optimization": { + "stage": 0 + }, + } + + engine, _, _, _ = deepspeed.initialize(config=ds_config, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer) + + assert isinstance(engine.optimizer, FP16_UnfusedOptimizer) + assert engine.optimizer.low_precision_dtype == torch.bfloat16 + assert engine.optimizer.loss_scale_config.dynamic_loss_scale is False + assert engine.optimizer.loss_scale_config.cur_scale == 1 + + data_loader = random_dataloader(model=engine, + total_samples=1, + hidden_dim=hidden_dim, + device=engine.device, + dtype=torch.bfloat16) + batch = next(iter(data_loader)) + + loss = engine(batch[0], batch[1]) + engine.backward(loss) + assert any(param.grad is not None for param in engine.module.parameters() if param.requires_grad) + + engine.step() + assert all(param.grad is None for param in engine.module.parameters() if param.requires_grad) + + @pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable]) @pytest.mark.parametrize("optimizer_type", [None, Optimizer, Callable]) class TestClientLrScheduler(DistributedTest): @@ -270,3 +362,132 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler: assert ds_lr_scheduler == client_scheduler else: assert isinstance(ds_lr_scheduler, LambdaLR) + + +@pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable]) +class TestClientLrSchedulerInit(DistributedTest): + world_size = 1 + + def test_same_lrscheler_and_callable(self, scheduler_type): + """ + Expect behavior + + if lr scheduler is defined in code and passed into initialize as arg, + it will be used even this is a lr scheduler has been defined in config. + + Initialize lr scheduler from config when no lr scheduler is defined in code. + """ + + def _my_lambda(epoch): + return epoch // 10 + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return LambdaLR(optimizer, _my_lambda) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = LambdaLR(client_optimizer, _my_lambda) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + # in this case, we initialize from config + assert not isinstance(ds_lr_scheduler, LambdaLR) + assert isinstance(ds_lr_scheduler, WarmupLR) + else: + # in this case, we initialize from passed-in scheduler + assert isinstance(ds_lr_scheduler, LambdaLR) + assert not isinstance(ds_lr_scheduler, WarmupLR) + + def test_diff_lrscheler_and_callable(self, scheduler_type): + """ + In this test, + the LambdaLR will be used for lrscheduler type + and the StepLR will be used for callable type + """ + + from torch.optim.lr_scheduler import StepLR + + def _my_lambda(epoch): + return epoch // 10 + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return StepLR(optimizer, step_size=30) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = LambdaLR(client_optimizer, _my_lambda) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + assert isinstance(ds_lr_scheduler, WarmupLR) + elif scheduler_type == _LRScheduler: + assert isinstance(ds_lr_scheduler, LambdaLR) + else: + # callable + assert isinstance(ds_lr_scheduler, StepLR) + + def test_diff_lrscheler_and_callable_onecyclelr_steplr(self, scheduler_type): + + from deepspeed.runtime.lr_schedules import OneCycle, ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR + from torch.optim.lr_scheduler import OneCycleLR, StepLR + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return OneCycleLR(optimizer, max_lr=0.01, total_steps=200) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': ONE_CYCLE, 'params': {CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = StepLR(client_optimizer, step_size=30) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + assert isinstance(ds_lr_scheduler, OneCycle) + elif scheduler_type == _LRScheduler: + assert isinstance(ds_lr_scheduler, StepLR) + else: + # callable + assert isinstance(ds_lr_scheduler, OneCycleLR) diff --git a/tests/unit/runtime/test_lr_schedulers.py b/tests/unit/runtime/test_lr_schedulers.py index 2393891e28df..1dfa853dbe05 100644 --- a/tests/unit/runtime/test_lr_schedulers.py +++ b/tests/unit/runtime/test_lr_schedulers.py @@ -3,6 +3,8 @@ # DeepSpeed Team +import math + import torch import deepspeed import pytest @@ -13,6 +15,7 @@ from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS +from deepspeed.runtime.lr_schedules import WARMUP_COSINE_LR, WARMUP_MIN_RATIO, COS_MIN_RATIO, WarmupCosineLR def _verify_continuous_decrease(values): @@ -36,6 +39,9 @@ def _verify_staircase_increase(values, step_size): (WARMUP_DECAY_LR, { WARMUP_NUM_STEPS: 10, TOTAL_NUM_STEPS: 20 + }), (WARMUP_COSINE_LR, { + WARMUP_NUM_STEPS: 10, + TOTAL_NUM_STEPS: 20 }), (ONE_CYCLE, { CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1 @@ -70,6 +76,11 @@ def test(self, scheduler_type, params): hidden_dim=hidden_dim, device=model.device, dtype=torch.float) + + true_lrs = lr_scheduler.get_lr() + for group, true_lr in zip(model.optimizer.param_groups, true_lrs): + assert group['lr'] == true_lr, f"True lr {true_lr}, optimizer lr {group['lr']}" + for n, batch in enumerate(data_loader): # get lr before training starts lr_scheduler.get_lr() @@ -441,3 +452,94 @@ def test_mom(self, min_mom, max_mom, decay_rate, step_size): # Verify decay phase if decay_rate > 0: _verify_continuous_increase(step_moms[(step_size * 2):]) + + +class TestWarmupCosineLR(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize("total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_ratio", + [ + (100, 10, 0.1, 0.2), + (200, 20, 0.1, 0.2), + (500, 30, 0.0, 0.2), + (600, 300, 0.1, 0.0), + (600, 550, 0.0, 0.0), + ]) # yapf: disable + def test_lr(self, total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_ratio): + opt_lr = 0.0015 + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": opt_lr + }, + }, + "scheduler": { + "type": WARMUP_COSINE_LR, + "params": { + TOTAL_NUM_STEPS: total_num_steps, + WARMUP_MIN_RATIO: warmup_min_ratio, + WARMUP_NUM_STEPS: warmup_num_steps, + COS_MIN_RATIO: cos_min_ratio, + } + }, + "gradient_clipping": 1.0 + } + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=max(50, total_num_steps * 3), + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + + step_lrs = [] + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + step_lrs.extend(lr_scheduler.get_lr()) + + # Verify starting lr + assert abs(step_lrs[0] - opt_lr * warmup_min_ratio) < 1e-7 + + # Verify peak lr + assert abs(step_lrs[warmup_num_steps - 1] - opt_lr) < 1e-7 + + # Verify end lr + assert abs(step_lrs[total_num_steps - 1] - opt_lr * cos_min_ratio) < 1e-7 + + # Verify increasing phase + _verify_continuous_increase(step_lrs[:warmup_num_steps]) + + # Verify decreasing phase + _verify_continuous_decrease(step_lrs[warmup_num_steps:total_num_steps]) + + +def test_warmup_cosine_lr_initializes_all_param_groups(): + dense = torch.nn.Parameter(torch.zeros(1)) + expert = torch.nn.Parameter(torch.zeros(1)) + optimizer = torch.optim.Adam([{"params": [dense], "lr": 0.0015}, {"params": [expert], "lr": 0.003}]) + + scheduler = WarmupCosineLR(optimizer=optimizer, total_num_steps=100, warmup_num_steps=10, warmup_min_ratio=0.0) + + assert scheduler.get_lr_ratio() == 0.0 + assert scheduler.get_lr() == [0.0, 0.0] + assert scheduler.get_last_lr() == [0.0, 0.0] + assert [group["lr"] for group in optimizer.param_groups] == [0.0, 0.0] + + scheduler.step(1) + + expected_ratio = math.log(2) / math.log(10) + expected_lrs = [0.0015 * expected_ratio, 0.003 * expected_ratio] + + assert scheduler.get_lr_ratio() == pytest.approx(expected_ratio) + assert scheduler.get_lr() == pytest.approx(expected_lrs) + assert scheduler.get_last_lr() == pytest.approx(expected_lrs) + assert [group["lr"] for group in optimizer.param_groups] == pytest.approx(expected_lrs) diff --git a/tests/unit/runtime/test_multi_output_model.py b/tests/unit/runtime/test_multi_output_model.py index d9aba419b158..11a030bcbcc2 100644 --- a/tests/unit/runtime/test_multi_output_model.py +++ b/tests/unit/runtime/test_multi_output_model.py @@ -5,8 +5,10 @@ import torch import deepspeed +from deepspeed.accelerator import get_accelerator from pytest import approx -from unit.common import DistributedTest +from unit.util import torch_assert_close +from unit.common import DistributedTest, preferred_dtype from unit.multi_output_model import MultiOutputModel, multi_output_dataloader @@ -28,10 +30,11 @@ def test(self, tmpdir): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 weight_value = 0.1 @@ -47,21 +50,21 @@ def test(self, tmpdir): targets=[1, 2]) for n, batch in enumerate(data_loader): assert len(batch) % 2 == 0, \ - f"multi_output_dataloader failed to return even number of data samples (input+target)" + "multi_output_dataloader failed to return even number of data samples (input+target)" midpoint = len(batch) // 2 inputs, targets = batch[:midpoint], batch[midpoint:] loss_tuple = model(inputs, targets) - expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device) + expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device) for loss in loss_tuple: assert loss.shape == torch.Size([]) assert loss.item() == approx(expected_loss.item()) summed_loss = sum(loss_tuple) scaled_loss = model.backward(summed_loss) - expected_scaled_loss = summed_loss.float() / grad_accumulation_steps - assert scaled_loss.item() == approx(expected_scaled_loss.item()) + expected_scaled_loss = summed_loss / grad_accumulation_steps + torch_assert_close(scaled_loss, expected_scaled_loss) model.step() @@ -84,10 +87,11 @@ def test(self, tmpdir): "lr": 0.00015 } }, - "fp16": { - "enabled": True - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 weight_value = 0.1 @@ -104,14 +108,14 @@ def test(self, tmpdir): targets=[1, 2, 3]) for n, batch in enumerate(data_loader): assert len(batch) % 2 == 0, \ - f"multi_output_dataloader failed to return even number of data samples (input+target)" + "multi_output_dataloader failed to return even number of data samples (input+target)" midpoint = len(batch) // 2 inputs, targets = batch[:midpoint], batch[midpoint:] loss_tuple = model(inputs, targets) assert len(loss_tuple) == 3 - expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device) + expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device) for loss in loss_tuple: assert loss.shape == torch.Size([]) @@ -119,7 +123,7 @@ def test(self, tmpdir): summed_loss = sum(loss_tuple) scaled_loss = model.backward(summed_loss) - expected_scaled_loss = summed_loss.float() / grad_accumulation_steps - assert scaled_loss.item() == approx(expected_scaled_loss.item()) + expected_scaled_loss = summed_loss / grad_accumulation_steps + torch_assert_close(scaled_loss, expected_scaled_loss) model.step() diff --git a/tests/unit/runtime/test_multiple_models.py b/tests/unit/runtime/test_multiple_models.py new file mode 100644 index 000000000000..9bce009d7a87 --- /dev/null +++ b/tests/unit/runtime/test_multiple_models.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed +import deepspeed.comm as dist +import torch +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +def create_model(config_dict): + hidden_dim = 64 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + return model + + +def train_shared_loss(num_models, config_dict, dtype): + hidden_dim = 64 + + models = [create_model(config_dict) for _ in range(num_models)] + data_loader = random_dataloader(model=models[0], + total_samples=4, + hidden_dim=hidden_dim, + device=models[0].device, + dtype=dtype) + dist.barrier() + for _, batch in enumerate(data_loader): + losses = [m.module(batch[0], batch[1]) for m in models] + loss = sum(l / (i + 1) for i, l in enumerate(losses)) + loss.backward() + + for m in models: + m._backward_epilogue() + + for m in models: + m.step() + + for m in models: + m.optimizer.zero_grad() + + for m in models: + m.destroy() + + +def train_independent_loss(num_models, config_dict, dtype): + hidden_dim = 64 + + models = [create_model(config_dict) for _ in range(num_models)] + data_loader = random_dataloader(model=models[0], + total_samples=4, + hidden_dim=hidden_dim, + device=models[0].device, + dtype=dtype) + dist.barrier() + for _, batch in enumerate(data_loader): + losses = [m.module(batch[0], batch[1]) for m in models] + for m, loss in zip(models, losses): + m.backward(loss) + m.step() + + for m in models: + m.destroy() + + +@pytest.mark.parametrize('num_models', [1, 2, 3]) +class TestMultipleModels(DistributedTest): + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('shared_loss', [False, True]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('fp32_grad_accum', [False, True]) + @pytest.mark.parametrize('contiguous_gradients', [False, True]) + @pytest.mark.parametrize('overlap_comm', [False, True]) + def test_zero_optimizer(self, num_models, shared_loss, zero_stage, fp32_grad_accum, contiguous_gradients, + overlap_comm): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": zero_stage, + "contiguous_gradients": contiguous_gradients, + "overlap_comm": overlap_comm, + }, + "fp16": { + "initial_scale_power": 8, + "enabled": True + }, + } + if fp32_grad_accum: + config_dict["data_types"] = {"grad_accum_dtype": "fp32"} + + if shared_loss: + train_shared_loss(num_models=num_models, config_dict=config_dict, dtype=torch.float16) + else: + train_independent_loss(num_models=num_models, config_dict=config_dict, dtype=torch.float16) + + # TODO: Combination of shared_loss==True and bf16.immediate_grad_update==False is currently broken + @pytest.mark.parametrize('shared_loss', [False, True]) + def test_bf16_optimizer(self, num_models, shared_loss): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 1, + }, + "bf16": { + "enabled": True, + "immediate_grad_update": True, + }, + "data_types": { + "grad_accum_dtype": "fp32" + } + } + + if shared_loss: + train_shared_loss(num_models=num_models, config_dict=config_dict, dtype=torch.bfloat16) + else: + train_independent_loss(num_models=num_models, config_dict=config_dict, dtype=torch.bfloat16) diff --git a/tests/unit/runtime/test_mup_optimizers.py b/tests/unit/runtime/test_mup_optimizers.py new file mode 100644 index 000000000000..a004397b47ac --- /dev/null +++ b/tests/unit/runtime/test_mup_optimizers.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from mup.shape import set_base_shapes +from deepspeed.accelerator import get_accelerator + + +@pytest.mark.parametrize("optimizer, expected_opt_class", [("MuAdam", torch.optim.Adam), + ("MuAdamW", torch.optim.AdamW), ("MuSGD", torch.optim.SGD)]) # yapf: disable +@pytest.mark.parametrize("zero_offload", [True, False]) # yapf: disable +class TestMuPOptimizers(DistributedTest): + world_size = 1 + reuse_dist_env = True + + def test(self, optimizer, expected_opt_class, zero_offload): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "zero_allow_untested_optimizer": True, + "optimizer": { + "type": optimizer, + "params": { + "lr": 0.00015, + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "cpu_offload": zero_offload + } + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + hidden_dim = 10 + model = SimpleModel(hidden_dim) + set_base_shapes(model, None) + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + ds_optimizer = model.optimizer.optimizer + assert isinstance(ds_optimizer, expected_opt_class) diff --git a/tests/unit/runtime/test_no_sync_ctxt.py b/tests/unit/runtime/test_no_sync_ctxt.py new file mode 100644 index 000000000000..8c6497013809 --- /dev/null +++ b/tests/unit/runtime/test_no_sync_ctxt.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +from contextlib import nullcontext +import torch + +from unit.simple_model import SimpleModel, random_dataloader +from unit.common import DistributedTest + +import deepspeed +import deepspeed.comm as dist +from deepspeed.utils import safe_get_full_grad + + +class TestNoSyncCtxt(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + def test_zero_stage(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + invalid_cfg = zero_stage > 1 + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + with pytest.raises(AssertionError) if invalid_cfg else nullcontext() as assertinfo: + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + if invalid_cfg: + assert ("no_sync context manager is incompatible" in str(assertinfo)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1]) + def test_engine_step(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + with pytest.raises(AssertionError) as assertinfo: + model.step() + assert ("It is illegal to call Engine.step() inside no_sync context manager" in str(assertinfo)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1]) + def test_multiple_ctxts(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + param_list = list(model.parameters()) + first_losses = [] + first_grad_norms = [] + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + first_losses.append(loss.item()) + model.backward(loss) + grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list]) + first_grad_norms.append(grad_norm.item()) + + second_losses = [] + second_grad_norms = [] + + model.zero_grad() + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + second_losses.append(loss.item()) + model.backward(loss) + grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list]) + second_grad_norms.append(grad_norm.item()) + + assert len(first_losses) == len(second_losses) + for x, y in zip(first_losses, second_losses): + assert x == y + + assert len(first_grad_norms) == len(second_grad_norms) + for x, y in zip(first_grad_norms, second_grad_norms): + assert x == y + + def test_reentry(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": 1, + }, + } + + hidden_dim = 64 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + dist.barrier() + + with model.no_sync(): + with pytest.raises(AssertionError) as assertinfo: + with model.no_sync(): + pass + assert ("no_sync context manager reentry is unsupported" in str(assertinfo)) diff --git a/tests/unit/runtime/test_pld.py b/tests/unit/runtime/test_pld.py index 1f602db73b2f..a776ba882852 100644 --- a/tests/unit/runtime/test_pld.py +++ b/tests/unit/runtime/test_pld.py @@ -10,6 +10,7 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, PLD_SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) @@ -39,15 +40,16 @@ def test_pld_model(self, theta): "lr": 0.0001 } }, - "fp16": { - "enabled": True - }, "progressive_layer_drop": { "enabled": True, "theta": theta, "gamma": gamma } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 model = PLD_SimpleModel(hidden_dim, empty_grad=False) @@ -80,15 +82,16 @@ def test_non_pld_model(self): "lr": 0.0001 } }, - "fp16": { - "enabled": True - }, "progressive_layer_drop": { "enabled": True, "theta": theta, "gamma": gamma } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=False) diff --git a/tests/unit/runtime/test_precision_config_loss_scale.py b/tests/unit/runtime/test_precision_config_loss_scale.py new file mode 100644 index 000000000000..cc8b7d064c62 --- /dev/null +++ b/tests/unit/runtime/test_precision_config_loss_scale.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import math + +import pytest +from pydantic import ValidationError + +from deepspeed.runtime.precision_config import DeepSpeedFP16Config + + +@pytest.mark.parametrize("loss_scale", [-1, float("inf"), float("nan"), True]) +def test_fp16_loss_scale_rejects_invalid_values(loss_scale): + with pytest.raises(ValidationError): + DeepSpeedFP16Config(loss_scale=loss_scale) + + +@pytest.mark.parametrize("loss_scale", [0, 1, 2.0, "3"]) +def test_fp16_loss_scale_accepts_valid_values(loss_scale): + cfg = DeepSpeedFP16Config(loss_scale=loss_scale) + assert math.isfinite(cfg.loss_scale) + assert cfg.loss_scale >= 0 + + +@pytest.mark.parametrize("loss_scale", [[], {}]) +def test_fp16_loss_scale_invalid_type_has_clear_error(loss_scale): + with pytest.raises(ValidationError) as excinfo: + DeepSpeedFP16Config(loss_scale=loss_scale) + assert "must be a number" in str(excinfo.value) diff --git a/tests/unit/runtime/test_runtime_utils.py b/tests/unit/runtime/test_runtime_utils.py index 5d8478b249be..8aebbe500eb7 100644 --- a/tests/unit/runtime/test_runtime_utils.py +++ b/tests/unit/runtime/test_runtime_utils.py @@ -7,6 +7,7 @@ from torch._utils import _flatten_dense_tensors import deepspeed.comm as dist import pytest +from typing import Dict import deepspeed.runtime.utils as ds_utils import deepspeed.utils.groups as groups @@ -26,10 +27,10 @@ def test_call_to_str(): assert c2s('hello', 1138, val=3) == 'hello(1138, val=3)' -class TestClibGradNorm(DistributedTest): +class TestClipGradNorm(DistributedTest): world_size = 2 - def test(self): + def test_gather(self): param1 = torch.nn.Parameter(torch.Tensor([0])) param1.grad = torch.Tensor([1]) param2 = torch.nn.Parameter(torch.Tensor([0])) @@ -50,6 +51,27 @@ def test(self): assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1" + def test_clipped_val(self): + max_norm = 0.1 + + def test_params(): + param1 = torch.nn.Parameter(torch.Tensor([0])) + param1.grad = torch.Tensor([1]) + param2 = torch.nn.Parameter(torch.Tensor([0])) + param2.grad = torch.Tensor([1]) + return [param1, param2] + + # This assumes gradients are same on all the ranks and doesn't consider multiple ranks + params_expected = test_params() + torch.nn.utils.clip_grad_norm_(params_expected, max_norm) + + params_actual = test_params() + ds_utils.clip_grad_norm_(params_actual, max_norm=max_norm) + + # This can be allclose + assert torch.equal(params_expected[0].grad, params_actual[0].grad) + assert torch.equal(params_expected[1].grad, params_actual[1].grad) + @pytest.mark.parametrize("check_using_norm", [(False), (True)]) class TestCheckOverflow(DistributedTest): @@ -77,3 +99,30 @@ def test(self, check_using_norm): overflow_checker = ds_utils.CheckOverflow([parameters]) overflow = overflow_checker.check() assert overflow + + +@pytest.mark.skipif(not hasattr(torch.autograd.graph, "_get_grad_fn_or_grad_acc"), + reason="requires torch.autograd.graph._get_grad_fn_or_grad_acc") +def test_count_used_parameters_enables_grad_for_grad_acc_lookup(monkeypatch): + """count_used_parameters_in_backward should enable grad for grad-acc lookup.""" + param = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True)) + seen: Dict[str, int] = {"lookup_calls": 0} + original_getter = torch.autograd.graph._get_grad_fn_or_grad_acc + + def _require_grad_enabled(t): + seen["lookup_calls"] += 1 + if not torch.is_grad_enabled(): + raise RuntimeError("grad mode must be enabled for grad-acc lookup") + return original_getter(t) + + monkeypatch.setattr(torch.autograd.graph, "_get_grad_fn_or_grad_acc", _require_grad_enabled) + + def _hook(grad): + seen["count"] = ds_utils.count_used_parameters_in_backward([param]) + return grad + + param.register_hook(_hook) + loss = (param * 2.0).sum() + loss.backward() + assert seen["lookup_calls"] > 0 + assert "count" in seen diff --git a/tests/unit/runtime/test_tp_plan_extraction.py b/tests/unit/runtime/test_tp_plan_extraction.py new file mode 100644 index 000000000000..7d0c8e81164c --- /dev/null +++ b/tests/unit/runtime/test_tp_plan_extraction.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.runtime.tensor_parallel.config import _get_hf_tp_plan + + +class TestTPPlanExtraction: + + def test_extract_tp_plan_from_mock_model(self): + + class MockHFModel: + + def __init__(self): + self._tp_plan = {"layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise"} + + model = MockHFModel() + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is not None + assert "layers.*.self_attn.q_proj" in tp_plan + assert tp_plan["layers.*.self_attn.q_proj"] == "colwise" + + def test_extract_tp_plan_from_model_with_config(self): + + class MockHFConfig: + base_model_tp_plan = {"layers.*.self_attn.q_proj": "colwise"} + + class MockHFModel: + + def __init__(self, config): + self.config = config + + config = MockHFConfig() + model = MockHFModel(config) + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is not None + assert "layers.*.self_attn.q_proj" in tp_plan + + def test_no_tp_plan_model(self): + model = torch.nn.Linear(10, 10) + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is None + + def test_empty_tp_plan(self): + + class MockHFModel: + + def __init__(self): + self._tp_plan = {} + + model = MockHFModel() + tp_plan = _get_hf_tp_plan(model) + + # Empty _tp_plan is falsy, so falls through to config then None + assert tp_plan is None + + def test_none_tp_plan_falls_back_to_config(self): + + class MockHFConfig: + base_model_tp_plan = {"layers.*.self_attn.q_proj": "colwise"} + + class MockHFModel: + + def __init__(self, config): + self.config = config + self._tp_plan = None + + config = MockHFConfig() + model = MockHFModel(config) + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is not None + assert "layers.*.self_attn.q_proj" in tp_plan + + def test_none_tp_plan(self): + + class MockHFModel: + + def __init__(self): + pass + + model = MockHFModel() + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is None + + def test_priority_config_over_model(self): + + class MockHFConfig: + base_model_tp_plan = {"config_plan": "colwise"} + + class MockHFModel: + + def __init__(self, config): + self.config = config + self._tp_plan = {"model_plan": "colwise"} + + config = MockHFConfig() + model = MockHFModel(config) + tp_plan = _get_hf_tp_plan(model) + + assert tp_plan is not None + assert "config_plan" in tp_plan + assert "model_plan" not in tp_plan diff --git a/tests/unit/runtime/utils/test_partition.py b/tests/unit/runtime/utils/test_partition.py index e7085ee2c4bd..8f7768d0d730 100644 --- a/tests/unit/runtime/utils/test_partition.py +++ b/tests/unit/runtime/utils/test_partition.py @@ -22,7 +22,6 @@ class TestPartitionedTensor(DistributedTest): def test(self): world = dist.get_world_size() - rank = dist.get_rank() group = dist.new_group(ranks=list(range(world))) @@ -40,12 +39,32 @@ def test(self): assert torch.equal(full, reconstructed) +class TestPartitionedTensorUnEven(DistributedTest): + world_size = 4 + + def test(self): + world = dist.get_world_size() + + group = dist.new_group(ranks=list(range(world))) + + rows = world * 4 - 1 + cols = world + 1 + + full = torch.rand(rows, cols).to(get_accelerator().device_name()) + dist.broadcast(full, src=0, group=group) + part = PartitionedTensor(full, group=group) + + assert len(part.local_size()) == 1 + + reconstructed = part.full() + assert torch.equal(full, reconstructed) + + class TestPartitionedTensorMeta(DistributedTest): world_size = 4 def test(self): world = dist.get_world_size() - rank = dist.get_rank() group = dist.new_group(ranks=list(range(world))) diff --git a/tests/unit/runtime/zenflow/test_zf.py b/tests/unit/runtime/zenflow/test_zf.py new file mode 100644 index 000000000000..7adcdb784972 --- /dev/null +++ b/tests/unit/runtime/zenflow/test_zf.py @@ -0,0 +1,111 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +import deepspeed + + +class BaseZenFlowTest: + hidden_dim = 10 + batch_size = 4 + grad_acc_steps = 1 + + def get_config_dict(self, stage, offload_selective_optimizer, select_strategy, select_interval, update_interval, + full_warm_up_rounds): + config = { + "train_batch_size": self.batch_size, + "gradient_accumulation_steps": self.grad_acc_steps, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": stage, + "offload_optimizer": { + "device": "cpu" + }, + "overlap_comm": True, + "zenflow": { + "topk_ratio": 0.2, + "select_strategy": select_strategy, + "select_interval": select_interval, + "update_interval": update_interval, + "overlap_step": False, + "offload": offload_selective_optimizer, + "auto_ratio": 0.99, + "full_warm_up_rounds": full_warm_up_rounds, + } + }, + "zero_allow_untested_optimizer": True, + } + + if get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} + return config + + def run_training_distributed(self, config_dict): + + if get_accelerator().device_name() == "cpu": + return + + model = SimpleModel(self.hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + train_dataloader = random_dataloader(model=model, + total_samples=20, + hidden_dim=self.hidden_dim, + device=model.device) + + dist.barrier() + + for step, batch in enumerate(train_dataloader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + model.destroy() + + +@pytest.mark.parametrize("stage", [1, 2, 3]) +@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) +@pytest.mark.parametrize("offload_selective_optimizer", [True, False]) +@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [ + ("auto", "auto", "auto"), + ("step", 10, 3), + ("epoch", 1, 4), +]) +class TestZenFlowSingleGPU(DistributedTest, BaseZenFlowTest): + world_size = 1 + + def test_zenflow_single_gpu(self, stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds): + tester = BaseZenFlowTest() + config_dict = tester.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds) + tester.run_training_distributed(config_dict) + + +@pytest.mark.parametrize("stage", [1, 2, 3]) +@pytest.mark.parametrize("full_warm_up_rounds", [0, 3]) +@pytest.mark.parametrize("offload_selective_optimizer", [True, False]) +@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [ + ("auto", "auto", "auto"), + ("step", 10, 3), + ("epoch", 1, 4), +]) +class TestZenFlowDistributed(DistributedTest, BaseZenFlowTest): + world_size = 2 + + def test_zenflow_distributed(self, stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds): + config_dict = self.get_config_dict(stage, offload_selective_optimizer, select_strategy, select_interval, + update_interval, full_warm_up_rounds) + self.run_training_distributed(config_dict) diff --git a/tests/unit/runtime/zenflow/test_zf_config.py b/tests/unit/runtime/zenflow/test_zf_config.py new file mode 100644 index 000000000000..647b7f82f2e9 --- /dev/null +++ b/tests/unit/runtime/zenflow/test_zf_config.py @@ -0,0 +1,86 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +from pydantic import ValidationError + +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig, ZeroStageEnum +from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig + + +def test_stage_enum_accepts_int_and_enum(): + """`stage` can be passed as either an int or the ZeroStageEnum.""" + c1 = DeepSpeedZeroConfig(stage=2) + assert c1.stage == ZeroStageEnum.gradients + c2 = DeepSpeedZeroConfig(stage=ZeroStageEnum.weights) + assert c2.stage == ZeroStageEnum.weights + + +def test_offload_optimizer_config_from_dict(): + """A dict for offload_optimizer should be coerced into DeepSpeedZeroOffloadOptimizerConfig.""" + cfg = DeepSpeedZeroConfig(offload_optimizer={"device": "cpu", "pin_memory": True}) + assert isinstance(cfg.offload_optimizer, DeepSpeedZeroOffloadOptimizerConfig) + assert cfg.offload_optimizer.device == "cpu" + assert cfg.offload_optimizer.pin_memory is True + + +def test_invalid_offload_optimizer_type_raises(): + """Passing a non-dict to offload_optimizer must error out.""" + with pytest.raises(ValidationError): + DeepSpeedZeroConfig(offload_optimizer="not a dict") + + +def test_zenflow_config_from_dict(): + """A dict for zenflow should be coerced into ZenFlowConfig.""" + zenflow_payload = { + "topk_ratio": 0.25, + "select_strategy": "auto", + "select_interval": 4, + "update_interval": 8, + "full_warm_up_rounds": 1, + "overlap_step": True + } + cfg = DeepSpeedZeroConfig(zenflow=zenflow_payload) + assert isinstance(cfg.zenflow, ZenFlowConfig) + assert cfg.zenflow.topk_ratio == 0.25 + assert cfg.zenflow.select_strategy == "auto" + assert cfg.zenflow.select_interval == 4 + assert cfg.zenflow.update_interval == 8 + assert cfg.zenflow.full_warm_up_rounds == 1 + assert cfg.zenflow.overlap_step is True + + +def test_invalid_zenflow_type_raises(): + """Passing a non-dict to zenflow must error out.""" + with pytest.raises(ValidationError): + DeepSpeedZeroConfig(zenflow=123) + + +def test_offload_and_zenflow_combined(): + """ + offload_optimizer and zenflow can be used together under stage 2 + without validation errors. + """ + payload = { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True + }, + "zenflow": { + "topk_ratio": 0.3, + "select_strategy": "epoch", + "select_interval": 3, + "update_interval": 6, + "full_warm_up_rounds": 0, + "overlap_step": False + } + } + cfg = DeepSpeedZeroConfig(**payload) + assert isinstance(cfg.offload_optimizer, DeepSpeedZeroOffloadOptimizerConfig) + assert cfg.offload_optimizer.device == "cpu" + assert isinstance(cfg.zenflow, ZenFlowConfig) + assert cfg.zenflow.select_strategy == "epoch" diff --git a/tests/unit/runtime/zero/test_ignore_unused_parameters.py b/tests/unit/runtime/zero/test_ignore_unused_parameters.py index aade488fde42..56b7ad2ac2d0 100644 --- a/tests/unit/runtime/zero/test_ignore_unused_parameters.py +++ b/tests/unit/runtime/zero/test_ignore_unused_parameters.py @@ -9,6 +9,7 @@ from deepspeed.ops.op_builder import CPUAdamBuilder import deepspeed +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('ignore_unused_parameters', [False, True]) @@ -36,11 +37,11 @@ def test(self, ignore_unused_parameters): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 model = UnusedParametersModel(hidden_dim=hidden_dim) diff --git a/tests/unit/runtime/zero/test_nvme_checkpointing.py b/tests/unit/runtime/zero/test_nvme_checkpointing.py new file mode 100644 index 000000000000..5b0c9d2a0d34 --- /dev/null +++ b/tests/unit/runtime/zero/test_nvme_checkpointing.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import pytest +import deepspeed.comm as dist +import torch + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader, SimpleModel + +import deepspeed +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.partition_parameters import Init +from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator + + +@pytest.mark.sequential +class TestNVMeCheckpointing(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('param_offload_device, optim_offload_device', + [(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme), + (OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme), + (OffloadDeviceEnum.nvme, OffloadDeviceEnum.none), + (OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu), + (OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)]) + def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device): + zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint") + + first_stage_steps, second_stage_steps = 2, 2 + + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + torch.manual_seed(123) + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": param_offload_device, + "nvme_path": str(zero_dir) + }, + "offload_optimizer": { + "device": optim_offload_device, + "nvme_path": str(zero_dir) + }, + "sub_group_size": 100, + "stage3_max_live_parameters": 100, + "stage3_param_persistence_threshold": 0, + }, + "aio": { + "block_size": 1048576 # Minimum AIO bytes, anything smaller than this will not be offloaded + } + } + + hidden_dim, nlayers = 2048, 2 + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=nlayers, empty_grad=False) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + model.empty_partition_cache() + + assert first_stage_steps > 0 + + data_loader = random_dataloader(model=model, + total_samples=first_stage_steps, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16) + dist.barrier() + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + dist.barrier() + model.save_checkpoint(ckpt_dir) + + if second_stage_steps > 0: + second_stage_batches = list( + random_dataloader(model=model, + total_samples=second_stage_steps, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16)) + dist.barrier() + for n, batch in enumerate(second_stage_batches): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + dist.barrier() + + final_batch = next( + iter( + random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float16))) + dist.barrier() + loss_before = float(model(final_batch[0], final_batch[1])) + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + # TODO: This should be on the engine? There needs to be a better way. + Init.param_id = 0 + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=nlayers, empty_grad=False) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + model.load_checkpoint(ckpt_dir) + + if second_stage_steps > 0: + dist.barrier() + for n, batch in enumerate(second_stage_batches): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + dist.barrier() + + dist.barrier() + loss_after = float(model(final_batch[0], final_batch[1])) + + assert loss_before == loss_after diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py new file mode 100644 index 000000000000..8a7bde215301 --- /dev/null +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +from deepspeed.runtime.zero import unwrap_model_for_generation +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel + +config = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 1, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + } +} + +if get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} +elif get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} + + +class TestUnwrapModel(DistributedTest): + # gather across more than 1 gpu + world_size = 2 + + def test(self): + + def hooks_exist(engine): + if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"): + optimizer_offload = engine.optimizer.parameter_offload + elif engine.optimizer is not None: + optimizer_offload = engine.optimizer + + hooks = 0 + for hook in optimizer_offload.forward_hooks: + hooks += 1 + if hooks > 0: + return True + return False + + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + + with unwrap_model_for_generation(engine): + # assert no hooks + assert not hooks_exist(engine) + # assert parameters gathered + assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # assert hooks + assert hooks_exist(engine) diff --git a/tests/unit/runtime/zero/test_zero_config.py b/tests/unit/runtime/zero/test_zero_config.py index a5bd96c411e0..8b20eca8c7d2 100644 --- a/tests/unit/runtime/zero/test_zero_config.py +++ b/tests/unit/runtime/zero/test_zero_config.py @@ -31,6 +31,12 @@ def test_zero_config_aliasfields(): assert config.gather_16bit_weights_on_model_save == True +def test_zero_config_pipeline_loading_checkpoint(): + for stage in [0, 1, 2]: + config = DeepSpeedZeroConfig(**{"stage": stage}) + assert config.pipeline_loading_checkpoint == False + + def test_zero_config_overlapcomm(): for stage in [0, 1, 2]: config = DeepSpeedZeroConfig(**{"stage": stage}) @@ -42,12 +48,12 @@ def test_zero_config_overlapcomm(): def test_zero_config_offload_configs(): config = DeepSpeedZeroConfig() - assert config.offload_param == None - assert config.offload_optimizer == None + assert config.offload_param is None + assert config.offload_optimizer is None config = DeepSpeedZeroConfig(**{"offload_param": None, "offload_optimizer": None}) - assert config.offload_param == None - assert config.offload_optimizer == None + assert config.offload_param is None + assert config.offload_optimizer is None config = DeepSpeedZeroConfig(**{"offload_param": {}, "offload_optimizer": {}}) assert isinstance(config.offload_param, DeepSpeedZeroOffloadParamConfig) diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index aabe7f0b7f15..189502445bf3 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -6,11 +6,13 @@ from types import SimpleNamespace import torch +import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype, reduce_boolean_flags from unit.simple_model import SimpleModel from utils import setup_serial_env @@ -47,16 +49,17 @@ def forward(self, x): "lr": 0.00015 } }, - "fp16": { - "enabled": True, - "loss_scale": 138. - }, "zero_optimization": { "stage": 3, "stage3_param_persistence_threshold": 1, } } +if get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} +elif get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} + class TestZeroGatheredParametersFree(DistributedTest): world_size = 1 @@ -81,6 +84,66 @@ def __init__(self, hidden_dim): assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized" +class TestMiCSGatheredParametersFree(DistributedTest): + world_size = 1 + + def test(self): + config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3, "mics_shard_size": 1}} + hidden_dim = 10 + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + + with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + with deepspeed.zero.GatheredParameters(list(model.parameters())): + assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor" + + # on exit from `GatheredParameters` the gathered params should be freed and not leak memory + assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized" + + +class TestGatheredParametersAllRanksErrorOnModification(DistributedTest): + world_size = 2 + + def test(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "enable_sanity_checks": True + } + } + hidden_dim = 10 + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, hidden_dim) + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + error_local = False + try: + with deepspeed.zero.GatheredParameters([model.l1.weight, model.l2.weight], modifier_rank=None): + with torch.no_grad(): + model.l1.weight.add_(0.0) + except RuntimeError as exc: + if "in-place modification" in str(exc): + error_local = True + + error_global = reduce_boolean_flags(error_local, all) + if not error_global: + raise AssertionError("Expected in-place modification error on all ranks.") + + class TestSerialContext(DistributedTest): world_size = 1 init_distributed = False @@ -101,6 +164,8 @@ def test_scattered_init_dist(self): assert dist.is_initialized() def test_scatter_halftype(self): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") setup_serial_env() with deepspeed.zero.Init(): @@ -190,9 +255,9 @@ def test_throughput_calculation(self): engine.tput_timer.stop(global_step=global_step) duration = engine.tput_timer.end_time - engine.tput_timer.start_time # step elapsed time is reset after gradient accumulation steps - assert engine.tput_timer.step_elapsed_time == ( - 0 if engine.tput_timer.global_step_count != engine.tput_timer.start_step else current_duration + - duration) + assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count + != engine.tput_timer.start_step else current_duration + + duration) assert engine.tput_timer.total_elapsed_time == total_duration + duration def test_ext_param_getattr(self): @@ -225,7 +290,7 @@ def forward(self, input): with deepspeed.zero.GatheredParameters(net.linear1.weight): assert net.linear1.weight.numel() == net.dim**2 - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) engine.backward(loss) engine.step() diff --git a/tests/unit/runtime/zero/test_zero_context_ancestry.py b/tests/unit/runtime/zero/test_zero_context_ancestry.py index 21955f5df152..77a8744ab5bc 100644 --- a/tests/unit/runtime/zero/test_zero_context_ancestry.py +++ b/tests/unit/runtime/zero/test_zero_context_ancestry.py @@ -32,7 +32,7 @@ # test that sub-classes get params that aren't prematurely partitioned and thus requiring gathering -# fixed by https://github.com/microsoft/DeepSpeed/pull/1202 +# fixed by https://github.com/deepspeedai/DeepSpeed/pull/1202 class GrandPa(torch.nn.Module): def __init__(self, *args): diff --git a/tests/unit/runtime/zero/test_zero_context_return.py b/tests/unit/runtime/zero/test_zero_context_return.py index 874a8ea3b676..82aafc2cddeb 100644 --- a/tests/unit/runtime/zero/test_zero_context_return.py +++ b/tests/unit/runtime/zero/test_zero_context_return.py @@ -8,9 +8,10 @@ import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.accelerator import get_accelerator from utils import setup_serial_env -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype class DanglingBias(torch.nn.Linear): @@ -119,16 +120,17 @@ def forward(self, input): "lr": 0.00015 } }, - "fp16": { - "enabled": True, - "loss_scale": 138. - }, "zero_optimization": { "stage": 3, "stage3_param_persistence_threshold": 1, } } +if get_accelerator().is_bf16_supported(): + config["bf16"] = {"enabled": True} +elif get_accelerator().is_fp16_supported(): + config["fp16"] = {"enabled": True, "loss_scale": 138.} + class TestReturnParam(DistributedTest): world_size = 1 @@ -142,7 +144,7 @@ def test_ext_param_return(self): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(5): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) engine.backward(loss) engine.step() @@ -158,7 +160,7 @@ def test_ext_param_returnobj(self): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(5): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) assert len(net._external_params) == 1 assert len(net.dangler._external_params) == 0 @@ -176,7 +178,7 @@ def test_stage_3_output_type(self, output_type): engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config) for _ in range(1): - input = torch.rand(net.dim).to(engine.device).half() + input = torch.rand(net.dim).to(engine.device).to(preferred_dtype()) loss = engine(input) if loss is not None: if isinstance(loss, dict): diff --git a/tests/unit/runtime/zero/test_zero_dynamic_class.py b/tests/unit/runtime/zero/test_zero_dynamic_class.py new file mode 100644 index 000000000000..e235206d4dc4 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_dynamic_class.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from unit.common import DistributedTest + +import deepspeed + + +class TestNewClassDeclaredNestingInit(DistributedTest): + world_size = 1 + + def test_new_class_declared_nesting_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = MyModel() + + # ensure that zero3 processed the parameter + assert hasattr(model.fc.weight, "ds_id") + deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config) + + +class TestNewClassDeclaredInsideNestingInit(DistributedTest): + world_size = 1 + + def test_new_class_declared_inside_nesting_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(1, 1) + + model = MyModel() + + # ensure that zero3 processed the parameter + assert hasattr(model.fc.weight, "ds_id") + deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config) diff --git a/tests/unit/runtime/zero/test_zero_grad_clip.py b/tests/unit/runtime/zero/test_zero_grad_clip.py new file mode 100644 index 000000000000..9ce50c2ca9ce --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_grad_clip.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import pytest +import deepspeed +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.utils import safe_get_local_grad, safe_set_local_grad +from deepspeed.accelerator import get_accelerator +from unit.simple_model import SimpleModel +import os + + +def get_config(precision, clip_value, offload_device="cpu"): + config = { + "train_batch_size": 8, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": offload_device + }, + "contiguous_gradients": True, + "overlap_comm": False, + }, + "gradient_clipping": 1.0, + } + + if precision == "fp16": + config["fp16"] = { + "enabled": True, + "loss_scale": 1024, + "initial_scale_power": 10, + } + elif precision == "bf16": + config["bf16"] = { + "enabled": True, + } + + return config + + +@pytest.mark.parametrize("precision,clip_value,offload_device", [ + ("fp16", 0.5, "cpu"), + ("bf16", 0.05, "cpu"), + ("fp16", 0.5, "none"), + ("bf16", 0.05, "none"), +]) +class TestZeroGradClip(): + world_size = 1 + + def test_grad_clip_and_norm_update(self, precision, clip_value, offload_device): + """Test custom gradient clipping with configurations and to check if the norm_groups are updated correctly""" + config_dict = get_config(precision, clip_value, offload_device) + + model = SimpleModel(hidden_dim=10) + + # Set up distributed environment variables + os.environ['LOCAL_RANK'] = '0' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + + try: + model_engine, optimizer, _, _ = deepspeed.initialize(args=None, + model=model, + config=config_dict, + model_parameters=model.parameters(), + dist_init_required=True) + except Exception as e: + pytest.skip("Could not initialize deepspeed") + + assert isinstance(optimizer, DeepSpeedZeroOptimizer_Stage3) + + torch.manual_seed(1670) + inputs = torch.randn(8, 10, device=model_engine.device) + targets = torch.randn(8, 10, device=model_engine.device) + + if model_engine.fp16_enabled() and get_accelerator().is_fp16_supported(): + inputs = inputs.half() + targets = targets.half() + elif model_engine.bfloat16_enabled() and get_accelerator().is_bf16_supported(): + inputs = inputs.bfloat16() + targets = targets.bfloat16() + else: + pytest.skip("Unsupported precision") + + loss = model_engine(inputs, targets) + model_engine.backward(loss) + + pre_clip_norm_groups = optimizer._get_norm_groups() + pre_clip_global_norm = torch.linalg.vector_norm(torch.stack(pre_clip_norm_groups)) + + modified_count = 0 + + for param in model_engine.parameters(): + if not hasattr(param, 'ds_id'): + continue + + grad = safe_get_local_grad(param) + if grad is not None: + pre_clip_norm = grad.norm().item() + clamped_grad = torch.clamp(grad, -clip_value, clip_value) + post_clip_norm = clamped_grad.norm().item() + + if pre_clip_norm > clip_value: + # Checks if the post-clip norm is less than the pre-clip norm + assert post_clip_norm < pre_clip_norm, f"Post-clip norm should be < pre-clip norm for param {param.ds_id}" + + safe_set_local_grad(param, clamped_grad) + modified_count += 1 + + # Get post-clip state + post_clip_norm_groups = optimizer._get_norm_groups() + post_clip_global_norm = torch.linalg.vector_norm(torch.stack(post_clip_norm_groups)) + + assert modified_count > 0, "No parameters were modified during clipping" + assert post_clip_global_norm.item() < pre_clip_global_norm.item( + ), f"Post-clip norm {post_clip_global_norm.item():.6f} should be < pre-clip norm {pre_clip_global_norm.item():.6f}" + + model_engine.step() + final_norm = optimizer._global_grad_norm + if pre_clip_global_norm.item() > clip_value: + assert post_clip_global_norm.item() < pre_clip_global_norm.item( + ), "Global norm should be reduced after clipping when pre-clip norm > clip_value" diff --git a/tests/unit/runtime/zero/test_zero_late_module_attach.py b/tests/unit/runtime/zero/test_zero_late_module_attach.py new file mode 100644 index 000000000000..641500a36232 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_late_module_attach.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression tests for issue #6961. + +ZeRO-3 forward used to crash with ``AttributeError: 'dict' object has no +attribute '_in_forward'`` when a submodule's ``_parameters`` was a plain +``dict`` instead of a ``ZeROOrderedDict``. PyTorch 2.5+ defaults +``nn.Module._parameters`` to ``dict`` (pytorch/pytorch#129164), and any +module not converted at ``DeepSpeedZeRoOffload`` init time hits the crash. +The tests force the plain-dict condition explicitly so they exercise the +fix on every supported torch version, not only torch 2.5+. +""" + +import torch + +import deepspeed +from deepspeed.runtime.zero.parameter_offload import (ZeROOrderedDict, ensure_zero_ordered_dict) + +from unit.common import DistributedTest, preferred_dtype + + +class _Tiny(torch.nn.Module): + + def __init__(self, hidden_dim=16): + super().__init__() + self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + return self.fc(x) + + +def _zero3_config(dtype): + return { + "train_batch_size": 1, + "fp16": { + "enabled": dtype is torch.float16 + }, + "bf16": { + "enabled": dtype is torch.bfloat16 + }, + "zero_optimization": { + "stage": 3 + }, + } + + +class TestZero3LateModuleAttach(DistributedTest): + world_size = 1 + + def test_forward_after_late_submodule_attach(self): + """Attaching a fresh ``nn.Linear`` after ``initialize`` must not crash.""" + hidden = 16 + dtype = preferred_dtype() + model = _Tiny(hidden) + engine, *_ = deepspeed.initialize(model=model, + config=_zero3_config(dtype), + model_parameters=list(model.parameters())) + + late = torch.nn.Linear(hidden, hidden, bias=False).to(device=engine.device, dtype=dtype) + # Force the post-pytorch/pytorch#129164 condition deterministically so + # the test exercises the fix regardless of the installed torch version. + late._parameters = dict(late._parameters) + engine.module.late = late + + x = torch.randn(2, hidden, dtype=dtype, device=engine.device) + engine(x) + + # Prologue must have lazily converted the late submodule. + assert isinstance(engine.module.late._parameters, ZeROOrderedDict) + + def test_idempotent_on_already_injected_modules(self): + """Repeated forwards must not re-wrap an already-converted ``_parameters``.""" + hidden = 16 + dtype = preferred_dtype() + model = _Tiny(hidden) + engine, *_ = deepspeed.initialize(model=model, + config=_zero3_config(dtype), + model_parameters=list(model.parameters())) + + first_pdict = engine.module.fc._parameters + assert isinstance(first_pdict, ZeROOrderedDict) + + x = torch.randn(2, hidden, dtype=dtype, device=engine.device) + engine(x) + engine(x) + + assert engine.module.fc._parameters is first_pdict + + +class TestEnsureZeroOrderedDict: + """Direct unit tests for the helper. No distributed harness needed.""" + + def test_skips_already_converted(self): + m = torch.nn.Linear(4, 4, bias=False) + m._parameters = ZeROOrderedDict(parent_module=m) + before = m._parameters + ensure_zero_ordered_dict(m) + assert m._parameters is before + + def test_wraps_plain_dict(self): + m = torch.nn.Linear(4, 4, bias=False) + m._parameters = dict(m._parameters) + ensure_zero_ordered_dict(m) + assert isinstance(m._parameters, ZeROOrderedDict) + assert "weight" in m._parameters + assert m._original_parameters is not m._parameters + + def test_preserves_existing_original_parameters(self): + """Subsequent wraps must not clobber the first-saved original. + + ``_inject_parameters`` at engine init records the true torch-native + container in ``_original_parameters``; the deepcompile path in + ``init_z3.py`` reads it back to un-inject. If the helper later runs + after some intermediate replacement of ``_parameters``, it must not + overwrite that saved reference. + """ + m = torch.nn.Linear(4, 4, bias=False) + sentinel = m._parameters + m._original_parameters = sentinel + m._parameters = dict(sentinel) # different object, same contents + ensure_zero_ordered_dict(m) + assert m._original_parameters is sentinel + + def test_noop_when_parameters_missing(self): + """Helper must not raise when ``_parameters`` is missing or None.""" + + class Bare: + pass + + m = Bare() + ensure_zero_ordered_dict(m) # no-op, no exception + m._parameters = None + ensure_zero_ordered_dict(m) # no-op, no exception + assert m._parameters is None diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py new file mode 100644 index 000000000000..88898403ec43 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -0,0 +1,544 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed.comm as dist +import torch + +from unit.common import DistributedTest, preferred_dtype +from unit.simple_model import random_dataloader + +import deepspeed +from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, \ + set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig +from deepspeed.runtime.zero.leaf_module_config import (DEFAULT_LEAF_MODULE_CLASSES, DEFAULT_LEAF_MODULE_NAMES, + DEFAULT_LEAF_MODULE_NAME_SUFFIXES) +from deepspeed.accelerator import get_accelerator +from torch import nn +import time + + +class ChooseModuleByCounter(torch.nn.Module): + + def __init__(self, hidden_dim): + super(ChooseModuleByCounter, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), + torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + self.counter = 0 + + def forward(self, x, y): + # This fails without setting this module as a leaf module. + # See the comment in `set_z3_leaf_modules()`. + x = self.linears[self.counter % len(self.linears)](x) + x = self.act(x) + loss = self.cel(x, y) + self.counter += 1 + return x, loss + + +class ChooseModuleByRankModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(ChooseModuleByRankModel, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), + torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + # Each rank runs only one of the linear layers + x = self.linears[dist.get_rank() % len(self.linears)](x) + x = self.act(x) + loss = self.cel(x, y) + return x, loss + + +class MultiOutputMoEBlock(nn.Module): + """A simplified MoE block that returns multiple tensors. + + This model mimics Qwen3 MoE which returns (hidden_states, router_logits). + When used with ZeRO3 leaf modules and autograd multithreading enabled, + this pattern previously caused race conditions in fetch_sub_module + because backward hooks could be triggered concurrently from multiple threads. + + See: https://github.com/deepspeedai/DeepSpeed/issues/7824 + """ + + def __init__(self, hidden_dim, num_experts=4): + super(MultiOutputMoEBlock, self).__init__() + self.num_experts = num_experts + self.gate = nn.Linear(hidden_dim, num_experts, bias=False) + self.experts = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(num_experts)]) + self.act = nn.ReLU() + self.cel = nn.CrossEntropyLoss() + + def forward(self, x, y): + # Compute router logits - this tensor will have gradients flowing through it + router_logits = self.gate(x) + + # Process through experts + for expert in self.experts: + x = expert(x) + x = self.act(x) + loss = self.cel(x, y) + + # Return multiple tensors - this triggers concurrent backward hooks + # when autograd multithreading is enabled + return x, loss, router_logits + + +class MLPBlock(nn.Module): + + def __init__(self, hidden_dim): + super(MLPBlock, self).__init__() + self.gate_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.act_fn = nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class FineGrainedBlock(nn.Module): + + def __init__(self, hidden_dim, num_block): + super(FineGrainedBlock, self).__init__() + self.num_block = num_block + self.mlp_layers = torch.nn.ModuleList([MLPBlock(hidden_dim=hidden_dim) for _ in range(self.num_block)]) + + def forward(self, x): + for i in range(self.num_block): + x = self.mlp_layers[i](x) + return x + + +class BaseLeafModule(nn.Module): + + def __init__(self): + super(BaseLeafModule, self).__init__() + + +class SubLeafModule(BaseLeafModule): + + def __init__(self, hidden_dim): + super(SubLeafModule, self).__init__() + self.proj = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + return self.proj(x) + + +class WrapperLeafModule(nn.Module): + + def __init__(self, hidden_dim): + super(WrapperLeafModule, self).__init__() + self.child = SubLeafModule(hidden_dim) + + def forward(self, x): + return self.child(x) + + +def test_set_leaf_modules_with_fully_qualified_name(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + fq_name = f"{SubLeafModule.__module__}.{SubLeafModule.__qualname__}" + + matched = set_z3_leaf_modules(model, [fq_name]) + + assert len(matched) == 1 + assert matched[0] is model.child + assert z3_leaf_module(model.child) + assert not z3_leaf_module(model) + + +def test_set_leaf_modules_no_raise_when_missing(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + + matched = set_z3_leaf_modules(model, ["NonExistentClass"], raise_if_not_found=False) + + assert matched == [] + assert not z3_leaf_module(model.child) + + +def test_set_leaf_modules_by_name(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + + matched, missing = set_z3_leaf_modules_by_name(model, ["child"]) + + assert matched == [model.child] + assert missing == [] + assert z3_leaf_module(model.child) + + +def test_set_leaf_modules_by_name_missing(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + + matched, missing = set_z3_leaf_modules_by_name(model, ["missing"], raise_if_not_found=False) + + assert matched == [] + assert missing == ["missing"] + + +def test_set_leaf_modules_by_suffix(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + + matched, missing = set_z3_leaf_modules_by_suffix(model, ["child"]) + + assert missing == [] + assert matched == [model.child] + assert z3_leaf_module(model.child) + + +def test_set_leaf_modules_by_suffix_missing(): + hidden_dim = 16 + model = WrapperLeafModule(hidden_dim) + + matched, missing = set_z3_leaf_modules_by_suffix(model, ["missing"], raise_if_not_found=False) + + assert matched == [] + assert missing == ["missing"] + + +def test_zero_leaf_module_default_config(): + config = DeepSpeedZeroConfig() + assert config.leaf_module.classes == DEFAULT_LEAF_MODULE_CLASSES + assert config.leaf_module.names == DEFAULT_LEAF_MODULE_NAMES + assert config.leaf_module.name_suffixes == DEFAULT_LEAF_MODULE_NAME_SUFFIXES + + +def test_zero_leaf_module_custom_config(): + payload = { + "leaf_module": { + "classes": ["custom.module.CustomClass"], + "names": ["transformer.layer"], + "name_suffixes": ["experts"] + } + } + + config = DeepSpeedZeroConfig(**payload) + + assert config.leaf_module.classes == ["custom.module.CustomClass"] + assert config.leaf_module.names == ["transformer.layer"] + assert config.leaf_module.name_suffixes == ["experts"] + + +def test_zero_leaf_module_string_coercion(): + payload = {"leaf_module": {"classes": "my.Class", "names": "submodule", "name_suffixes": "tail"}} + + config = DeepSpeedZeroConfig(**payload) + + assert config.leaf_module.classes == ["my.Class"] + assert config.leaf_module.names == ["submodule"] + assert config.leaf_module.name_suffixes == ["tail"] + + +@pytest.mark.skip(reason="Requires Hugging Face transformers; run manually when validating defaults.") +def test_default_leaf_module_classes_exist(): + import importlib + + from deepspeed.runtime.zero.leaf_module_config import DEFAULT_LEAF_MODULE_CLASSES + + for cls_path in DEFAULT_LEAF_MODULE_CLASSES: + module_name, _, class_name = cls_path.rpartition('.') + module = importlib.import_module(module_name) + assert hasattr(module, class_name), f"Expected {class_name} in {module_name}" + + +class modelWithFineGrainedBlock(nn.Module): + + def __init__(self, hidden_dim, num_block): + super(modelWithFineGrainedBlock, self).__init__() + self.coarse_grained_layer1 = nn.Linear(hidden_dim, 8 * hidden_dim) + self.coarse_grained_layer2 = nn.Linear(8 * hidden_dim, hidden_dim) + self.fine_grained_layer = FineGrainedBlock(hidden_dim, num_block) + self.cel = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + x = self.coarse_grained_layer1(x) + x = self.coarse_grained_layer2(x) + x = self.fine_grained_layer(x) + loss = self.cel(x, y) + return x, loss + + +def run_model(model, config_dict, hidden_dim, dtype, requires_grad): + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for batch in data_loader: + batch[0].requires_grad = requires_grad + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step() + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +class TestSetZ3LeafModule(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + def _create_zero_config(self, hidden_dim, leaf_module=None): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + } + } + if leaf_module is not None: + config_dict["zero_optimization"]["leaf_module"] = leaf_module + + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + return config_dict + + def _test_set_z3_leaf_modules(self, cls, requires_grad): + hidden_dim = 128 + config_dict = self._create_zero_config(hidden_dim) + + model = cls(hidden_dim) + + assert not z3_leaf_module(model) + set_z3_leaf_modules(model, [cls]) + assert z3_leaf_module(model) + + run_model(model, config_dict, hidden_dim, preferred_dtype(), requires_grad) + + def test_choose_module_by_counter(self): + self._test_set_z3_leaf_modules(ChooseModuleByCounter, True) + + def test_choose_module_by_rank(self): + self._test_set_z3_leaf_modules(ChooseModuleByRankModel, True) + + def test_multi_output_leaf_module_thread_safety(self): + """Test that leaf modules returning multiple tensors work correctly with autograd multithreading. + + This tests the fix for https://github.com/deepspeedai/DeepSpeed/issues/7824 + where MoE models (like Qwen3) returning multiple tensors caused race conditions + in fetch_sub_module when autograd executed backward hooks from multiple threads. + """ + # Ensure autograd multithreading is enabled (this is the default, but be explicit) + torch.autograd.set_multithreading_enabled(True) + + hidden_dim = 128 + config_dict = self._create_zero_config(hidden_dim) + + model = MultiOutputMoEBlock(hidden_dim, num_experts=4) + + assert not z3_leaf_module(model) + set_z3_leaf_modules(model, [MultiOutputMoEBlock]) + assert z3_leaf_module(model) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + + # Run multiple iterations to increase chance of hitting race conditions + for batch in data_loader: + batch[0].requires_grad = True + # Model returns (output, loss, router_logits) + output, loss, router_logits = model(batch[0], batch[1]) + # Include router_logits in the loss to ensure multiple backward paths + total_loss = loss + 0.01 * router_logits.mean() + model.backward(total_loss) + model.step() + + model.destroy() + + def test_multi_output_non_leaf_module_thread_safety(self): + """Ensure non-leaf modules returning multiple tensors remain thread-safe. + + This covers the multi-output autograd multithreading case without marking the + module as a ZeRO leaf module. + """ + torch.autograd.set_multithreading_enabled(True) + + hidden_dim = 128 + config_dict = self._create_zero_config(hidden_dim) + + model = MultiOutputMoEBlock(hidden_dim, num_experts=4) + assert not z3_leaf_module(model) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + + for batch in data_loader: + batch[0].requires_grad = True + output, loss, router_logits = model(batch[0], batch[1]) + total_loss = loss + 0.01 * router_logits.mean() + model.backward(total_loss) + model.step() + + model.destroy() + + def test_no_grad_input_error(self): + try: + self._test_set_z3_leaf_modules(ChooseModuleByCounter, False) + raise AssertionError( + "Expected RuntimeError: inputs with requires_grad=False is not supported for a leaf module") + except RuntimeError as e: + pass + + def test_set_unset_leaf_modules(self): + hidden_dim = 128 + model = ChooseModuleByCounter(hidden_dim) + assert len(set_z3_leaf_modules(model, [torch.nn.ModuleList])) == 1, \ + "Expected only one module to be set as a leaf module" + assert len(get_z3_leaf_modules(model)) == 1, "Expected there is only one leaf module" + + assert len(unset_z3_leaf_modules(model, [torch.nn.ModuleList])) == 1, \ + "Expected only one module to be unset as a leaf module" + assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module" + + def test_set_leaf_modules_with_subclass(self): + hidden_dim = 32 + model = WrapperLeafModule(hidden_dim) + + leaf_modules = set_z3_leaf_modules(model, [BaseLeafModule]) + + assert len(leaf_modules) == 1, "Expected the subclass instance to be marked as leaf" + assert leaf_modules[0] is model.child, "Expected the subclass instance to be returned" + assert z3_leaf_module(model.child), "Expected subclass instance flagged as leaf" + assert not z3_leaf_module(model), "Expected wrapper module to remain non-leaf" + + def test_set_no_match_class(self): + hidden_dim = 128 + model = ChooseModuleByCounter(hidden_dim) + try: + set_z3_leaf_modules(model, [torch.nn.Conv2d]) + raise AssertionError("Expected error that no module is set as a leaf module") + except ValueError as e: + pass + + def test_leaf_module_enabled_via_config(self): + hidden_dim = 128 + leaf_class_fqn = f"{ChooseModuleByCounter.__module__}.{ChooseModuleByCounter.__qualname__}" + config_dict = self._create_zero_config(hidden_dim, + leaf_module={ + "classes": [leaf_class_fqn], + "name_suffixes": ["linears"] + }) + + model = ChooseModuleByCounter(hidden_dim) + assert not z3_leaf_module(model) + + run_model(model, config_dict, hidden_dim, preferred_dtype(), True) + + assert z3_leaf_module(model) + modules_by_name = dict(model.named_modules()) + assert "linears" in modules_by_name + assert z3_leaf_module(modules_by_name["linears"]) + + +@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000]) +class TestZ3LeafOptimization(DistributedTest): + world_size = 2 + reuse_dist_env = True + + def test_finegrained_optimization(self, module_granularity_threshold: int): + hidden_dim = 128 + num_block = 16 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + def bench_loss_and_time(config): + warm_up_step = 10 + model = modelWithFineGrainedBlock(hidden_dim, num_block) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + loss_list = [] + + for i, batch in enumerate(data_loader): + if i == warm_up_step: + dist.barrier() + get_accelerator().synchronize() + start_time = time.time() + batch[0].requires_grad = True + loss = model(batch[0], batch[1]) + loss = loss[1] + loss_list.append(loss) + model.backward(loss) + model.step() + get_accelerator().synchronize() + end_time = time.time() + duration = end_time - start_time + model.destroy() + return loss_list, duration + + baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) + + config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = module_granularity_threshold + loss, duration = bench_loss_and_time(config_dict) + + if dist.get_rank() == 0: + print("baseline exec time:", baseline_exec_time) + print( + f"finegrained optimziation exec time: {duration},granularity threshold:{module_granularity_threshold} " + ) + assert baseline_loss_list == loss, f"incorrect loss value with threshold:{module_granularity_threshold}" diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py new file mode 100644 index 000000000000..15d82fd8be00 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from unit.common import DistributedTest + +from transformers import VisionEncoderDecoderModel +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +import deepspeed + + +class TestNestingInit(DistributedTest): + world_size = 1 + + def test_nesting_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = torch.nn.Linear(4, 4) + + # ensure that zero3 processed the parameter + assert hasattr(model.weight, "ds_id") + + deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config) + + +class TestShutdownInNestingInit(DistributedTest): + world_size = 1 + + def test_shutdown_in_nesting_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model1 = torch.nn.Linear(4, 4) + + assert hasattr(model1.weight, "ds_id") + deepspeed_engine1, *_ = deepspeed.initialize(model=model1, config_params=ds_config) + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model2 = torch.nn.Linear(4, 4) + + # ensure that zero3 processed the parameter + assert hasattr(model2.weight, "ds_id") + deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config) + + +class TestNestedParallelInit(DistributedTest): + world_size = 1 + + # Testing a model with composed and nested zero.Inits, with 3 zero.Init contexts, 1 parent and 2 children. + # The skeleton of the model is like so + # + # class VisionEncoderDecoderModel(...):: + # def __init__(self): + # encoder = AutoModel.from_config(config.encoder) + # decoder = AutoModelForCausalLM.from_config(config.decoder) + # + # And the user calls like below: + # VisionEncoderDecoderModel.from_pretrained(...) + # which calls this constructor inside zero.Init + + def test_nested_parallel_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + dschf = HfDeepSpeedConfig(ds_config) # keep this object alive + model = VisionEncoderDecoderModel.from_pretrained( + "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2") + assert all([hasattr(p, 'ds_id') for p in model.parameters()]) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py new file mode 100644 index 000000000000..32e7ccc4f9ae --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import random_dataloader + +import deepspeed +import torch +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig + +import torch.nn as nn + + +class NNModel(nn.Module): + + def __init__(self, h_dim=1024, n_layers=2): + super(NNModel, self).__init__() + self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)]) + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + return self.cross_entropy_loss(x, y) + + +def test_zero_partial_offload_config(): + config = DeepSpeedZeroOffloadOptimizerConfig(**{"ratio": 0.3}) + assert config.ratio == 0.3 + + +#Large sweep along hidden dim, num_layers of different sizes +@pytest.mark.parametrize("h_dim", [1024]) +@pytest.mark.parametrize("n_layers", [4, 8]) +class TestZeroPartialOffloadConfigSweep(DistributedTest): + world_size = 4 + + def test(self, h_dim: int, n_layers: int) -> None: + + config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "gradient_clipping": 1.0, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 3, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=h_dim, + device=model.device, + dtype=torch.float16) + dist.barrier() + + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index 459d41f98eea..90e8e968abdf 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -6,24 +6,43 @@ import pytest import deepspeed.comm as dist import torch +import math from unit.common import DistributedTest -from unit.simple_model import random_dataloader +from unit.simple_model import random_dataloader, SimpleModel from unit.util import bf16_required_version_check import deepspeed from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state +from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_grad, safe_set_full_optimizer_state +from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state +from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state +from deepspeed.utils import safe_update_full_grad_vectorized from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.swap_tensor import MIN_SWAPPABLE_BYTES +WEIGHT_KEY = 'weight' +FIRST_ORDER_KEY = 'exp_avg' +SECOND_ORDER_KEY = 'exp_avg_sq' +GRADIENT_KEY = 'gradient' -def validate_full_tensors(model): + +def validate_tensor(model, api_type, opt_states): + assert api_type in ["full", "local"] for _, lp in model.named_parameters(): - hp = safe_get_full_fp32_param(lp) - exp_avg = safe_get_full_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_full_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_full_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] + param_list = [] + if opt_states: + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg')) + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg_sq') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg_sq')) + else: + param_list.append(safe_get_full_fp32_param(lp) if api_type == "full" else safe_get_local_fp32_param(lp)) + param_list.append(safe_get_full_grad(lp) if api_type == "full" else safe_get_local_grad(lp)) if lp.requires_grad: assert all([p is not None for p in param_list]) else: @@ -48,12 +67,10 @@ def forward(self, x, y): for l in self.linears: x = l(x) x = self.act(x) - loss = self.cel(x, y) - val = (x, loss) - return val + return self.cel(x, y) -def run_fragmented_model(model, config_dict, hidden_dim, dtype): +def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_after_bwd, validate_after_step): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -63,26 +80,38 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype): dist.barrier() for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) - loss = loss[1] model.backward(loss) - validate_full_tensors(model) + validate_after_bwd(model) model.step() + validate_after_step(model) + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() @pytest.mark.parametrize('frozen_weights', [True, False]) -class TestTensorFragment(DistributedTest): +class TestTensorFragmentGet(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 + reuse_dist_env = True + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('api_type', ['local', 'full']) @pytest.mark.parametrize('zero_stage', [1, 2, 3]) @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) - def test_zero_fragments(self, tmpdir, zero_stage, offload_device, frozen_weights): + def test_zero_fragments(self, tmpdir, dtype, api_type, zero_stage, offload_device, frozen_weights): + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"{get_accelerator()._name} does not support {dtype} data type") + if offload_device == OffloadDeviceEnum.nvme: if zero_stage != 3: pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: pytest.skip('Skip tests since async-io is not compatible') + if api_type == "local" and zero_stage != 3: + pytest.skip(f"Local APIs only for zero stage 3 but current stage is {zero_stage}") + config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -92,15 +121,16 @@ def test_zero_fragments(self, tmpdir, zero_stage, offload_device, frozen_weights "lr": 1e-6 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 2 - }, "zero_optimization": { "stage": zero_stage, } } + if dtype == torch.half: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 2} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + if offload_device == OffloadDeviceEnum.cpu: config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} elif offload_device == OffloadDeviceEnum.nvme: @@ -109,20 +139,25 @@ def test_zero_fragments(self, tmpdir, zero_stage, offload_device, frozen_weights "nvme_path": str(tmpdir) } - hidden_dim = 128 + hidden_dim = MIN_SWAPPABLE_BYTES if zero_stage == 3: with deepspeed.zero.Init(config_dict_or_path=config_dict): model = MyModel(hidden_dim, frozen_weights) else: model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16) + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) + + run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_after_bwd, validate_after_step) - def test_bf16_fragments(self, frozen_weights): + def test_bf16_optimizer_fragments(self, frozen_weights): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet.") if frozen_weights: pytest.skip("TODO: Frozen weights not currently supported by BF16 Optimizer") - if not bf16_required_version_check(accelerator_check=False): + if not bf16_required_version_check(): pytest.skip( " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) @@ -138,11 +173,288 @@ def test_bf16_fragments(self, frozen_weights): "bf16": { "enabled": True }, + # Use fp32 gradient accumulation to ensure BF16_Optimizer is used + # (bf16 model + bf16 grad_accum uses FP16_Optimizer which doesn't support tensor fragment APIs) + "data_types": { + "grad_accum_dtype": "fp32" + }, "zero_optimization": { - "stage": 0, + "stage": 1, } } hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16) + + api_type = "full" + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) + + run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_after_bwd, validate_after_step) + + +def create_random_values(model, key_list, group, grad_dtype): + param_values = {} + for n, lp in model.named_parameters(): + param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape + param_values[n] = {} + for key in key_list: + dtype = grad_dtype if key == GRADIENT_KEY else torch.float32 + rand_value = torch.rand(param_shape, dtype=dtype, device=model.device) + dist.broadcast(rand_value, src=0, group=group) + param_values[n][key] = rand_value + return param_values + + +def set_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, value_tensor in value_dict[n].items(): + if key == GRADIENT_KEY: + safe_set_full_grad(lp, value_tensor) + elif key == WEIGHT_KEY: + safe_set_full_fp32_param(lp, value_tensor) + else: + safe_set_full_optimizer_state(lp, value_tensor, key) + + +def update_param_values_with_dict(model, value_dict): + new_grad_values = {} + for n, lp in model.named_parameters(): + if GRADIENT_KEY in value_dict[n]: + new_grad_values[id(lp)] = value_dict[n][GRADIENT_KEY] + + def update_gradient_callback(old_value, param): + return new_grad_values[id(param)] + + update_param_list = [] + for n, lp in model.named_parameters(): + for key, value_tensor in value_dict[n].items(): + if key == GRADIENT_KEY: + update_param_list.append(lp) + + if len(update_param_list) > 0: + safe_update_full_grad_vectorized(update_param_list, update_gradient_callback) + + +def validate_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, expected_tensor in value_dict[n].items(): + if key == GRADIENT_KEY: + actual_tensor = safe_get_full_grad(lp) + elif key == WEIGHT_KEY: + actual_tensor = safe_get_full_fp32_param(lp) + else: + actual_tensor = safe_get_full_optimizer_state(lp, key) + + assert torch.equal(expected_tensor, actual_tensor) + + +def create_random_values_for_local(model, key_list, group, grad_dtype): + param_values = {} + for n, lp in model.named_parameters(): + param_shape = lp.ds_tensor.shape + param_values[n] = {} + for key in key_list: + dtype = grad_dtype if key == GRADIENT_KEY else torch.float32 + rand_value = torch.rand(param_shape, dtype=dtype, device=model.device) + param_values[n][key] = rand_value + return param_values + + +def set_local_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + + for key, value_tensor in value_dict[n].items(): + if key == GRADIENT_KEY: + safe_set_local_grad(lp, value_tensor) + elif key == WEIGHT_KEY: + safe_set_local_fp32_param(lp, value_tensor) + else: + safe_set_local_optimizer_state(lp, value_tensor, key) + + +def validate_local_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, expected_tensor in value_dict[n].items(): + if key == GRADIENT_KEY: + actual_tensor = safe_get_local_grad(lp) + elif key == WEIGHT_KEY: + actual_tensor = safe_get_local_fp32_param(lp) + else: + actual_tensor = safe_get_local_optimizer_state(lp, key) + + assert torch.equal(expected_tensor, actual_tensor) + + +helper_funcs_mapping = { + "full": { + "create_random_values": create_random_values, + "set_param_values_with_dict": set_param_values_with_dict, + "update_param_values_with_dict": update_param_values_with_dict, + "validate_param_values_with_dict": validate_param_values_with_dict, + }, + "local": { + "create_random_values": create_random_values_for_local, + "set_param_values_with_dict": set_local_param_values_with_dict, + "validate_param_values_with_dict": validate_local_param_values_with_dict + } +} + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +class TestTensorFragmentSet(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('api_type', ['local', 'full']) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtype): + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"{get_accelerator()._name} does not support {dtype} data type") + + if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if api_type == "local" and zero_stage != 3: + pytest.skip(f"Local APIs only for zero stage 3 but current stage is {zero_stage}") + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + hidden_dim = int(math.sqrt(MIN_SWAPPABLE_BYTES)) + if zero_stage == 3: + config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim) + else: + model = SimpleModel(hidden_dim) + + world = dist.get_world_size() + group = dist.new_group(ranks=list(range(world))) + + dist.barrier() + + def after_bwd_validate_func(model): + state_keys = [WEIGHT_KEY, GRADIENT_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"](model, state_keys, group, grad_dtype=dtype) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + def after_step_validate_func(model): + state_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"](model, state_keys, group, grad_dtype=dtype) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + run_fragmented_model(model, config_dict, hidden_dim, dtype, after_bwd_validate_func, after_step_validate_func) + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +class TestTensorFragmentUpdate(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('torch_adam', [False, True]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_zero_fragments(self, tmpdir, torch_adam, zero_stage, offload_device, dtype): + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"{get_accelerator()._name} does not support {dtype} data type") + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6, + "torch_adam": torch_adam + } + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + hidden_dim = int(math.sqrt(MIN_SWAPPABLE_BYTES)) + if zero_stage == 3: + config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim) + else: + model = SimpleModel(hidden_dim) + + world = dist.get_world_size() + group = dist.new_group(ranks=list(range(world))) + + dist.barrier() + + api_type = "full" + + def after_bwd_validate_func(model): + state_keys = [GRADIENT_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"](model, state_keys, group, grad_dtype=dtype) + helper_funcs["update_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + def after_step_validate_func(model): + pass + + run_fragmented_model(model, config_dict, hidden_dim, dtype, after_bwd_validate_func, after_step_validate_func) diff --git a/tests/unit/runtime/zero/test_zeropp.py b/tests/unit/runtime/zero/test_zeropp.py new file mode 100644 index 000000000000..e8846d521b88 --- /dev/null +++ b/tests/unit/runtime/zero/test_zeropp.py @@ -0,0 +1,288 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest +import deepspeed.comm as dist +from torch.nn import Module + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader + +import deepspeed + +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig + +import torch.nn as nn +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.utils.data import DataLoader + +import numpy as np + + +class NNModel(nn.Module): + + def __init__(self, h_dim=1024, n_layers=2): + super(NNModel, self).__init__() + self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)]) + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + return self.cross_entropy_loss(x, y) + + +def test_zero_hpz_partition_size_config(): + config = DeepSpeedZeroConfig(**{"zero_hpz_partition_size": 4}) + assert config.zero_hpz_partition_size == 4 + + +def _assert_no_secondary_tensor_group(model: Module) -> None: + for _, param in model.named_parameters(): + assert param.ds_secondary_tensor is None + assert param.ds_zero_param_process_group is None + + +def _check_secondary_tensor_existence(model: Module) -> None: + for _, param in model.named_parameters(): + if param.ds_secondary_tensor is not None: + return True + return False + + +def _assert_secondary_tensor_size(model: Module) -> None: + for name, param in model.named_parameters(): + assert param.ds_secondary_tensor is not None, f"param {param.ds_id}:{name} does not have secondary tensor" + assert param.ds_secondary_tensor.size()[0] % param.ds_tensor.size()[0] == 0 + + +#Large sweep along hidden dim, num_layers, and zpg of different sizes +#Assert when zpg=1 that secondary group and tensors are invalid +@pytest.mark.sequential +@pytest.mark.parametrize("h_dim", [1024]) +@pytest.mark.parametrize("n_layers", [9]) +@pytest.mark.parametrize("zpg", [1, 2, 4]) +class TestZeroPPConfigSweep(DistributedTest): + world_size = 4 + + def test(self, h_dim: int, n_layers: int, zpg: int) -> None: + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "zero_hpz_partition_size": zpg, + "zero_quantized_weights": True, + "zero_quantized_gradients": True, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=h_dim, + device=model.device, + dtype=torch.float16) + dist.barrier() + if zpg == 1: + _assert_no_secondary_tensor_group(model) + + for n, batch in enumerate(data_loader): + if n == 0 and zpg != 1: + _assert_secondary_tensor_size(model) + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + def test_eval(self, h_dim: int, n_layers: int, zpg: int) -> None: + # in this test case, we are testing that hpz should be enabled when eval mode is on + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "zero_hpz_partition_size": zpg, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=h_dim, + device=model.device, + dtype=torch.float16) + dist.barrier() + if zpg == 1: + _assert_no_secondary_tensor_group(model) + + for n, batch in enumerate(data_loader): + if zpg != 1: + # here we check that the hpz is enabled when the previous iteration does not update the model + _assert_secondary_tensor_size(model) + with torch.no_grad(): + loss = model(batch[0], batch[1]) + + def test_gradient_accumulation(self, h_dim: int, n_layers: int, zpg: int) -> None: + # in this test case, we are testing that hpz should be enabled for the intermediate gradient accumulation steps + # In this test, we should disable loss_scale + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 3, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "zero_hpz_partition_size": zpg, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0., + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=h_dim, + device=model.device, + dtype=torch.float16) + dist.barrier() + if zpg == 1: + _assert_no_secondary_tensor_group(model) + + for n, batch in enumerate(data_loader): + if n == 0 and zpg != 1: + _assert_secondary_tensor_size(model) + # here we cannot assert that secondary tensor does not exist because the gradient is likely overflowed as we use random data + if n > 0 and n % 3 != 0 and zpg != 1: + # if the previous iteration does not update the model, then the hpz should be enabled + assert _check_secondary_tensor_existence(model), f"n={n}" + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + +@pytest.mark.nightly +@pytest.mark.parametrize("model_name", ["gpt2"]) +class TestZeroPPConvergence(DistributedTest): + world_size = 4 + + def load_and_prepare_data(self, model_name): + """Load model, tokenizer and dataset, and prepare data loader.""" + from datasets import load_dataset + + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Load and tokenize dataset + dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]').filter(lambda x: x["text"]) + + def tokenize_function(examples): + # Tokenize and ensure 'labels' are the same as 'input_ids' + tokenized_output = tokenizer(examples["text"], padding="max_length", truncation=True, return_tensors='pt') + tokenized_output["labels"] = tokenized_output["input_ids"].clone() + return tokenized_output + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels']) + + # Create data loader + data_loader = DataLoader(tokenized_dataset, batch_size=1, shuffle=False) + return model, data_loader + + def get_loss(self, model, data_loader, config_dict, step=500): + """Train the model and calculate average loss.""" + # Initialize DeepSpeed + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + dist.barrier() + model.train() + + # Training loop + losses = [] + for n, batch in enumerate(data_loader): + if n >= step: + break + batch = {k: v.to(model.device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + model.backward(loss) + model.step() + losses.append(loss.item()) + + return np.nanmean(losses[-100:]) + + def get_config_dict(self, use_quantized_weights=False, use_hpz=False): + """Generate the configuration dictionary for DeepSpeed.""" + config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5 + } + }, + "fp16": { + "enabled": True + } + } + if use_quantized_weights: + config["zero_optimization"]["zero_quantized_weights"] = True + if use_hpz: + config["zero_optimization"]["zero_hpz_partition_size"] = self.world_size // 2 + return config + + def test(self, model_name): + torch.manual_seed(0) + model, data_loader = self.load_and_prepare_data(model_name) + zeropp_loss = self.get_loss(model, data_loader, self.get_config_dict(use_quantized_weights=True, use_hpz=True)) + model, data_loader = self.load_and_prepare_data(model_name) + baseline_loss = self.get_loss(model, data_loader, self.get_config_dict()) + + # Output and assert + print(f"zeropp_loss={zeropp_loss}, baseline_loss={baseline_loss}") + assert zeropp_loss < baseline_loss * 1.1, f"zeropp_loss={zeropp_loss}, baseline_loss={baseline_loss}" diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py new file mode 100644 index 000000000000..0c289353a29d --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Numerical equivalence tests for AutoSP multimodal sequence parallelism. + +Each test verifies that running the SP-wrapped path across N ranks produces +the same result as the equivalent single-device (non-SP) computation. + +These tests require 2 GPUs. +Run with: + + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytest + +import deepspeed.comm as dist +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, LlavaFusionAdapter, Qwen2VLFusionAdapter +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Shared identity attention — deterministic, easy to verify +# --------------------------------------------------------------------------- + +_IMAGE_TOKEN_ID = -200 + + +class _IdentityAttn(nn.Module): + """Returns hidden_states unchanged so that gather-compute-scatter is a no-op.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention equivalence +# --------------------------------------------------------------------------- + + +class TestViTSPEquivalence(DistributedTest): + """SP-wrapped ViT attention with an identity inner module must reproduce + the unsharded output on every rank.""" + + world_size = 2 + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches", [8, 12]) + def test_output_equals_single_device(self, has_cls_token, num_patches): + """Each rank's local output slice must match the corresponding slice of + the single-device output.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 32 + + # --- Single-device reference --- + # Build the full input (all ranks see the same RNG seed so the tensor + # is identical everywhere). + torch.manual_seed(42) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) + else: + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) + + identity = _IdentityAttn().to(get_accelerator().device_name()) + # Single-device path is just identity — output == input. + ref_out = identity(full_input) + + # --- SP path --- + local_patches = num_patches // self.world_size + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, rank * local_patches:(rank + 1) * local_patches, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token).to(get_accelerator().device_name()) + sp_out = wrapper(local_input) + + # --- Compare --- + # sp_out is the local slice; reconstruct what slice of ref_out it maps to. + if has_cls_token: + ref_slice = torch.cat( + [ref_out[:, :1, :], ref_out[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :]], dim=1) + else: + ref_slice = ref_out[:, rank * local_patches:(rank + 1) * local_patches, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} sp_out differs from reference: " + f"max_diff={( sp_out - ref_slice).abs().max().item():.2e}") + + @pytest.mark.parametrize("has_cls_token", [True, False]) + def test_noneven_patches(self, has_cls_token): + """When num_patches % world_size != 0, the wrapper must still produce + correct per-rank output. With 5 patches and world_size=2, rank 0 + holds 3 patches and rank 1 holds 2 patches.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 16 + num_patches = 5 # not divisible by world_size=2 + + torch.manual_seed(77) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) + else: + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) + + # Distribute: first (num_patches % world_size) ranks carry one extra patch. + extra = num_patches % self.world_size # = 1 + base = num_patches // self.world_size # = 2 + local_v = base + (1 if rank < extra else 0) + patch_start = rank * base + min(rank, extra) + + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + patch_start:1 + patch_start + local_v, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, patch_start:patch_start + local_v, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token) + sp_out = wrapper(local_input) + + # Reference: identity wrapper — each rank's output must equal its input slice. + if has_cls_token: + ref_slice = torch.cat([full_input[:, :1, :], full_input[:, 1 + patch_start:1 + patch_start + local_v, :]], + dim=1) + else: + ref_slice = full_input[:, patch_start:patch_start + local_v, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} non-even patches: sp_out differs from reference: " + f"max_diff={(sp_out - ref_slice).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter equivalence +# --------------------------------------------------------------------------- + + +class TestLlavaFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in LlavaFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full + fused sequence that single-device splicing would produce.""" + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank): + """Build deterministic visual and text tensors identical on every rank.""" + torch.manual_seed(0) + # Each rank holds a contiguous slice of the visual tokens. + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 1] = _IMAGE_TOKEN_ID # one image placeholder at position 1 + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank) + + # --- SP path: each rank gets one shard --- + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) # [bs, local_fused, hidden] + + # Gather all shards onto every rank so we can compare globally. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) # [bs, padded_fused, hidden] + + # --- Single-device reference --- + # Simulate what a non-SP LlavaFusionAdapter would produce: project the + # full visual tensor (identity here) and splice once. + ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + # Call _splice_visual_into_text directly so we bypass the SP scatter. + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + # Pad reference to the same padded length. + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} reassembled SP output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_image_token_passthrough(self): + """When there are no image placeholders the SP fused output must equal + the sharded text after padding/scatter (all-text path).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(1) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) # no image placeholder + + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + # Gather shards and strip the padding slice from visual gather. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Expected: when there is no image token, the visual tokens are ignored. + # So the fused output should just be the text tokens. + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} no-image path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_TOKEN_ID = 92546 + + +class TestInternVLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in InternVLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + InternVL replaces IMG_CONTEXT tokens 1-to-1 with visual tokens, so the + sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_ctx_tokens): + """Build deterministic inputs with a run of IMG_CONTEXT tokens in the middle.""" + torch.manual_seed(2) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # Place IMG_CONTEXT tokens starting at position 2. + ids[:, 2:2 + num_ctx_tokens] = _INTERNVL_CONTEXT_TOKEN_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + full_visual, local_visual, text, ids = self._build_inputs(bs, + local_v, + text_len, + hidden, + rank, + num_ctx_tokens=local_v * self.world_size) + + # SP path. + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to( + get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_context_token_passthrough(self): + """When there are no IMG_CONTEXT tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 6, 4 + torch.manual_seed(3) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL no-context path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + + +class TestQwen2VLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in Qwen2VLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + Qwen2-VL replaces inner placeholder tokens (between vision_start/end pairs) + 1-to-1 with visual tokens, so the sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_inner): + """Build inputs with a single vision_start/end block containing num_inner placeholders.""" + torch.manual_seed(4) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # [t0, , pad×num_inner, , ...] + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + num_inner = local_v * self.world_size # inner placeholder count equals total visual tokens + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank, num_inner) + + # SP path. + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_vision_token_passthrough(self): + """When there are no vision_start/end tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(5) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL no-vision path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") diff --git a/tests/unit/sequence_parallelism/test_autosp_integration.py b/tests/unit/sequence_parallelism/test_autosp_integration.py new file mode 100644 index 000000000000..4efcdb07c302 --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_integration.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +End-to-end integration tests for AutoSP multimodal sequence parallelism. + +Each test builds a minimal mock model whose attention-layer class names match +the autosp_detector registry, then verifies two things: + +1. auto_wrap_model_for_sp correctly identifies and wraps ViT attention modules + (with the correct has_cls_token value from the registry) and emits warnings + for HF-style LLM attention without wrapping them. +2. The full pipeline (SP-wrapped ViT -> fusion adapter) produces fused output + numerically equivalent to the single-device splice reference. + +These tests require 2 GPUs. +Run with: + + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_integration.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes +# +# Class names must match exactly the entries in autosp_detector._VIT_ATTN_CLASSNAMES +# and _LLM_ATTN_CLASSNAMES so that auto_wrap_model_for_sp detects them. +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + """Mock ViT attention for InternVL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + """Mock LLM attention for InternVL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + """Mock ViT attention for Qwen2-VL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + """Mock LLM attention for Qwen2-VL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model skeleton helpers +# --------------------------------------------------------------------------- + + +class _AttnLayer(nn.Module): + """Generic transformer block that holds an attention submodule. + + auto_wrap_model_for_sp scans named_modules() and replaces ``self.attn`` + when its class name is in the detector's registry. + """ + + def __init__(self, attn: nn.Module) -> None: + super().__init__() + self.attn = attn + + def forward(self, x, **kwargs): + return self.attn(x, **kwargs) + + +class _MinimalInternVLModel(nn.Module): + """Minimal InternVL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``vision_encoder.0.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``language_model.0.attn`` -> InternLM2Attention (_LLM_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``forward`` exercises only the ViT + fusion path; ``language_model`` is + present to verify that auto_wrap does NOT wrap HF-style LLM attention. + """ + + def __init__(self) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(_AttnLayer(InternVisionAttention())) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(_AttnLayer(InternLM2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +class _MinimalQwen2VLModel(nn.Module): + """Minimal Qwen2-VL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``visual.0.attn`` -> Qwen2VLVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``model.0.attn`` -> Qwen2Attention (_LLM_ATTN_CLASSNAMES) + - ``multi_modal_projector`` -> keyword in _VISION_PROJ_KEYWORDS + """ + + def __init__(self) -> None: + super().__init__() + self.visual = nn.Sequential(_AttnLayer(Qwen2VLVisionAttention())) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(_AttnLayer(Qwen2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +# --------------------------------------------------------------------------- +# InternVL integration tests +# --------------------------------------------------------------------------- + + +class TestInternVLIntegration(DistributedTest): + """Integration tests for the InternVL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace InternVisionAttention with + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + InternLM2Attention (HF-style, incompatible with DistributedAttention).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.vision_encoder[0].attn, + UlyssesSPViTAttention), ("Expected vision_encoder[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert not model.vision_encoder[0].attn.has_cls_token, ( + "InternVisionAttention has no CLS token; has_cls_token must be False") + assert isinstance(model.language_model[0].attn, + InternLM2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> InternVLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 10, 8 + num_ctx = local_v * self.world_size + + torch.manual_seed(20) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference: splice without SP scatter. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2-VL integration tests +# --------------------------------------------------------------------------- + + +class TestQwen2VLIntegration(DistributedTest): + """Integration tests for the Qwen2-VL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace Qwen2VLVisionAttention with + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + Qwen2Attention (HF-style, incompatible with DistributedAttention).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.visual[0].attn, + UlyssesSPViTAttention), ("Expected visual[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert not model.visual[0].attn.has_cls_token, ( + "Qwen2VLVisionAttention has no CLS token; has_cls_token must be False") + assert isinstance(model.model[0].attn, + Qwen2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> Qwen2VLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 8 + num_inner = local_v * self.world_size + + torch.manual_seed(21) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}") diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py new file mode 100644 index 000000000000..7c1371864073 --- /dev/null +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import torch.nn.functional as F +import deepspeed.comm as dist +from deepspeed import initialize +from transformers import AutoModel +from unit.common import DistributedTest +from deepspeed.sequence.layer import _SeqAllToAll +from deepspeed.sequence.fpdt_layer import _FPDTGPUOffloadingAttentionImpl_, FPDT_InputConstruct +from unit.util import skip_on_arch +from unit.simple_model import * +from deepspeed.utils import groups +from deepspeed.module_inject.tp_shard import get_shard_size_list +#Use mesh device to create data and sequence parallel group + + +class TestUlyssesUtils(DistributedTest): + world_size = 4 + + def test_mesh_device_creation(self) -> None: + skip_on_arch(min_arch=8) + model = AutoModel.from_pretrained('bert-base-uncased') + sp_size = 2 + dp_size = 2 + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size + }, + ) + assert ds_engine.seq_parallel_group is not None + assert ds_engine.data_parallel_group is not None + assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size + assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size + assert dist.get_world_size() == sp_size * dp_size + + +#Sweep b,s,h,d to test all2all consistency +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [16, 32]) +class TestUlyssesAll2All(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + skip_on_arch(min_arch=8) + model = AutoModel.from_pretrained('bert-base-uncased') + ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2)) + #4D tensor : b,s,h,d or s,b,h,d + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + scatter_idx = 2 + batch_dim_idx = 0 + outputs = [] + seq_dims = [0] #seq first API + #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed + ## See discussion in : https://github.com/deepspeedai/DeepSpeed/issues/5808 + for seq_dim in seq_dims: + gather_idx = seq_dim + #first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + #No op + # second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs are the same as input + for i in range(1, len(outputs)): + assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [3, 7]) +@pytest.mark.parametrize("head_dim", [16]) +class TestUlyssesAll2All_odd(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + + data_parallel_size = 2 + seq_parallel_size = self.world_size // data_parallel_size + skip_on_arch(min_arch=8) + + def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): + d0 += offset_d0 + d1 += offset_d1 + h += offset_h + return d0 * 10 + h + d1 * 0.1 + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_engine, _, _, _ = initialize(model=model, + config_params={"train_batch_size": 8}, + mesh_param=(data_parallel_size, seq_parallel_size)) + + scatter_idx = 2 + outputs = [] + inputs = [] + batch_dims = [0, 1] + seq_dims = [1, 0] + + for idx, seq_dim in enumerate(seq_dims): + gather_idx = seq_dim + batch_dim_idx = batch_dims[idx] + + #4D tensor : b,s,h,d or s,b,h,d + #create a hash tensor from pos_id, head_id, and batch_id + d0_indices = torch.arange(d0).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(d1).reshape(1, -1, 1, 1) + h_indices = torch.arange(num_heads).reshape(1, 1, -1, 1) + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + if batch_dim_idx == 1: #seq_len_dim : 0(d0) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, + d0 * groups._get_sequence_parallel_rank(), 0) + elif batch_dim_idx == 0: #seq_len_dim : 1(d1) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, + d1 * groups._get_sequence_parallel_rank()) + inputs.append(input_tensor) + + ### first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + # s2h_tensor check for the first all2all: compare with the expected ground truth + d0_indices = torch.arange(s2h_tensor.shape[0]).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(s2h_tensor.shape[1]).reshape(1, -1, 1, 1) + h_indices = torch.arange(s2h_tensor.shape[2]).reshape(1, 1, -1, 1) + shard_list = get_shard_size_list(num_heads, groups._get_sequence_parallel_world_size()) + head_offset = sum(shard_list[:groups._get_sequence_parallel_rank()]) + s2h_truth = torch.zeros_like(s2h_tensor) + s2h_truth[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, 0, head_offset) + + assert torch.allclose(s2h_truth, + s2h_tensor), f"s2h_tensor differs from the expected for sequence dim: {seq_dim}" + #No op + ### second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs for the second all2all + for i in range(0, len(outputs)): + assert torch.allclose(inputs[i], + outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [4, 1]) #batch dimension +@pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [128, 256]) #size of chunk +@pytest.mark.parametrize("num_heads", [8, 4]) +@pytest.mark.parametrize("head_dim", [32]) +class TestFPDTAttention(DistributedTest): + + def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, + num_heads: int) -> None: + skip_on_arch(min_arch=8) + world_size = 2 + + try: + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + except ImportError: + _flash_attn_forward = None + _flash_attn_backward = None + + if _flash_attn_forward is None or _flash_attn_backward is None: + pytest.skip("Flash Attention is not available.") + + model = AutoModel.from_pretrained('bert-base-uncased') + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": 1, + "sequence_parallel_size": world_size + }, + ) + #3D tensor : l, b, d + dim = head_dim * num_heads + + seed = 42 + torch.manual_seed(seed) + get_accelerator().manual_seed_all(seed) + + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) # l, b, d + spg = ds_engine.seq_parallel_group + + dist.broadcast(input_tensor, src=0, group=spg) + + class args: + + def __init__(self): + self.ds_sequence_parallel_fpdt_chunk_size = chunk_size + + fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, args(), + world_size, dist.get_rank()).generate()[0].permute(1, 0, 2) + + if dist.get_rank() == 0: + qkv_linear_weight = torch.nn.Parameter( + torch.empty(dim + 2 * dim, dim, device=dist.get_rank(), dtype=torch.half)) + torch.nn.init.normal_(qkv_linear_weight, mean=0.0, std=0.02) + + qkv_linear_bias = torch.nn.Parameter(torch.empty(dim + 2 * dim, device=dist.get_rank(), dtype=torch.half)) + torch.nn.init.normal_(qkv_linear_bias, mean=0.0, std=0.02) + else: + qkv_linear_weight = torch.nn.Parameter( + torch.empty(dim + 2 * dim, dim, device=dist.get_rank(), dtype=torch.half)) + qkv_linear_bias = torch.nn.Parameter(torch.empty(dim + 2 * dim, device=dist.get_rank(), dtype=torch.half)) + + dist.broadcast(qkv_linear_weight, src=0, group=spg) + dist.broadcast(qkv_linear_bias, src=0, group=spg) + + num_chunks_attn = fpdt_input_tensor.shape[0] * dist.get_world_size(spg) // chunk_size + fpdt_output = _FPDTGPUOffloadingAttentionImpl_.apply(fpdt_input_tensor, None, None, None, spg, 2, 0, dim, dim, + head_dim, dim, qkv_linear_weight, qkv_linear_bias, 0, + num_chunks_attn, True) + + # baseline + qkv = torch.matmul(input_tensor, qkv_linear_weight.t()) + qkv_linear_bias + q = qkv[:, :, :dim].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, head_dim).permute(1, 2, 0, + 3).contiguous() + k = qkv[:, :, dim:dim * 2].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, + head_dim).permute(1, 2, 0, 3).contiguous() + v = qkv[:, :, dim * 2:dim * 3].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, + head_dim).permute(1, 2, 0, + 3).contiguous() # b, nhead, l, d + + scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dim, dtype=torch.half)) + + causal_mask = torch.triu(torch.ones(d1, d1, device=ds_engine.device), diagonal=1).bool() + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(causal_mask, float('-inf')) + attn_weights = F.softmax(scores, dim=-1) + output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) + + baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, + dist.get_rank()).generate()[0] # b, l, n, d + + assert torch.allclose( + fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1 + ), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + + +@pytest.mark.parametrize("sp_size", [2]) +class TestUlyssesLossBackward(DistributedTest): + world_size = 4 + + def test_sp_loss_backward_stability(self, sp_size: int) -> None: + """ + Regression test for Issue #7672. + Verifies that using all_reduce for loss aggregation is stable + when sequence_parallel_size < world_size, preventing IndexError. + """ + skip_on_arch(min_arch=8) + + # Setup + dp_size = self.world_size // sp_size + model = SimpleModel(4) + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size + }, + ) + + sp_group = ds_engine.seq_parallel_group + + # Simulate Loss on each rank + rank = dist.get_rank() + local_loss = torch.tensor(float(rank + 1), device=ds_engine.device, requires_grad=True) + local_weight = torch.tensor(1.0, device=ds_engine.device) + + # Numerator: Weighted Loss summation + weighted_loss = local_loss * local_weight + dist.all_reduce(weighted_loss, op=dist.ReduceOp.SUM, group=sp_group) + + # B. Denominator: Sum of total weights + total_weight = local_weight.clone() + dist.all_reduce(total_weight, op=dist.ReduceOp.SUM, group=sp_group) + + # C. Calculate the final loss + dist_loss = weighted_loss / total_weight + + # Backward Pass verification + try: + dist_loss.backward() + except IndexError as e: + pytest.fail(f"Backward crashed with IndexError: {e}") + + # Verify Gradients + # Loss = (L1*1 + L2*1) / 2 = 0.5*L1 + 0.5*L2 + expected_grad = 0.5 + assert torch.allclose(local_loss.grad, torch.tensor(expected_grad, device=ds_engine.device)), \ + f"Gradient mismatch! Expected {expected_grad}, got {local_loss.grad}" diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index ac68e8d1c21a..a5538a8c6e68 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -7,12 +7,14 @@ import json import argparse import torch +from collections import OrderedDict from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.moe.layer import MoE from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist +from .common import preferred_dtype class SimpleModel(torch.nn.Module): @@ -34,6 +36,36 @@ def forward(self, x, y): return self.cross_entropy_loss(x, y) +class SimpleFrozenModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleFrozenModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(2)]) + if empty_grad: + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + self.empty_grad = empty_grad + # Freeze first layer + self.linears[0].weight.requires_grad = False + self.linears[0].bias.requires_grad = False + + def custom_state_dict(self, *args, **kwargs): + state_dict = super(SimpleFrozenModel, self).state_dict(*args, **kwargs) + custom = OrderedDict() + for k, v in state_dict.items(): + if 'linears.0.weight' not in k: + custom[k] = v + return custom + + def forward(self, x, y): + if len(self.linears) == 1: + x = self.linears[0](x) + else: + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return self.cross_entropy_loss(x, y) + + class Curriculum_SimpleModel(SimpleModel): def __init__(self, hidden_dim, empty_grad=False): @@ -47,29 +79,37 @@ def forward(self, x, y, **kwargs): class SimpleMoEModel(torch.nn.Module): - def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False): + def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False, use_rts=True): super(SimpleMoEModel, self).__init__() - self.linear = torch.nn.Linear(hidden_dim, hidden_dim) - expert = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim)) # using two MoE layers to check implications of sharing a single storage - self.linear2 = MoE(hidden_size=hidden_dim, - expert=expert, - ep_size=ep_size, - use_residual=use_residual, - num_experts=num_experts, - k=1) - self.linear3 = MoE(hidden_size=hidden_dim, - expert=expert, - ep_size=ep_size, - use_residual=use_residual, - num_experts=num_experts, - k=1) + self.moe_1 = MoE(hidden_size=hidden_dim, + expert=expert, + ep_size=ep_size, + use_residual=use_residual, + num_experts=num_experts, + k=1, + use_rts=use_rts) + # interleaving MoE modules with dense to create an opportunity + # for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.moe_2 = MoE(hidden_size=hidden_dim, + expert=expert, + ep_size=ep_size, + use_residual=use_residual, + num_experts=num_experts, + k=1, + use_rts=use_rts) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): - hidden_dim = self.linear(x) - output, _, _ = self.linear2(hidden_dim) - output, _, _ = self.linear3(output) + hidden_dim = self.linear1(x) + output, _, _ = self.moe_1(hidden_dim) + output = self.linear2(output) + output, _, _ = self.moe_2(output) + output = self.linear3(output) hidden_dim = hidden_dim + output sentence_embed = hidden_dim.mean(1) return self.cross_entropy_loss(sentence_embed, y) @@ -225,21 +265,21 @@ def forward(self, x, y, **kwargs): return hidden_dim -def random_dataset(total_samples, hidden_dim, device, dtype=torch.half): +def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype()): train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) train_dataset = torch.utils.data.TensorDataset(train_data, train_label) return train_dataset -def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half): +def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()): batch_size = model.train_micro_batch_size_per_gpu() train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) return train_loader -def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=torch.half): +def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype) train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) diff --git a/tests/unit/ulysses_alst/test_tiled_compute.py b/tests/unit/ulysses_alst/test_tiled_compute.py new file mode 100644 index 000000000000..e0146e9ec4d9 --- /dev/null +++ b/tests/unit/ulysses_alst/test_tiled_compute.py @@ -0,0 +1,357 @@ +# Copyright (c) The DeepSpeed Contributors +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Arctic Long Sequence Training (ALST) Tiled compute component tests +""" + +from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP, sequence_tiled_compute, TiledFusedLogitsLoss +from deepspeed.utils import safe_get_full_grad +from torch.nn import Linear, Module +from unit.common import DistributedTest, preferred_dtype +from unit.util import torch_assert_equal, torch_assert_close, CaptureStderr +import deepspeed +import pytest +import torch + + +def get_grad(param, zero_stage): + return safe_get_full_grad(param) + # z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1 + # if zero_stage == 1: + # return param.grad + # else: + # return safe_get_full_grad(param) + + +class SimpleMLP(Module): + + def __init__(self, hidden_dim): + super().__init__() + self.up_proj = Linear(hidden_dim, hidden_dim * 2, bias=False) + self.down_proj = Linear(hidden_dim * 2, hidden_dim, bias=False) + self.act = torch.nn.ReLU() + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + + +# save the original implementation to pass through to the tiled computation wrapper +mlp_forward_orig = SimpleMLP.forward + + +class MyModel(Module): + + def __init__(self, hidden_dim, vocab_size): + super().__init__() + self.vocab_size = vocab_size + # Critical - need to use a stack of at least 2 mlps to validate that the backward of the last mlp sends the correct gradients to the previous mlp in the stack + self.mlp1 = SimpleMLP(hidden_dim) + self.mlp2 = SimpleMLP(hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + x = self.mlp1(x) + x = self.mlp2(x) + logits = self.lm_head(x) + return self.cross_entropy_loss(logits.view(-1, self.vocab_size), y.view(-1)) + + +def mlp_forward_tiled_mlp(self, x): + # this tests TiledMLP + compute_params = [self.down_proj.weight, self.up_proj.weight] + num_shards = 4 + + return TiledMLP.apply( + mlp_forward_orig, + self, + x, + num_shards, + compute_params, + ) + + +def mlp_forward_sequence_tiled_compute(self, x): + # this tests: sequence_tiled_compute + SequenceTiledCompute - same as TiledMLP but a-non-MLP + # specific generic implementation of tiled compute + + kwargs_to_shard = dict(x=x) + kwargs_to_pass = dict(self=self) + grad_requiring_tensor_key = "x" + compute_params = [self.down_proj.weight, self.up_proj.weight] + seqlen = x.shape[1] + num_shards = 4 + + return sequence_tiled_compute( + mlp_forward_orig, + seqlen, + num_shards, + kwargs_to_shard, + kwargs_to_pass, + grad_requiring_tensor_key, + compute_params, + output_unshard_dimension=1, # x + output_reduction=None, + ) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("zero_stage", [2, 3]) +class TestTiledCompute(DistributedTest): + world_size = 1 + + def test_tiled_mlp(self, zero_stage, batch_size): + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + dtype = preferred_dtype() + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + + # for debug + # torch.set_printoptions(precision=8, sci_mode=True) + + vocab_size = 10 + seed = 42 + hidden_dim = 128 + bs = batch_size + seqlen = 125 # use a non 2**n length to test varlen shards (last short) + torch.manual_seed(seed) + x = torch.rand((bs, seqlen, hidden_dim), dtype=dtype, requires_grad=True) + y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(vocab_size) + + # A. Baseline: model with normal MLP + torch.manual_seed(seed) + model_a = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype) + model_a, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_a, + model_parameters=model_a.parameters()) + + x = x.to(model_a.device) + y = y.to(model_a.device) + + x_a = x.clone().detach().requires_grad_(True) + y_a = y.clone().detach() + + loss_a = model_a(x_a, y_a) + model_a.backward(loss_a) + param_grad_a1 = get_grad(model_a.module.mlp1.up_proj.weight, zero_stage) + param_grad_a2 = get_grad(model_a.module.mlp2.up_proj.weight, zero_stage) + x_grad_a = x_a.grad + assert param_grad_a1 is not None + assert param_grad_a2 is not None + assert x_grad_a is not None + + # B. model with tiled MLP using TiledMLP + torch.manual_seed(seed) + SimpleMLP.forward = mlp_forward_tiled_mlp + model_b = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype) + model_b, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_b, + model_parameters=model_b.parameters()) + + x_b = x.clone().detach().requires_grad_(True) + y_b = y.clone().detach() + loss_b = model_b(x_b, y_b) + + with CaptureStderr() as cs: + model_b.backward(loss_b) + # see the explanation inside TiledMLP.backward + assert "grad and param do not obey the gradient layout contract" not in cs.err, f"stride issue: {cs.err}" + + param_grad_b1 = get_grad(model_b.module.mlp1.up_proj.weight, zero_stage) + param_grad_b2 = get_grad(model_b.module.mlp2.up_proj.weight, zero_stage) + x_grad_b = x_b.grad + assert param_grad_b1 is not None + assert param_grad_b2 is not None + assert x_grad_b is not None + + # print(f"{loss_a=}") + # print(f"{loss_b=}") + # print(f"{param_grad_a1=}") + # print(f"{param_grad_b1=}") + # print(f"{param_grad_a2=}") + # print(f"{param_grad_b2=}") + torch_assert_equal(loss_a, loss_b) + + # Gradient will not be exactly the same, especially under half-precision. And bf16 is + # particularly lossy so need to lower tolerance a bit more than the default. Switch to + # dtype torch.float or even torch.double to see that the diff is tiny - so the math is + # correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the + # divergence much smaller as well. + torch_assert_close(param_grad_a1, param_grad_b1) #, rtol=1e-03, atol=1e-04) + torch_assert_close(param_grad_a2, param_grad_b2) #, rtol=1e-03, atol=1e-04) + torch_assert_close(x_grad_a, x_grad_b) + + # C. model with tiled MLP using the generic version of the same via sequence_tiled_compute + SequenceTiledCompute + torch.manual_seed(seed) + SimpleMLP.forward = mlp_forward_sequence_tiled_compute + model_c = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype) + model_c, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_c, + model_parameters=model_c.parameters()) + + x_c = x.clone().detach().requires_grad_(True) + y_c = y.clone().detach() + loss_c = model_c(x_c, y_c) + with CaptureStderr() as cs: + model_c.backward(loss_c) + + assert "grad and param do not obey the gradient layout contract" not in cs.err, f"stride issue: {cs.err}" + + param_grad_c1 = get_grad(model_c.module.mlp1.up_proj.weight, zero_stage) + param_grad_c2 = get_grad(model_c.module.mlp2.up_proj.weight, zero_stage) + x_grad_c = x_c.grad + assert param_grad_c1 is not None + assert param_grad_c2 is not None + assert x_grad_c is not None + + # print(f"{loss_a=}") + # print(f"{loss_c=}") + # print(f"{param_grad_a1=}") + # print(f"{param_grad_c1=}") + # see notes for B + torch_assert_equal(loss_a, loss_c) + torch_assert_close(param_grad_a1, param_grad_c1) #, rtol=1e-03, atol=1e-04) + torch_assert_close(param_grad_a2, param_grad_c2) #, rtol=1e-03, atol=1e-04) + torch_assert_close(x_grad_a, x_grad_c) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("zero_stage", [2, 3]) +class TestTiledFusedLogitsLoss(DistributedTest): + world_size = 1 + + def test_tiled_fused_logits_loss(self, zero_stage, batch_size): + + def tiled_forward(self, x, y): + x = self.mlp1(x) + x = self.mlp2(x) + + def loss_fn(self, x, y): + logits = self.lm_head(x) + return self.cross_entropy_loss(logits.view(-1, self.vocab_size), y.view(-1)) + + mask = None + shards = 2 + compute_params = [self.lm_head.weight] + output_reduction = "mean" + loss = TiledFusedLogitsLoss.apply( + loss_fn, + self, + x, + y, + mask, + shards, + compute_params, + output_reduction, + ) + return loss + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + dtype = preferred_dtype() + #dtype = torch.float + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + + # for debug + # torch.set_printoptions(precision=8, sci_mode=True) + + vocab_size = 100 + seed = 42 + hidden_dim = 64 + bs = batch_size + seqlen = 425 # use a non 2**n length to test varlen shards (last short) + torch.manual_seed(seed) + x = torch.rand((bs, seqlen, hidden_dim), dtype=dtype, requires_grad=True) + y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(vocab_size) + + # A. Baseline: model with normal loss + torch.manual_seed(seed) + model_a = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype) + model_a, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_a, + model_parameters=model_a.parameters()) + + x = x.to(model_a.device) + y = y.to(model_a.device) + + x_a = x.clone().detach().requires_grad_(True) + y_a = y.clone().detach() + + loss_a = model_a(x_a, y_a) + model_a.backward(loss_a) + param_grad_a = get_grad(model_a.module.lm_head.weight, zero_stage) + x_grad_a = x_a.grad + assert param_grad_a is not None + assert x_grad_a is not None + + # B. model with fused tiled logits loss + torch.manual_seed(seed) + MyModel.forward_orig = MyModel.forward + MyModel.forward = tiled_forward + model_b = MyModel(hidden_dim=hidden_dim, vocab_size=vocab_size).to(dtype) + model_b, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_b, + model_parameters=model_b.parameters()) + + x_b = x.clone().detach().requires_grad_(True) + y_b = y.clone().detach() + loss_b = model_b(x_b, y_b) + + with CaptureStderr() as cs: + model_b.backward(loss_b) + # see the explanation inside TiledMLP.backward + assert "grad and param do not obey the gradient layout contract" not in cs.err, f"stride issue: {cs.err}" + + param_grad_b = get_grad(model_b.module.lm_head.weight, zero_stage) + x_grad_b = x_b.grad + assert param_grad_b is not None + assert x_grad_b is not None + + # print(f"{loss_a=}") + # print(f"{loss_b=}") + # print(f"{x_grad_a=}") + # print(f"{x_grad_b=}") + # print(f"{param_grad_a=}") + # print(f"{param_grad_b=}") + # usually this is an exact match, but on cpu CI this fails. + torch_assert_close(loss_a, loss_b) + + # Gradient will not be exactly the same, especially under half-precision. And bf16 is + # particularly lossy so need to lower tolerance a bit more than the default. Switch to + # dtype torch.float or even torch.double to see that the diff is tiny - so the math is + # correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the + # divergence much smaller as well. + torch_assert_close(x_grad_a, x_grad_b) + torch_assert_close(param_grad_a, param_grad_b) #, rtol=1e-03, atol=1e-04) + + # restore + MyModel.forward = MyModel.forward_orig diff --git a/tests/unit/ulysses_alst/test_ulysses_sp_hf.py b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py new file mode 100644 index 000000000000..550233d9239e --- /dev/null +++ b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py @@ -0,0 +1,565 @@ +# Copyright (c) The DeepSpeed Contributors +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +UlyssesPlus: UlyssesSPHF tests +""" + +from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter +from deepspeed.runtime.utils import move_to_device +from deepspeed.utils import groups +from deepspeed.utils import safe_get_full_grad +from torch import tensor +from transformers import AutoModelForCausalLM +from unit.common import DistributedTest, preferred_dtype +from unit.util import torch_assert_equal, torch_assert_close, torch_assert_dicts_of_tensors_equal +import deepspeed +import deepspeed.comm as dist +import pytest +import torch + + +def get_grad(param, zero_stage): + return safe_get_full_grad(param) + # z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1 + # if zero_stage == 1: + # return param.grad + # else: + # return safe_get_full_grad(param) + + +@pytest.mark.parametrize("zero_stage", [2, 3]) +class TestUlyssesSPHF(DistributedTest): + world_size = 2 + + def test_ulysses_sp_hf(self, zero_stage): + core_attn_implementation = "sdpa" + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + #model_name_or_path = 'Felladrin/Llama-160M-Chat-v1' + #model_name_or_path = 'Felladrin/Llama-160M-Chat-v1' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + rank = dist.get_rank() + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "sequence_parallel_size": sequence_parallel_size, + } + + dtype = preferred_dtype() + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + + # Part 1. Baseline: Setup + def collate_fn(batch): + input_ids, position_ids = batch[0] + #print(f"{batch=}") + return dict(input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0)) + + input_ids = tensor([[1, 10, 10, 10, 2, 2], [1, 20, 20, 20, 2, 2]], ) + position_ids = tensor([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]) + ds = torch.utils.data.TensorDataset(input_ids, position_ids) + + # 1. Baseline: DataLoader calibration + dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + batch_a = next(iter(dl_a)) + #print(f"{rank=} {batch_a=}") + expected_batch_a = { + 'input_ids': tensor([[1, 10, 10, 10, 2, 2]]), + 'position_ids': tensor([[0, 1, 2, 3, 4, 5]]), + 'labels': tensor([[1, 10, 10, 10, 2, 2]]) + } + torch_assert_dicts_of_tensors_equal(batch_a, expected_batch_a) + + # 2. Baseline: Attention + model_a = AutoModelForCausalLM.from_pretrained(model_name_or_path) + model_a, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_a, + model_parameters=model_a.parameters(), + mpu=None) + batch_a = move_to_device(batch_a, model_a.device) + loss_a = model_a(**batch_a).loss + model_a.backward(loss_a) + #print(f"{loss_a=}") + + grad_a = get_grad(model_a.module.model.layers[0].self_attn.q_proj.weight, zero_stage) + assert grad_a is not None + #print(f"{grad_a}") + + # Part 2. Ulysses: Setup + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=model_name_or_path, + core_attn_implementation=core_attn_implementation, + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + + model_b = AutoModelForCausalLM.from_pretrained(model_name_or_path, + attn_implementation=core_attn_implementation) + model_b, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_b, + model_parameters=model_b.parameters(), + mpu=mpu) + + # 3. Ulysses: UlyssesSPDataLoaderAdapter test + sp_group = groups._get_sequence_parallel_group() + sp_world_size = groups._get_sequence_parallel_world_size() + sp_rank = groups._get_sequence_parallel_rank() + dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + dl_b = UlyssesSPDataLoaderAdapter( + dl_a, + sp_rank=sp_rank, + sp_group=sp_group, + sp_world_size=sp_world_size, + device=model_b.device, + ) + batch_b = next(iter(dl_b)) + + expected_batch_b = [ + { + 'input_ids': tensor([[1, 10, 10]]), + 'position_ids': tensor([[0, 1, 2]]), + 'shift_labels': tensor([[10, 10, 10]]), + }, + { + 'input_ids': tensor([[10, 2, 2]]), + 'position_ids': tensor([[3, 4, 5]]), + 'shift_labels': tensor([[2, 2, -100]]), + }, + ] + + # here we expect each sample to be sharded in half, rank0 getting the first half and rank1 the other half + #print(f"{sp_rank=} {batch_b=}") + torch_assert_dicts_of_tensors_equal(batch_b, expected_batch_b[sp_rank]) + + # 4. UlyssesSPAttentionHF test + batch_b = move_to_device(batch_b, model_b.device) + outputs = model_b(**batch_b) + # HF doesn't calculate loss with shift_labels yet and requires us to do it manually (liger does that) + shift_labels = batch_b["shift_labels"] + loss_b = model_b.module.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=model_b.module.config.vocab_size, + ) + # print(f"{sp_rank=} {loss_b=}") + + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss_b, group=sp_group) + good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) + total_good_tokens = sum(good_tokens_per_rank) + loss_b = total_loss / total_good_tokens + # print(f"{sp_rank=} gathered {loss_b=}") + model_b.backward(loss_b) + + grad_b = get_grad(model_b.module.model.layers[0].self_attn.q_proj.weight, zero_stage) + assert grad_b is not None + #print(f"{grad_b}") + + # compare loss of A (non-Ulysses Attention) and B (Ulyssses Attention) + torch_assert_equal(loss_a, loss_b) + + # - we are feeding the exact same sample to each rank of A + # - for B we feed half the sample to each rank, but in total it's the same sample as each rank of A sees + # thus we expect very similar grads (but not exact) + if zero_stage in [1, 2]: + # possibly some issue with z1/z2 as it requires higher tolerance than z3? + torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03) + else: + torch_assert_close(grad_a, grad_b) + + +class TestUlyssesSPHFPEFT(DistributedTest): + world_size = 2 + + def test_ulysses_sp_hf_with_peft_model(self): + """Test that UlyssesSPAttentionHF.register_with_transformers works with PEFT models. + + PEFT models don't inherit from transformers.PreTrainedModel but have a config attribute. + This test verifies the duck-typing check for the config attribute works correctly. + """ + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + # Create a mock PEFT model object that has config but doesn't inherit from PreTrainedModel + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(model_name_or_path) + + class MockPEFTModel: + """Mock PEFT model that simulates PeftModel behavior""" + + def __init__(self, config): + self.config = config + + mock_peft_model = MockPEFTModel(hf_config) + + # Test that register_with_transformers works with PEFT-like model object + # This should not crash and should use the config attribute via duck-typing + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=mock_peft_model, + core_attn_implementation="sdpa", + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + + # Verify mpu is created successfully + assert mpu is not None + + # Verify that the sequence parallel groups are initialized + sp_group = mpu.get_sequence_parallel_group() + assert sp_group is not None + sp_world_size = mpu.get_sequence_parallel_world_size() + assert sp_world_size == sequence_parallel_size + + +class TestUlyssesSPHFDisableInEval(DistributedTest): + world_size = 2 + + def test_disable_in_eval(self): + """Test that disable_in_eval parameter controls SP behavior during evaluation. + + When disable_in_eval=True, SP operations should be bypassed during eval mode, + allowing the user to pass full (non-sharded) sequences directly. + This should produce the same output as a model without SP registered. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + dtype = preferred_dtype() + rank = dist.get_rank() + + # Full sequence input (not sharded) - this is what users would pass during eval + # when they want to bypass SP and process sequences independently per rank + input_ids = tensor([[1, 10, 10, 10, 2, 2]], device=f"cuda:{rank}") + position_ids = tensor([[0, 1, 2, 3, 4, 5]], device=f"cuda:{rank}") + + # 1. Baseline: model without SP, processing full sequence + model_baseline = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype) + model_baseline = model_baseline.to(f"cuda:{rank}") + model_baseline.eval() + + # Save original attention function for comparison + original_sdpa = ALL_ATTENTION_FUNCTIONS["sdpa"] + + with torch.no_grad(): + outputs_baseline = model_baseline(input_ids=input_ids, position_ids=position_ids) + logits_baseline = outputs_baseline.logits.clone() + + del model_baseline + + # 2. Model with SP registered but disable_in_eval=True + # In eval mode, SP is bypassed so full sequence can be passed directly + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=model_name_or_path, + core_attn_implementation="sdpa", + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + disable_in_eval=True, + ) + + # Verify that register_with_transformers actually registered the wrapper + assert mpu is not None, "mpu should not be None when sequence_parallel_size > 1" + assert ALL_ATTENTION_FUNCTIONS["sdpa"] is not original_sdpa, \ + "register_with_transformers should have replaced the attention function" + + model_sp = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype) + model_sp = model_sp.to(f"cuda:{rank}") + model_sp.eval() + + with torch.no_grad(): + outputs_sp = model_sp(input_ids=input_ids, position_ids=position_ids) + logits_sp = outputs_sp.logits.clone() + + # Verify: with disable_in_eval=True, full sequence input should produce + # the same output as baseline (SP is bypassed) + torch_assert_equal(logits_baseline, logits_sp) + + +class TestUlyssesSPHFHubKernel(DistributedTest): + world_size = 2 + + def test_register_hub_kernel_attn(self, monkeypatch): + """Test hub-kernel attention strings are registered before validation. + + This verifies that DeepSpeed can accept kernel-based attention implementations + by triggering transformers' lazy registration path prior to checking + ALL_ATTENTION_FUNCTIONS. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + hub_attn_implementation = 'kernels-community/flash-attn2' + + called_with = [] + had_hub_key_before = hub_attn_implementation in ALL_ATTENTION_FUNCTIONS + original_sdpa = ALL_ATTENTION_FUNCTIONS['sdpa'] + + def _mock_lazy_import_flash_attention(implementation, attention_wrapper=None, allow_all_kernels=False): + called_with.append(implementation) + if implementation == hub_attn_implementation and implementation not in ALL_ATTENTION_FUNCTIONS: + # Mimic transformers hub-kernel registration behavior. + ALL_ATTENTION_FUNCTIONS.register(implementation, ALL_ATTENTION_FUNCTIONS['sdpa']) + return (None, None, None, None), None + + monkeypatch.setattr( + 'transformers.modeling_flash_attention_utils.lazy_import_flash_attention', + _mock_lazy_import_flash_attention, + ) + + try: + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=model_name_or_path, + core_attn_implementation=hub_attn_implementation, + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + assert ALL_ATTENTION_FUNCTIONS['sdpa'] is original_sdpa + assert ALL_ATTENTION_FUNCTIONS[hub_attn_implementation] is not original_sdpa + finally: + if not had_hub_key_before and hub_attn_implementation in ALL_ATTENTION_FUNCTIONS: + ALL_ATTENTION_FUNCTIONS.pop(hub_attn_implementation, None) + + assert mpu is not None + assert called_with == [hub_attn_implementation] + + +class TestUlyssesSPHFAttnImplMismatch(DistributedTest): + world_size = 2 + + def test_register_with_mismatched_attn_impl_raises(self): + from transformers import AutoConfig + + model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + hf_config = AutoConfig.from_pretrained(model_name_or_path) + hf_config._attn_implementation = "sdpa" + + class MockModel: + """Mock model wrapper exposing a transformers config attribute.""" + + def __init__(self, config): + self.config = config + + with pytest.raises(ValueError, match='does not match model config attn_implementation'): + UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=MockModel(hf_config), + core_attn_implementation='flash_attention_2', + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + + +@pytest.mark.parametrize("zero_stage", [2, 3]) +class TestUlyssesSPHFFlexAttention(DistributedTest): + """Separate class for flex_attention tests — requires non_daemonic_procs + because torch.compile (used by flex_attention) creates unpicklable objects + that break the default multiprocessing.Pool exception handling.""" + world_size = 2 + non_daemonic_procs = True + + def test_ulysses_sp_hf_flex_attention(self, zero_stage): + core_attn_implementation = "flex_attention" + # flex_attention's compiled kernel requires head_dim >= 16. + # tiny-random-LlamaForCausalLM has head_dim=4, so we create a tiny model with head_dim=16. + from transformers import LlamaConfig + model_config = LlamaConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=32, + max_position_embeddings=64, + ) # head_dim = 32/2 = 16 + seq_length = 64 + sequence_parallel_size = self.world_size + micro_batch_size = 1 + + rank = dist.get_rank() + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "sequence_parallel_size": sequence_parallel_size, + } + + dtype = preferred_dtype() + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + + # Part 1. Baseline: Setup + def collate_fn(batch): + input_ids, position_ids = batch[0] + #print(f"{batch=}") + return dict(input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0)) + + input_ids = tensor([[1, 10, 10, 10, 2, 2], [1, 20, 20, 20, 2, 2]]) + position_ids = tensor([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]) + ds = torch.utils.data.TensorDataset(input_ids, position_ids) + + # 1. Baseline: DataLoader calibration + dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + batch_a = next(iter(dl_a)) + #print(f"{rank=} {batch_a=}") + expected_batch_a = { + 'input_ids': tensor([[1, 10, 10, 10, 2, 2]]), + 'position_ids': tensor([[0, 1, 2, 3, 4, 5]]), + 'labels': tensor([[1, 10, 10, 10, 2, 2]]) + } + torch_assert_dicts_of_tensors_equal(batch_a, expected_batch_a) + + # 2. Baseline: Attention + torch.manual_seed(42) + model_a = AutoModelForCausalLM.from_config(model_config, attn_implementation=core_attn_implementation) + model_a, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_a, + model_parameters=model_a.parameters(), + mpu=None) + batch_a = move_to_device(batch_a, model_a.device) + loss_a = model_a(**batch_a).loss + model_a.backward(loss_a) + #print(f"{loss_a=}") + + grad_a = get_grad(model_a.module.model.layers[0].self_attn.q_proj.weight, zero_stage) + assert grad_a is not None + #print(f"{grad_a}") + + # Part 2. Ulysses: Setup + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=model_a.module, + core_attn_implementation=core_attn_implementation, + sequence_parallel_size=sequence_parallel_size, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + seq_length_is_variable=True, + ) + + torch.manual_seed(42) + model_b = AutoModelForCausalLM.from_config(model_config, attn_implementation=core_attn_implementation) + model_b, _, _, _ = deepspeed.initialize(config=config_dict, + model=model_b, + model_parameters=model_b.parameters(), + mpu=mpu) + + # 3. Ulysses: UlyssesSPDataLoaderAdapter test + sp_group = groups._get_sequence_parallel_group() + sp_world_size = groups._get_sequence_parallel_world_size() + sp_rank = groups._get_sequence_parallel_rank() + dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + dl_b = UlyssesSPDataLoaderAdapter( + dl_a, + sp_rank=sp_rank, + sp_group=sp_group, + sp_world_size=sp_world_size, + device=model_b.device, + ) + batch_b = next(iter(dl_b)) + + expected_batch_b = [ + { + 'input_ids': tensor([[1, 10, 10]]), + 'position_ids': tensor([[0, 1, 2]]), + 'shift_labels': tensor([[10, 10, 10]]), + }, + { + 'input_ids': tensor([[10, 2, 2]]), + 'position_ids': tensor([[3, 4, 5]]), + 'shift_labels': tensor([[2, 2, -100]]), + }, + ] + + # here we expect each sample to be sharded in half, rank0 getting the first half and rank1 the other half + #print(f"{sp_rank=} {batch_b=}") + torch_assert_dicts_of_tensors_equal(batch_b, expected_batch_b[sp_rank]) + + # 4. UlyssesSPAttentionHF test + batch_b = move_to_device(batch_b, model_b.device) + outputs = model_b(**batch_b) + # HF doesn't calculate loss with shift_labels yet and requires us to do it manually (liger does that) + shift_labels = batch_b["shift_labels"] + loss_b = model_b.module.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=model_b.module.config.vocab_size, + ) + # print(f"{sp_rank=} {loss_b=}") + + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss_b, group=sp_group) + good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) + total_good_tokens = sum(good_tokens_per_rank) + loss_b = total_loss / total_good_tokens + # print(f"{sp_rank=} gathered {loss_b=}") + model_b.backward(loss_b) + + grad_b = get_grad(model_b.module.model.layers[0].self_attn.q_proj.weight, zero_stage) + assert grad_b is not None + #print(f"{grad_b}") + + # compare loss of A (non-Ulysses Attention) and B (Ulyssses Attention) + torch_assert_close(loss_a, loss_b, atol=1e-05, rtol=1e-05) + + # - we are feeding the exact same sample to each rank of A + # - for B we feed half the sample to each rank, but in total it's the same sample as each rank of A sees + # thus we expect very similar grads (but not exact) + if zero_stage in [1, 2]: + # possibly some issue with z1/z2 as it requires higher tolerance than z3? + torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03) + else: + torch_assert_close(grad_a, grad_b) diff --git a/tests/unit/util.py b/tests/unit/util.py index b339a08056a2..663f45067bd2 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -3,85 +3,345 @@ # DeepSpeed Team +from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported +from deepspeed.git_version_info import torch_info + +from io import StringIO +import deepspeed +import logging import pytest +import re +import sys import torch -import deepspeed -from deepspeed.git_version_info import torch_info def skip_on_arch(min_arch=7): - if deepspeed.accelerator.get_accelerator().device_name() == 'cuda': + if get_accelerator().device_name() == 'cuda': if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda pytest.skip(f"needs higher compute capability than {min_arch}") else: - assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu' + assert is_current_accelerator_supported() return def skip_on_cuda(valid_cuda): split_version = lambda x: map(int, x.split('.')[:2]) - if deepspeed.accelerator.get_accelerator().device_name() == 'cuda': + if get_accelerator().device_name() == 'cuda': CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version']) CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR if valid_cuda.count(CUDA_VERSION) == 0: pytest.skip(f"requires cuda versions {valid_cuda}") else: - assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu' + assert is_current_accelerator_supported() return -def required_torch_version(): - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) +def bf16_required_version_check(accelerator_check=True): + split_version = lambda x: map(int, x.split('.')[:2]) - if TORCH_MAJOR >= 1 and TORCH_MINOR >= 8: - return True + # torch_info may have stale/zero values if installed without --no-build-isolation + # In that case, fall back to runtime detection + if torch_info['version'] == '0.0': + # Use runtime torch version + TORCH_MAJOR, TORCH_MINOR = split_version(torch.__version__) else: - return False + TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version']) + if torch_info['nccl_version'] == '0.0': + # Use runtime NCCL version if available + if torch.cuda.is_available(): #ignore-cuda + try: + nccl_ver = torch.cuda.nccl.version() #ignore-cuda + NCCL_MAJOR, NCCL_MINOR = nccl_ver[0], nccl_ver[1] + except (AttributeError, RuntimeError): + NCCL_MAJOR, NCCL_MINOR = 0, 0 + else: + # No CUDA means no NCCL; NPU/HPU/XPU have separate checks below + NCCL_MAJOR, NCCL_MINOR = 0, 0 + else: + NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version']) -def bf16_required_version_check(accelerator_check=True): - split_version = lambda x: map(int, x.split('.')[:2]) - TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version']) - NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version']) - CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version']) + if torch_info['cuda_version'] == '0.0': + # Use runtime CUDA version + if torch.cuda.is_available(): #ignore-cuda + cuda_ver = torch.version.cuda + if cuda_ver: + CUDA_MAJOR, CUDA_MINOR = split_version(cuda_ver) + else: + CUDA_MAJOR, CUDA_MINOR = 0, 0 + else: + CUDA_MAJOR, CUDA_MINOR = 0, 0 + else: + CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version']) # Sometimes bf16 tests are runnable even if not natively supported by accelerator if accelerator_check: - accelerator_pass = torch_info['bf16_support'] + accelerator_pass = get_accelerator().is_bf16_supported() else: accelerator_pass = True - if (TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and ( - NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and accelerator_pass: + torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + cuda_version_available = CUDA_MAJOR >= 11 + nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10) + npu_available = get_accelerator().device_name() == 'npu' + hpu_available = get_accelerator().device_name() == 'hpu' + xpu_available = get_accelerator().device_name() == 'xpu' + + if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass: + return True + elif npu_available: + return True + elif hpu_available: + return True + elif xpu_available: return True else: return False -def required_minimum_torch_version(major_version, minor_version): - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - - if TORCH_MAJOR < major_version: +def required_amp_check(): + from importlib.util import find_spec + if find_spec('apex') is None: return False + else: + return True - return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version +class no_child_process_in_deepspeed_io: -def required_maximum_torch_version(major_version, minor_version): - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + def __enter__(self): + # deepspeed_io defaults to creating a dataloader that uses a + # multiprocessing pool. Our tests use pools and we cannot nest pools in + # python. Therefore we're injecting this kwarg to ensure that no pools + # are used in the dataloader. + self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io - if TORCH_MAJOR > major_version: - return False + def new_method(*args, **kwargs): + kwargs["num_local_io_workers"] = 0 + return self.old_method(*args, **kwargs) - return TORCH_MAJOR < major_version or TORCH_MINOR <= minor_version + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method + def __exit__(self, *_): + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method -def required_amp_check(): - from importlib.util import find_spec - if find_spec('apex') is None: - return False - else: - return True + +def torch_assert_equal(actual, expected, **kwargs) -> None: + """ + Compare two tensors or non-tensor numbers for their equality. + Add msg=blah to add an additional comment to when assert fails. + """ + torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs) + + +def torch_assert_close(actual, expected, **kwargs) -> None: + """ + Compare two tensors or non-tensor numbers for their closeness. + + Add msg=blah to add an additional comment to when assert fails. + + For default values of `rtol` and `atol` which are dtype dependent, see the table at https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close + For example for bf16 it is `rtol=1.6e-2` and `atol=1e-5`. + + The check doesn't assert when `|a - b| <= (atol + rtol * |b|)` + """ + torch.testing.assert_close(actual, expected, **kwargs) + + +def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs): + """ + Compare two dicts of tensors or non-tensor numbers for their equality. + Add msg=blah to add an additional comment to when assert fails. + """ + for k in actual.keys(): + torch.testing.assert_close(actual[k], expected[k], rtol=0.0, atol=0.0, **kwargs) + + +# CaptureStd, CaptureLogger context managers from https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py + + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + return re.sub(r"^.*\r", "", buf, 0, re.M) + + +class CaptureStd: + """ + Context manager to capture: + + - stdout: replay it, clean it up and make it available via ``obj.out`` + - stderr: replay it and make it available via ``obj.err`` + - combined: combined the chosen streams and make it available via ``obj.combined`` + + init arguments: + + - out - capture stdout:`` True``/``False``, default ``True`` + - err - capture stdout: ``True``/``False``, default ``True`` + - replay - whether to replay or not: ``True``/``False``, default ``True``. By default each + captured stream gets replayed back on context's exit, so that one can see what the test was + doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass + ``replay=False`` to disable this feature. + + Examples:: + + # to capture stdout only with auto-replay + with CaptureStdout() as cs: + print("Secret message") + assert "message" in cs.out + + # to capture stderr only with auto-replay + import sys + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + assert "Warning" in cs.err + + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay + with CaptureStd(err=False) as cs: + print("Secret message") + assert "message" in cs.out + # but best use the stream-specific subclasses + + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + + # sometimes it's easier to not try to figure out if it's stdout or stderr, and yet at + # other times the software may send the same output to stderr or stdout depending on + # environment, so to make the test robust a combined entry of both streams is available + + """ + + def __init__(self, out=True, err=True, replay=True): + self.replay = replay + + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + self.combined = "error: CaptureStd context is unfinished yet, called too early" + + def __enter__(self): + if self.out_buf is not None: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf is not None: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + self.combined = "" + + return self + + def __exit__(self, *exc): + if self.out_buf is not None: + sys.stdout = self.out_old + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) + self.combined += self.out + + if self.err_buf is not None: + sys.stderr = self.err_old + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured + self.combined += self.err + + def __repr__(self): + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """Same as CaptureStd but captures only stdout""" + + def __init__(self, replay=True): + super().__init__(err=False, replay=replay) + + +class CaptureStderr(CaptureStd): + """Same as CaptureStd but captures only stderr""" + + def __init__(self, replay=True): + super().__init__(out=False, replay=replay) + + +class CaptureLogger: + """ + Context manager to capture `logging` streams + + Args: + + - logger: 'logging` logger object + + Results: + The captured output is available via `self.out` + + Example:: + + >>> from transformers import logging + >>> from transformers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg+"\n" + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" diff --git a/tests/unit/utils/test_byte_cast.py b/tests/unit/utils/test_byte_cast.py new file mode 100644 index 000000000000..99e66eb4920b --- /dev/null +++ b/tests/unit/utils/test_byte_cast.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.accelerator import get_accelerator +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[UtilsBuilder.NAME]: + pytest.skip(f'Skip tests since {UtilsBuilder.NAME} is not compatible', allow_module_level=True) + + +def _validate_tensor_cast_properties(typed_tensor, byte_tensor): + assert byte_tensor.dtype == torch.uint8 + assert byte_tensor.numel() == typed_tensor.numel() * typed_tensor.element_size() + assert byte_tensor.data_ptr() == typed_tensor.data_ptr() + + +def _byte_cast_single_tensor(typed_tensor): + util_ops = UtilsBuilder().load() + byte_tensor = util_ops.cast_to_byte_tensor(typed_tensor) + + _validate_tensor_cast_properties(typed_tensor=typed_tensor, byte_tensor=byte_tensor) + + +def _byte_cast_multiple_tensors(typed_tensor_list): + util_ops = UtilsBuilder().load() + byte_tensor_list = util_ops.cast_to_byte_tensor(typed_tensor_list) + + assert len(typed_tensor_list) == len(byte_tensor_list) + + for typed_tensor, byte_tensor in zip(typed_tensor_list, byte_tensor_list): + _validate_tensor_cast_properties(typed_tensor=typed_tensor, byte_tensor=byte_tensor) + + +@pytest.mark.parametrize( + 'dtype', + [torch.float32, torch.half, torch.bfloat16, torch.float64, torch.int32, torch.short, torch.int64], +) +class TestCastSingleTensor(DistributedTest): + world_size = 1 + + def test_byte_cast_accelerator_tensor(self, dtype): + numel = 1024 + typed_tensor = torch.empty(numel, dtype=dtype).to(get_accelerator().device_name()) + _byte_cast_single_tensor(typed_tensor) + + @pytest.mark.parametrize("pinned_memory", [True, False]) + def test_byte_cast_cpu_tensor(self, dtype, pinned_memory): + numel = 1024 + typed_tensor = torch.empty(numel, dtype=dtype, device='cpu') + if pinned_memory: + typed_tensor = typed_tensor.pin_memory() + + _byte_cast_single_tensor(typed_tensor) + + +@pytest.mark.parametrize('tensor_count', [1, 8, 15]) +class TestCastTensorList(DistributedTest): + world_size = 1 + + def test_byte_cast_accelerator_tensor_list(self, tensor_count): + typed_tensor_list = [torch.empty(1024, dtype=torch.half).to(get_accelerator().device_name())] * tensor_count + _byte_cast_multiple_tensors(typed_tensor_list) + + def test_byte_cast_cpu_tensor_list(self, tensor_count): + typed_tensor_list = [torch.empty(1024, dtype=torch.half, device='cpu')] * tensor_count + _byte_cast_multiple_tensors(typed_tensor_list) diff --git a/tests/unit/utils/test_groups.py b/tests/unit/utils/test_groups.py index d8f12be4f3c6..5cd35baf3510 100644 --- a/tests/unit/utils/test_groups.py +++ b/tests/unit/utils/test_groups.py @@ -18,7 +18,7 @@ def test_get_expert_parallel_ranks(): expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] """ expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(world_size=16, - model_parallel_size_=2, + tensor_parallel_size_=2, expert_parallel_size_=4) assert expert_parallel_groups == [ [0, 2, 4, 6], diff --git a/tests/unit/utils/test_nvtx.py b/tests/unit/utils/test_nvtx.py new file mode 100644 index 000000000000..a8c6816273ed --- /dev/null +++ b/tests/unit/utils/test_nvtx.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.utils.nvtx as ds_nvtx +import accelerator.cuda_accelerator as cuda_accelerator +from accelerator.cuda_accelerator import CUDA_Accelerator + + +def _sample_nvtx_function(): + return "ok" + + +def test_instrument_w_nvtx_uses_deepspeed_domain(monkeypatch, capsys): + calls = [] + + class FakeAccelerator: + supports_nvtx_domain = True + + def range_push(self, msg, domain=None, category=None): + calls.append(("push", msg, domain, category)) + + def range_pop(self, domain=None): + calls.append(("pop", domain)) + + monkeypatch.setattr(ds_nvtx, "enable_nvtx", True) + monkeypatch.setattr(ds_nvtx, "is_compiling", lambda: False) + monkeypatch.setattr(ds_nvtx, "get_accelerator", lambda: FakeAccelerator()) + + wrapped_fn = ds_nvtx.instrument_w_nvtx(_sample_nvtx_function) + + assert wrapped_fn() == "ok" + + with capsys.disabled(): + print(f"\nNVTX instrumentation calls: {calls}") + + assert calls == [ + ("push", "_sample_nvtx_function", ds_nvtx.DEEPSPEED_NVTX_DOMAIN, None), + ("pop", ds_nvtx.DEEPSPEED_NVTX_DOMAIN), + ] + + +def test_instrument_w_nvtx_supports_legacy_accelerator_methods(monkeypatch, capsys): + calls = [] + + class LegacyAccelerator: + + def range_push(self, msg): + calls.append(("push", msg)) + + def range_pop(self): + calls.append(("pop", )) + + monkeypatch.setattr(ds_nvtx, "enable_nvtx", True) + monkeypatch.setattr(ds_nvtx, "is_compiling", lambda: False) + monkeypatch.setattr(ds_nvtx, "get_accelerator", lambda: LegacyAccelerator()) + + wrapped_fn = ds_nvtx.instrument_w_nvtx(_sample_nvtx_function) + + assert wrapped_fn() == "ok" + + with capsys.disabled(): + print(f"\nLegacy NVTX instrumentation calls: {calls}") + + assert calls == [ + ("push", "_sample_nvtx_function"), + ("pop", ), + ] + + +def test_cuda_accelerator_uses_nvtx_domain_when_available(monkeypatch, capsys): + + class FakeDomain: + + def __init__(self): + self.calls = [] + + def push_range(self, message=None, category=None): + self.calls.append(("push", message, category)) + return "domain-push" + + def pop_range(self): + self.calls.append(("pop", )) + return "domain-pop" + + class FakeNvtx: + + def __init__(self): + self.domains = {} + + def get_domain(self, name): + self.domains.setdefault(name, FakeDomain()) + return self.domains[name] + + fake_nvtx = FakeNvtx() + accelerator = CUDA_Accelerator.__new__(CUDA_Accelerator) + accelerator._nvtx_domains = {} + monkeypatch.setattr(cuda_accelerator, "nvtx", fake_nvtx) + + assert accelerator.range_push("my_range", domain="DeepSpeed", category="zero") == "domain-push" + assert accelerator.range_pop(domain="DeepSpeed") == "domain-pop" + + with capsys.disabled(): + print(f"\nCUDA NVTX domain calls: {fake_nvtx.domains['DeepSpeed'].calls}") + + assert fake_nvtx.domains["DeepSpeed"].calls == [ + ("push", "my_range", "zero"), + ("pop", ), + ] + + +def test_cuda_accelerator_falls_back_to_torch_nvtx_without_nvtx_package(monkeypatch, capsys): + calls = [] + + class FakeTorchNvtx: + + def range_push(self, msg): + calls.append(("push", msg)) + return "torch-push" + + def range_pop(self): + calls.append(("pop", )) + return "torch-pop" + + accelerator = CUDA_Accelerator.__new__(CUDA_Accelerator) + monkeypatch.setattr(cuda_accelerator, "nvtx", None) + monkeypatch.setattr(cuda_accelerator.torch.cuda, "nvtx", FakeTorchNvtx()) #ignore-cuda + + assert accelerator.range_push("my_range", domain="DeepSpeed", category="zero") == "torch-push" + assert accelerator.range_pop(domain="DeepSpeed") == "torch-pop" + + with capsys.disabled(): + print(f"\nCUDA torch.nvtx fallback calls: {calls}") + + assert calls == [ + ("push", "my_range"), + ("pop", ), + ] diff --git a/tests/unit/utils/test_partition_balanced.py b/tests/unit/utils/test_partition_balanced.py new file mode 100644 index 000000000000..e7285e478c53 --- /dev/null +++ b/tests/unit/utils/test_partition_balanced.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime import utils as ds_utils + + +def check_partition(weights, num_parts, target_diff): + result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts) + + parts_sum = [] + for b, e in zip(result[:-1], result[1:]): + parts_sum.append(sum(weights[b:e])) + + assert max(parts_sum) - min( + parts_sum + ) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}" + + +def test_partition_balanced(): + check_partition([1, 2, 1], 4, target_diff=2) + check_partition([1, 1, 1, 1], 4, target_diff=0) + check_partition([1, 1, 1, 1, 1], 4, target_diff=1) + check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1) diff --git a/tests/unit/v1/compile/test_compile_autosp.py b/tests/unit/v1/compile/test_compile_autosp.py new file mode 100644 index 000000000000..fcdb7fdefc88 --- /dev/null +++ b/tests/unit/v1/compile/test_compile_autosp.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F +from torch.fx import Graph, GraphModule + +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator +from deepspeed.compile import constants + +from unit.v1.compile.util import compare_sp_loss, create_gm_nodes, find_sym_seq_node +from unit.common import DistributedTest +from unit.util import bf16_required_version_check, skip_on_arch + +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.9), + reason="AutoSP tests require PyTorch >= 2.9") + +# Fixed sp_size injected into mocks. +_SP_SIZE = 2 + + +class TestAutoSPCompile(DistributedTest): + world_size = 4 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [0, 1]) + @pytest.mark.parametrize('sp_size', [2, 4]) + def test(self, zero_stage, dtype, sp_size): + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" + ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + dp_size = self.world_size // sp_size + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": dp_size, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True, + "passes": ["autosp"] + }, + "sequence_parallel_size": sp_size, + "gradient_clipping": 1.0, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + compare_sp_loss(self, config_dict, sp_size) + + +# Plain pytest classes — no distributed runtime needed because these functions +# perform pure IR-level graph rewrites; sp_size and get_rank are mocked. + + +class TestSDPANodesCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_sdpa_nodes + + gm, _ = create_gm_nodes(seq_len=seq_len) + sdpa_nodes = get_sdpa_nodes(gm) + + assert len(sdpa_nodes) >= 1, f"Expected at least 1 SDPA node, got {len(sdpa_nodes)}" + for node in sdpa_nodes: + assert node.target == F.scaled_dot_product_attention + + +class TestInputIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_input_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_input_id_node(gm) + + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_INPUT_ID_KEY + + +class TestLabelIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_label_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_label_id_node(gm) + + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_LABEL_ID_KEY + + +class TestPositionIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_position_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_position_id_node(gm) + + assert node is not None, "position_id node not found in graph" + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_POSITION_ID_KEY + + +class TestShardOffsetsCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import create_shard_offsets + + gm, _ = create_gm_nodes(seq_len=seq_len) + sym_seq_node = find_sym_seq_node(gm) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + start_node, end_node = create_shard_offsets(gm, sym_seq_node) + + # create_shard_offsets emits: chunk = seq // sp_size; start = rank * chunk; end = start + chunk. + # Verify the three-node chain has the right operators and wiring. + chunk_size_node = start_node.args[1] # start = rank * chunk → chunk is arg[1] + + assert chunk_size_node.target == operator.floordiv + assert chunk_size_node.args[0] is sym_seq_node + assert chunk_size_node.args[1] == _SP_SIZE + + assert start_node.target == operator.mul + assert start_node.args[0] == 0 # rank 0 baked in at transform time + assert start_node.args[1] is chunk_size_node + + assert end_node.target == operator.add + assert end_node.args[0] is start_node + assert end_node.args[1] is chunk_size_node + + +class TestSymSliceCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import create_symbolic_slice_indices + + gm, _ = create_gm_nodes(seq_len=seq_len) + sym_seq_node = find_sym_seq_node(gm) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + slice_all, slice_range = create_symbolic_slice_indices(gm, sym_seq_node) + + # slice_all = slice(None, None, None) — selects the batch dimension unchanged + assert slice_all.target == slice + assert slice_all.args == (None, None, None) + + # slice_range selects [start, end) along the sequence dim, where start and + # end come from create_shard_offsets (mul and add nodes respectively). + assert slice_range.target == slice + start_arg, end_arg, step_arg = slice_range.args + assert step_arg is None + + # start = rank * chunk → verify the full shard-offset wiring + chunk_size_node = start_arg.args[1] + assert start_arg.target == operator.mul + assert start_arg.args[0] == 0 # rank 0 baked in at transform time + assert chunk_size_node.target == operator.floordiv + assert chunk_size_node.args[0] is sym_seq_node + assert chunk_size_node.args[1] == _SP_SIZE + + # end = start + chunk + assert end_arg.target == operator.add + assert end_arg.args[0] is start_arg + assert end_arg.args[1] is chunk_size_node + + +class TestShardTensorCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import shard_tensor_node, get_input_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + input_ids_node = get_input_id_node(gm) + original_users = set(input_ids_node.users.keys()) + assert len(original_users) > 0, "input_ids_node must have users before sharding" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + shard_tensor_node(gm, input_ids_node) + + getitem_nodes = [n for n in gm.graph.nodes if n.target == operator.getitem and n.args[0] is input_ids_node] + assert len(getitem_nodes) == 1, f"Expected 1 slice node after sharding, got {len(getitem_nodes)}" + sliced_node = getitem_nodes[0] + + # After sharding, the raw node must only feed the slice; all downstream + # consumers are rewired to sliced_node by replace_node_users. + assert set(input_ids_node.users.keys()) == {sliced_node} + + for user in original_users: + assert input_ids_node not in user.all_input_nodes, \ + f"User '{user.name}' still references the unsharded input_ids_node" + assert sliced_node in user.all_input_nodes, \ + f"User '{user.name}' does not reference the sliced node" + + def test_preserves_topological_order_when_sym_placeholder_follows_input(self): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.fx import find_node_by_name, get_node_shape_meta + from deepspeed.compile.util import shard_tensor_node, get_input_id_node + + # Regression test for the torch 2.9 bf16 trace where the SymInt + # placeholder can appear after input_ids. shard_tensor_node must still + # produce a lint-clean graph instead of inserting getitem before its + # symbolic slice dependencies. + gm, _ = create_gm_nodes(seq_len=64) + input_ids_node = get_input_id_node(gm) + seq_symint = get_node_shape_meta(input_ids_node).shape[1] + sym_seq_node = find_node_by_name(gm, str(seq_symint)) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + nodes = list(gm.graph.nodes) + input_idx = nodes.index(input_ids_node) + sym_idx = nodes.index(sym_seq_node) + assert sym_idx < input_idx, "Expected source graph to place the symbolic placeholder before input_ids" + + # Reorder placeholders to mirror the torch 2.9 bf16 trace where the symbolic + # sequence placeholder can appear after input_ids. + reordered_nodes = nodes[:] + reordered_nodes.pop(input_idx) + reordered_nodes.insert(sym_idx, input_ids_node) + reordered_nodes.pop(sym_idx + 1) + reordered_nodes.insert(input_idx, sym_seq_node) + + reordered_graph = Graph() + env = {} + for node in reordered_nodes: + new_node = reordered_graph.node_copy(node, lambda n: env[n]) + new_node.meta = node.meta.copy() + env[node] = new_node + reordered_graph.lint() + + reordered_gm = GraphModule(gm, reordered_graph) + reordered_input_ids = get_input_id_node(reordered_gm) + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + shard_tensor_node(reordered_gm, reordered_input_ids) + + reordered_gm.graph.lint() diff --git a/tests/unit/v1/compile/test_compile_fx.py b/tests/unit/v1/compile/test_compile_fx.py new file mode 100644 index 000000000000..e07f173b69fc --- /dev/null +++ b/tests/unit/v1/compile/test_compile_fx.py @@ -0,0 +1,36 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +from torch.fx import Graph + +from deepspeed.compile.fx import add_end_backward, replace_reduce_outputs_with_none, get_output_node +from deepspeed.compile.util import get_deepcompile_handle, is_deepcompile_supported + + +@pytest.mark.skipif(not is_deepcompile_supported(), reason="DeepCompile requires CUDA and supported PyTorch") +def test_end_backward_depends_on_all_reduce_nodes(): + get_deepcompile_handle() + + graph = Graph() + grad = graph.placeholder("grad") + reduce_a = graph.create_node("call_function", torch.ops.dc.reduce_grad.default, (grad, 7, 11), name="reduce_a") + reduce_b = graph.create_node("call_function", torch.ops.dc.reduce_grad.default, (grad, 7, 12), name="reduce_b") + graph.output((grad, )) + + add_end_backward(graph, 7) + replace_reduce_outputs_with_none(graph) + graph.lint() + + end_backward = next(n for n in graph.nodes if n.target == torch.ops.dc.end_backward.default) + deps, graph_id = end_backward.args + output_node = get_output_node(graph) + + assert graph_id == 7 + assert list(deps) == [reduce_a, reduce_b] + assert end_backward in reduce_a.users + assert end_backward in reduce_b.users + assert output_node.args == ((grad, ), ) diff --git a/tests/unit/v1/compile/test_compile_zero.py b/tests/unit/v1/compile/test_compile_zero.py new file mode 100644 index 000000000000..16ad12d30f13 --- /dev/null +++ b/tests/unit/v1/compile/test_compile_zero.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator + +from unit.v1.compile.util import compare_loss +from unit.common import DistributedTest +from unit.util import bf16_required_version_check, skip_on_arch +import deepspeed +from deepspeed.ops.aio import AsyncIOBuilder + +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + + +class TestZeRO(DistributedTest): + world_size = 2 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" + ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + if offload_device == OffloadDeviceEnum.nvme: + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + compare_loss(self, config_dict, dtype) + + +class TestDeepCompile(DistributedTest): + world_size = 2 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 3]) + @pytest.mark.parametrize('deepcompile', [True]) # deepcompile==False is included in test_compile_zero + def test(self, zero_stage, dtype, deepcompile): + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" + ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": deepcompile + } + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + # Need warmup steps + compare_loss(self, config_dict, dtype, iteration=10) + + @pytest.mark.parametrize('dtype', [torch.float32]) + @pytest.mark.parametrize('zero_stage', [3]) + def test_padded_shard_handling(self, zero_stage, dtype): + """Test that parameters with padding (uneven division) work correctly with DeepCompile""" + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + # Use a hidden dimension that requires padding when divided across ranks + # With world_size=2, a hidden_dim of 13 creates parameters that need padding + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True + } + } + + # This should work correctly with our padding-aware implementation + # The test verifies that padded parameters are handled properly + compare_loss(self, config_dict, dtype, iteration=1, hidden_dim_override=13) + + @pytest.mark.parametrize('dtype', [torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 3]) + def test_free_activation_mode(self, zero_stage, dtype): + """Test that eagerly free activations work correctly and the threshold is configurable""" + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True, + "free_activation": True, + "free_activation_threshold": 0, + } + } + + compare_loss(self, config_dict, dtype) + + @pytest.mark.parametrize('dtype', ["bfloat16", "float16"]) + @pytest.mark.parametrize('zero_stage', [3]) + def test_fusing_allgather_and_autocast(self, zero_stage, dtype): + """Test that allgather and autocast can be correctly fused with DeepCompile""" + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "torch_autocast": { + "enable": True, + "dtype": dtype, + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True + } + } + + compare_loss(self, config_dict, torch.float32) diff --git a/tests/unit/v1/compile/test_selective_gather.py b/tests/unit/v1/compile/test_selective_gather.py new file mode 100644 index 000000000000..7abd9e930587 --- /dev/null +++ b/tests/unit/v1/compile/test_selective_gather.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from types import SimpleNamespace + +import torch + +import deepspeed.compile.passes.selective_gather as selective_gather_pass +from deepspeed.compile.profilers import ProfilingResult + + +class FakeAccelerator: + + def __init__(self, total_mem=1000, available_mem=250, device="cpu"): + self._total_mem = total_mem + self._available_mem = available_mem + self._device = device + + def total_memory(self): + return self._total_mem + + def available_memory(self): + return self._available_mem + + def current_device(self): + return self._device + + +class FakeDeepCompileHandle: + + def __init__(self): + self.persistent_ds_ids = [] + + def set_persistent(self, ds_id): + self.persistent_ds_ids.append(ds_id) + + +def _make_param(numel, ds_persist=False): + return SimpleNamespace(numel=numel, + dtype=torch.float32, + param=SimpleNamespace(ds_persist=ds_persist, ds_shape=(numel, ))) + + +def test_compute_persistence_budget_prefers_peak_resident_alloc(): + budget = selective_gather_pass._compute_persistence_budget(all_graph_mem_records=[[("fwd", 700, 0, 980)], + [("bwd", 720, 20, 800)]], + total_mem=1000, + mem_margin=0.1) + + assert budget["usable_mem"] == 900 + assert budget["peak_resident_alloc"] == 720 + assert budget["transient_peak"] == 980 + assert budget["available_mem"] == 180 + assert budget["profiled_list_count"] == 2 + + +def test_compute_persistence_budget_clamps_when_resident_alloc_exceeds_budget(): + budget = selective_gather_pass._compute_persistence_budget(all_graph_mem_records=[[("fwd", 920, 0, 980)], + [("bwd", 910, -10, 950)]], + total_mem=1000, + mem_margin=0.1) + + assert budget["usable_mem"] == 900 + assert budget["peak_resident_alloc"] == 920 + assert budget["available_mem"] == 0 + + +def test_selective_gather_sets_persistent_params_when_resident_headroom_exists(monkeypatch): + fake_handle = FakeDeepCompileHandle() + + monkeypatch.setattr(selective_gather_pass, "get_accelerator", lambda: FakeAccelerator(available_mem=220)) + monkeypatch.setattr(selective_gather_pass, "get_deepcompile_handle", lambda: fake_handle) + monkeypatch.setattr(selective_gather_pass.dist, "get_rank", lambda: 0) + monkeypatch.setattr(selective_gather_pass.dist, "all_reduce", lambda tensor, op: tensor) + + profiling_results = { + 0: + ProfilingResult(fwd_graph=SimpleNamespace(nodes=[]), + bwd_graph=SimpleNamespace(nodes=[]), + fwd_mem=[("fwd", 700, 0, 950)], + bwd_mem=[("bwd", 680, -20, 720)]) + } + param_manager = { + 0: + SimpleNamespace(params={ + "small": _make_param(25), + "large": _make_param(60), + }, + ds_ids={ + "small": 1, + "large": 2, + }) + } + gm = object() + + returned = selective_gather_pass.selective_gather(gm, + graph_id=0, + graph_order=[(0, True)], + profiling_results=profiling_results, + create_inputs_fn=None, + mem_budget=0.0, + param_manager=param_manager, + bwd=True) + + assert returned is gm + assert fake_handle.persistent_ds_ids == [1] diff --git a/tests/unit/v1/compile/util.py b/tests/unit/v1/compile/util.py new file mode 100644 index 000000000000..2e4e45bd2b31 --- /dev/null +++ b/tests/unit/v1/compile/util.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from copy import deepcopy +import os +import random +import numpy as np + +import torch + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import GatheredParameters + +from unit.simple_model import SimpleModel +from unit.common import allclose_on_all_ranks + + +def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None): + hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10 + + # the default tolerances of torch.testing.assert_close are too small + RTOL = 5e-1 + ATOL = 1e-2 + + # Use a fixed seed for determinism. We don't use the @enable_determinism decorator + # because it also sets torch.use_deterministic_algorithms(True), which seems + # incompatible with torch.compile() in test environments. + # Might be related to https://github.com/pytorch/pytorch/issues/159855 + local_rank = int(os.getenv("LOCAL_RANK", "0")) + seed = 123 + local_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + get_accelerator().manual_seed(seed) + get_accelerator().manual_seed_all(seed) + + device = torch.device(get_accelerator().current_device_name()) + model = SimpleModel(hidden_dim) + + i = get_accelerator().current_device() + baseline_model = deepcopy(model) + baseline_config = deepcopy(config) + baseline_config["zero_optimization"]["stage"] = 0 + baseline_config["zero_optimization"]["offload_optimizer"] = {} + baseline_engine, baseline_optimizer, _, _ = deepspeed.initialize(config=baseline_config, + model=baseline_model, + model_parameters=baseline_model.parameters()) + + if config["zero_optimization"]["stage"] == 3: + with deepspeed.zero.Init(config_dict_or_path=config): + target_model = SimpleModel(hidden_dim) + with GatheredParameters(target_model.parameters(), modifier_rank=0): + for p1, p2 in zip(target_model.parameters(), model.parameters()): + p1.data.copy_(p2.data) + else: + target_model = deepcopy(model) + + target_engine, target_optimizer, _, _ = deepspeed.initialize(config=config, + model=target_model, + model_parameters=target_model.parameters()) + target_engine.compile() + + train_batch_size = config["train_micro_batch_size_per_gpu"] + + xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=dtype) for _ in range(iteration)] + ys = [torch.randn_like(x) for x in xs] + + for x, y in zip(xs, ys): + baseline_loss = baseline_engine(x, y) + target_loss = target_engine(x, y) + + allclose_on_all_ranks(baseline_loss, target_loss, "Loss values are not close.", rtol=RTOL, atol=ATOL) + + baseline_engine.backward(baseline_loss) + target_engine.backward(target_loss) + + baseline_engine.step() + target_engine.step() + + with GatheredParameters(target_engine.parameters()): + for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()): + allclose_on_all_ranks(p1, p2, "Parameters are not equal.", rtol=RTOL, atol=ATOL) + + baseline_engine.destroy() + target_engine.destroy() + + +def compare_sp_loss(self, config, sp_size, iterations=3): + """ + Compare AutoSP compiled model loss against a compiled Ulysses SP model (ground truth). + + Both engines are trained in lockstep. After all training steps the final-step + losses are compared. + """ + import torch.nn.functional as F + from transformers import AutoModelForCausalLM, AutoConfig + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from deepspeed.compile import constants as autosp_constants + from deepspeed.compile.custom_ops.sp_dp_registry import populate_registry, get_group + from deepspeed.sequence.layer import DistributedAttention + + RTOL, ATOL = 0.1, 0.01 + model_name = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + + torch.manual_seed(42) + get_accelerator().manual_seed_all(42) + device = torch.device(get_accelerator().current_device_name()) + + model_config = AutoConfig.from_pretrained(model_name) + model_config._attn_implementation = "sdpa" + base_model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + vocab_size = model_config.vocab_size + + # Set up SP/DP process groups (shared by both Ulysses and AutoSP). + dp_size = dist.get_world_size() // sp_size + populate_registry(sp_size, dp_size) + # The DP-rank index selects which SP group the current rank belongs to. + sp_group = get_group(dist.get_rank() // sp_size) + sp_rank = dist.get_rank() % sp_size + chunk = seq_length // sp_size + + # Build a DistributedAttention wrapper that mirrors distributed_attention.py. + # Registered under a unique key so the model's "sdpa" slot stays untouched — + # AutoSP's graph pass can therefore find F.scaled_dot_product_attention nodes. + def _sdpa_inner(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): + # DistributedAttention delivers tensors in [b, s, n, h]; SDPA wants [b, n, s, h]. + out = F.scaled_dot_product_attention(q.permute(0, 2, 1, 3), + k.permute(0, 2, 1, 3), + v.permute(0, 2, 1, 3), + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale) + return out.permute(0, 2, 1, 3) + + _dist_attn = DistributedAttention(_sdpa_inner, sp_group, scatter_idx=2, gather_idx=1) + + def _ulysses_attn_forward(module, + query_states, + key_states, + value_states, + attention_mask, + scaling=None, + dropout=0.0, + is_causal=False, + **kwargs): + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + out = _dist_attn(q, k, v, batch_dim_idx=0, dropout_p=dropout, is_causal=is_causal, scale=scaling) + return out, None + + ALL_ATTENTION_FUNCTIONS["ulyssess"] = _ulysses_attn_forward + + # Ulysses baseline: regular torch.compile, no deepcompile or autosp pass. + ulysses_config = deepcopy(config) + ulysses_config.pop("compile", None) + ulysses_model = deepcopy(base_model) + ulysses_model.config._attn_implementation = "ulyssess" + ulysses_engine, _, _, _ = deepspeed.initialize(config=ulysses_config, + model=ulysses_model, + model_parameters=ulysses_model.parameters()) + ulysses_engine.compile() + + # AutoSP model: sdpa so the autosp pass can find F.scaled_dot_product_attention. + # dynamic=True ensures all shape dimensions are treated symbolically so the autosp + # pass can correctly shard the sequence dimension for all dtypes including fp16/bf16. + autosp_model = deepcopy(base_model) + autosp_engine, _, _, _ = deepspeed.initialize(config=config, + model=autosp_model, + model_parameters=autosp_model.parameters()) + autosp_engine.compile(compile_kwargs={"dynamic": True}) + + # Train both engines in lockstep; compare the losses at the final step. + ul_loss = autosp_loss = None + for i in range(iterations): + torch.manual_seed(42 + i) + full_ids = torch.randint(0, vocab_size, (1, seq_length), device=device) + + # Ulysses: each rank processes its own shard. + shard_ids = full_ids[:, sp_rank * chunk:(sp_rank + 1) * chunk] + shard_pos = torch.arange(sp_rank * chunk, (sp_rank + 1) * chunk, device=device).unsqueeze(0) + shard_mask = torch.ones(1, chunk, device=device, dtype=torch.long) + ul_out = ulysses_engine(input_ids=shard_ids, + labels=shard_ids, + position_ids=shard_pos, + attention_mask=shard_mask) + # Average per-shard losses across SP ranks to get the full-sequence loss. + ul_loss = ul_out.loss.clone() + dist.all_reduce(ul_loss, group=sp_group) + ul_loss = ul_loss / sp_size + + # AutoSP: full sequence. dynamic=True makes all shapes symbolic, so mark_dynamic + # is not needed; only the tag attributes that the autosp pass uses are set here. + autosp_ids = full_ids.clone() + autosp_lbl = autosp_ids.clone() + autosp_pos = torch.arange(seq_length, device=device).unsqueeze(0) + autosp_msk = torch.ones(1, seq_length, device=device, dtype=torch.long) + autosp_ids.tag = autosp_constants.AUTOSP_INPUT_ID_KEY + autosp_lbl.tag = autosp_constants.AUTOSP_LABEL_ID_KEY + autosp_pos.tag = autosp_constants.AUTOSP_POSITION_ID_KEY + autosp_out = autosp_engine(input_ids=autosp_ids, + labels=autosp_lbl, + position_ids=autosp_pos, + attention_mask=autosp_msk) + autosp_loss = autosp_out.loss + + ulysses_engine.backward(ul_out.loss) + ulysses_engine.step() + autosp_engine.backward(autosp_loss) + autosp_engine.step() + + allclose_on_all_ranks(autosp_loss, ul_loss, "AutoSP and Ulysses losses are not close.", rtol=RTOL, atol=ATOL) + + ulysses_engine.destroy() + del ALL_ATTENTION_FUNCTIONS["ulyssess"] + autosp_engine.destroy() + + +def create_gm_nodes(batch_size: int = 1, seq_len: int = 16): + """ + Load a tiny LlamaForCausalLM, tag inputs with AutoSP keys, mark the sequence + dimension dynamic, and capture the torch-fx GraphModule via a custom + torch.compile backend. + + The returned gm is identical to what the autosp pass receives during training: + placeholder nodes carry tensor_dict tags and meta['val'] shapes are symbolic + (SymInt) in the sequence dimension. + + Returns: + gm – GraphModule with fully populated node metadata + inputs – (input_ids, labels, position_ids) used for tracing + """ + from transformers import AutoModelForCausalLM, AutoConfig + from deepspeed.compile import constants + + # Each call needs a clean dynamo state; without this, the recompile_limit + # (default 8) is exhausted across tests and the backend is never invoked. + torch._dynamo.reset() + + model_name = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + model_config = AutoConfig.from_pretrained(model_name) + model_config._attn_implementation = "sdpa" + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + model.eval() + + vocab_size = model_config.vocab_size + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + position_ids = torch.arange(seq_len).unsqueeze(0) + + # dynamo propagates Python tensor attributes into node.meta['tensor_dict']; + # find_node_by_tag relies on this to identify the AutoSP input nodes. + input_ids.tag = constants.AUTOSP_INPUT_ID_KEY + labels.tag = constants.AUTOSP_LABEL_ID_KEY + position_ids.tag = constants.AUTOSP_POSITION_ID_KEY + + # Marking the sequence dim dynamic causes dynamo to emit a SymInt placeholder + # node and store symbolic shapes in node.meta['val'], which shard_tensor_node + # needs to locate the sequence-length symbol in the graph. + torch._dynamo.decorators.mark_dynamic(input_ids, 1) + torch._dynamo.decorators.mark_dynamic(labels, 1) + torch._dynamo.decorators.mark_dynamic(position_ids, 1) + + captured_gm = [None] + + def _capture_backend(gm, example_inputs): + if captured_gm[0] is None: + captured_gm[0] = gm + return gm + + compiled = torch.compile(model, backend=_capture_backend, dynamic=True) + with torch.no_grad(): + compiled(input_ids=input_ids, labels=labels, position_ids=position_ids) + + assert captured_gm[0] is not None, "Capture backend was never invoked — graph capture failed" + return captured_gm[0], (input_ids, labels, position_ids) + + +def find_sym_seq_node(gm): + """ + Return the SymInt placeholder node for the sequence-length dimension of + input_ids, or None if it cannot be found. + """ + from deepspeed.compile.util import get_input_id_node + from deepspeed.compile.fx import get_node_shape_meta, find_node_by_name + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + return find_node_by_name(gm, str(seq_symint)) diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/v1/half_precision/test_bf16.py similarity index 60% rename from tests/unit/runtime/half_precision/test_bf16.py rename to tests/unit/v1/half_precision/test_bf16.py index 916267a6ad42..661aaeb31fc0 100644 --- a/tests/unit/runtime/half_precision/test_bf16.py +++ b/tests/unit/v1/half_precision/test_bf16.py @@ -12,6 +12,9 @@ from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader from unit.util import bf16_required_version_check from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from unit.v1.zero.test_zero_user_backward import (initialize_distributed, create_ddp_model, collect_ddp_gradients, + collect_gradients_safe, compare_gradients) class TestAdamBF16ZeroOneCycleCompatibility(DistributedTest): @@ -27,6 +30,7 @@ def test(self, zero_stage=2, use_cpu_offload=False): pytest.skip("cpu-adam is not compatible") config_dict = { + "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "optimizer": { "type": "Adam", @@ -87,7 +91,7 @@ def test(self, zero_stage=2, use_cpu_offload=False): pytest.skip("cpu-adam is not compatible") config_dict = { - "train_batch_size": 4, + "train_micro_batch_size_per_gpu": 4, "steps_per_print": 1, "fp16": { "enabled": False, @@ -180,7 +184,7 @@ def test(self, optimizer_constructor, zero_stage=2): ) config_dict = { - "train_batch_size": 2, + "train_micro_batch_size_per_gpu": 2, "steps_per_print": 1, "fp16": { "enabled": False @@ -209,7 +213,7 @@ def test(self): ) config_dict = { - "train_batch_size": 2, + "train_micro_batch_size_per_gpu": 2, "steps_per_print": 1, "optimizer": { "type": "Adam", @@ -258,7 +262,7 @@ def test(self, stage=2): ) config_dict = { - "train_batch_size": 1, + "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "fp16": { "enabled": False @@ -286,8 +290,8 @@ def test(self, stage=2): model.step() -@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"]) -@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16], ids=["fp16", "bfp16"]) +@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) +@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bf16", "default"]) class TestZeroDtypeCocktail(DistributedTest): world_size = 2 @@ -298,10 +302,14 @@ def test(self, comp_type, comm_type): " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) - type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"} + if comp_type == torch.float16 or comm_type == torch.float16: + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + type_str = {torch.float16: "fp16", torch.bfloat16: "bf16"} config_dict = { - "train_batch_size": 2, + "train_micro_batch_size_per_gpu": 2, "steps_per_print": 1, "fp16": { "enabled": comp_type == torch.float16 @@ -312,8 +320,11 @@ def test(self, comp_type, comm_type): "zero_optimization": { "stage": 2 }, - "communication_data_type": type_str[comm_type] } + if comm_type is not None: + config_dict["communication_data_type"] = type_str[comm_type] + else: + comm_type = comp_type hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -336,3 +347,162 @@ def custom_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False) model.backward(loss) model.step() dist.reduce = orig_torch_reduce + + +@pytest.mark.parametrize("bf16_optimizer_states,use_cpu_offload,zero_stage", [ + pytest.param(False, True, 1, id="zero_stage_1_cpu_offload"), + pytest.param(True, False, 1, id="zero_stage_1_bf16_opt_states_True"), + pytest.param(True, True, 1, id="zero_stage_1_bf16_opt_states_cpu_offload"), + pytest.param(False, True, 2, id="zero_stage_2_cpu_offload"), + pytest.param(True, False, 2, id="zero_stage_2_bf16_opt_states_True"), + pytest.param(True, True, 2, id="zero_stage_2_bf16_opt_states_cpu_offload"), + pytest.param(False, True, 3, id="zero_stage_3_cpu_offload"), + pytest.param(True, False, 3, id="zero_stage_3_bf16_opt_states_True"), + pytest.param(True, True, 3, id="zero_stage_3_bf16_opt_states_cpu_offload"), +]) +class TestBF16MasterWeightsGradients(DistributedTest): + world_size = 2 + + def test_gradients_match_ddp(self, bf16_optimizer_states, use_cpu_offload, zero_stage): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + hidden_dim = 6 + lr = 1e-3 + seed = 123 + + device, rank, dtype = initialize_distributed() + + model_ddp, optimizer_ddp = create_ddp_model(SimpleModel, + device, + rank, + dtype, + seed=seed, + lr=lr, + hidden_dim=hidden_dim, + nlayers=2) + + torch.manual_seed(seed) + ds_model = SimpleModel(hidden_dim, nlayers=2) + + bf16_config = { + "enabled": True, + "bf16_master_weights_and_grads": True, + } + if bf16_optimizer_states: + bf16_config["bf16_optimizer_states"] = True + + zero_config = {"stage": zero_stage} + if use_cpu_offload: + zero_config["cpu_offload"] = True + + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": lr + } + }, + "bf16": bf16_config, + "zero_optimization": zero_config + } + + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=ds_model, + model_parameters=ds_model.parameters()) + + data_loader = random_dataloader(model=engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.bfloat16) + batch = next(iter(data_loader)) + + optimizer_ddp.zero_grad() + loss_ddp = model_ddp(batch[0], batch[1]) + loss_ddp.backward() + grads_ddp = collect_ddp_gradients(model_ddp) + + loss_ds = engine(batch[0], batch[1]) + loss_ds.backward() + grads_ds = collect_gradients_safe(engine) + + compare_gradients( + grads_ddp, + grads_ds, + step_info= + f"bf16_optimizer_states={bf16_optimizer_states}, cpu_offload={use_cpu_offload}, zero_stage={zero_stage}") + + optimizer_ddp.step() + optimizer_ddp.zero_grad() + engine.step() + engine.zero_grad() + + if bf16_optimizer_states and use_cpu_offload: + # With CPU offload the Adam moments must be allocated in bf16 on the host so the + # offloaded optimizer-state footprint is smaller than with fp32 moments. + cpu_adam_state = engine.optimizer.optimizer.state + moment_tensors = [] + for param_state in cpu_adam_state.values(): + for moment_key in ("exp_avg", "exp_avg_sq"): + if moment_key in param_state: + moment_tensors.append(param_state[moment_key]) + assert moment_tensors, "expected Adam moment tensors to be allocated after a step" + for moment in moment_tensors: + assert moment.dtype == torch.bfloat16, f"expected bf16 moment, got {moment.dtype}" + assert moment.device.type == "cpu", f"expected moment on cpu, got {moment.device}" + + engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestBF16OptimizerStatesOffloadValidation(DistributedTest): + world_size = 1 + + def test_user_cpu_adam_must_enable_bf16_states(self, zero_stage): + """A user-provided DeepSpeedCPUAdam must be built with fp32_optimizer_states=False + to combine bf16_optimizer_states with ZeRO-Offload, otherwise the moments would + silently stay fp32 and the memory benefit would be lost.""" + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + hidden_dim = 6 + model = SimpleModel(hidden_dim, nlayers=2) + # fp32_optimizer_states defaults to True, which keeps fp32 moments and is + # incompatible with bf16_optimizer_states under ZeRO-Offload. + optimizer = DeepSpeedCPUAdam(model.parameters()) + + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True, + "bf16_master_weights_and_grads": True, + "bf16_optimizer_states": True, + }, + # offload_optimizer is the current config key for ZeRO optimizer offload + # (TestBF16MasterWeightsGradients above still uses the legacy cpu_offload alias). + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": { + "device": "cpu" + }, + }, + } + + with pytest.raises(AssertionError, match="fp32_optimizer_states=False"): + deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) diff --git a/tests/unit/v1/half_precision/test_with_autocast.py b/tests/unit/v1/half_precision/test_with_autocast.py new file mode 100644 index 000000000000..a9ea15d8e683 --- /dev/null +++ b/tests/unit/v1/half_precision/test_with_autocast.py @@ -0,0 +1,269 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest, allclose_on_all_ranks +from deepspeed.ops.op_builder import CPUAdamBuilder +from unit.simple_model import SimpleModel, random_dataloader +from unit.util import bf16_required_version_check +from deepspeed.accelerator import get_accelerator +from unit.v1.zero.test_zero_user_backward import (initialize_distributed, create_ddp_model, collect_ddp_gradients, + collect_gradients_safe, compare_gradients) + + +class TestTorchAutocastWithPrecisionModes(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("precision_mode,zero_stage", [ + pytest.param("bf16_full", 1, id="z1_bf16_full_autocast"), + pytest.param("bf16_full", 2, id="z2_bf16_full_autocast"), + pytest.param("bf16_full", 3, id="z3_bf16_full_autocast"), + ]) + def test_gradients_match_ddp_with_autocast(self, precision_mode, zero_stage): + """Test BF16 with torch_autocast by comparing gradients with DDP baseline.""" + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + hidden_dim = 6 + lr = 1e-3 + seed = 123 + + device, rank, dtype = initialize_distributed() + + # Create DDP baseline with torch.autocast + model_ddp, optimizer_ddp = create_ddp_model(SimpleModel, + device, + rank, + dtype, + seed=seed, + lr=lr, + hidden_dim=hidden_dim, + nlayers=2) + + torch.manual_seed(seed) + ds_model = SimpleModel(hidden_dim, nlayers=2) + + # BF16 configuration + autocast_dtype = torch.bfloat16 + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": lr + } + }, + "torch_autocast": { + "enabled": True, + "dtype": str(autocast_dtype) + }, + "bf16": { + "enabled": True, + "bf16_master_weights_and_grads": True, + "bf16_optimizer_states": True + }, + "zero_optimization": { + "stage": zero_stage + } + } + + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=ds_model, + model_parameters=ds_model.parameters()) + + data_loader = random_dataloader(model=engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.bfloat16) + batch = next(iter(data_loader)) + + # DDP with torch.autocast + optimizer_ddp.zero_grad() + with torch.autocast(device_type=get_accelerator().device_name(), dtype=autocast_dtype, enabled=True): + loss_ddp = model_ddp(batch[0], batch[1]) + loss_ddp.backward() + grads_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed with torch_autocast config + loss_ds = engine(batch[0], batch[1]) + engine.backward(loss_ds) + grads_ds = collect_gradients_safe(engine) + + compare_gradients(grads_ddp, grads_ds, step_info=f"precision_mode={precision_mode}, zero_stage={zero_stage}") + + # Verify parameters have correct comm_dtype attribute for autocast + from deepspeed.runtime.torch_autocast import has_comm_dtype, get_comm_dtype + for name, param in engine.module.named_parameters(): + if "weight" in name: + # Linear layer weights should have comm_dtype set + assert has_comm_dtype(param), f"Parameter {name} should have comm_dtype attribute" + assert get_comm_dtype(param) == autocast_dtype, \ + f"Parameter {name} comm_dtype should be {autocast_dtype}, got {get_comm_dtype(param)}" + + optimizer_ddp.step() + engine.step() + + optimizer_ddp.zero_grad() + engine.zero_grad() + engine.destroy() + + @pytest.mark.parametrize("precision_mode,zero_stage", [ + pytest.param("fp16_master_wg", 2, id="z2_fp16_master_wg_autocast"), + pytest.param("fp16_master_wg", 3, id="z3_fp16_master_wg_autocast"), + ]) + def test_parameters_match_ddp_after_step(self, precision_mode, zero_stage): + """Test that parameters match DDP after a training step. + Note: This test is for FP16 where gradients are scaled and hard to compare. + """ + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + # FP16 mode requires CPU offload + if precision_mode == "fp16_master_wg" and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + # FP16 mode requires FP16 support + if precision_mode == "fp16_master_wg" and not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + + hidden_dim = 6 + lr = 1e-3 + seed = 123 + + device, rank, dtype = initialize_distributed() + + # For fp16 mode with autocast, use float32 model parameters + # For bf16 mode, use bfloat16 model parameters + model_dtype = torch.float32 if precision_mode == "fp16_master_wg" else dtype + + # Create DDP baseline with torch.autocast + model_ddp, optimizer_ddp = create_ddp_model(SimpleModel, + device, + rank, + model_dtype, + seed=seed, + lr=lr, + hidden_dim=hidden_dim, + nlayers=2) + + torch.manual_seed(seed) + ds_model = SimpleModel(hidden_dim, nlayers=2) + + # Configure based on precision mode + if precision_mode == "bf16_full": + autocast_dtype = torch.bfloat16 + precision_config = { + "bf16": { + "enabled": True, + "bf16_master_weights_and_grads": True, + "bf16_optimizer_states": True + } + } + zero_config = {"stage": zero_stage} + data_dtype = torch.bfloat16 + use_grad_scaler = False + else: # fp16_master_wg + autocast_dtype = torch.float16 + precision_config = {"fp16": {"enabled": True, "fp16_master_weights_and_grads": True}} + zero_config = {"stage": zero_stage, "offload_optimizer": {"device": "cpu"}} + data_dtype = torch.float16 + use_grad_scaler = True + + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": lr + } + }, + "torch_autocast": { + "enabled": True, + "dtype": str(autocast_dtype) + }, + "zero_optimization": zero_config, + **precision_config + } + + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=ds_model, + model_parameters=ds_model.parameters()) + + data_loader = random_dataloader(model=engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=data_dtype) + batch = next(iter(data_loader)) + + # DDP with torch.autocast and optional GradScaler for fp16 + if use_grad_scaler: + scaler = torch.amp.GradScaler() + + optimizer_ddp.zero_grad() + with torch.autocast(device_type=get_accelerator().device_name(), dtype=autocast_dtype, enabled=True): + loss_ddp = model_ddp(batch[0], batch[1]) + + if use_grad_scaler: + scaler.scale(loss_ddp).backward() + scaler.step(optimizer_ddp) + scaler.update() + else: + loss_ddp.backward() + optimizer_ddp.step() + + # DeepSpeed with torch_autocast config + loss_ds = engine(batch[0], batch[1]) + engine.backward(loss_ds) + engine.step() + + # Compare parameters after the optimizer step + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + for (name_ddp, param_ddp), (name_ds, param_ds) in zip(model_ddp.named_parameters(), + engine.module.named_parameters()): + # Remove 'module.' prefix from both for comparison + name_ddp_clean = name_ddp.replace('module.', '') + name_ds_clean = name_ds.replace('module.', '') + assert name_ddp_clean == name_ds_clean, f"Parameter name mismatch: {name_ddp_clean} vs {name_ds_clean}" + + # Get full parameter for ZeRO stage 3 + if hasattr(param_ds, 'ds_status') and param_ds.ds_status == ZeroParamStatus.NOT_AVAILABLE: + with deepspeed.zero.GatheredParameters([param_ds], modifier_rank=0): + param_ds_full = param_ds.detach().clone().cpu().float() + else: + param_ds_full = param_ds.detach().clone().cpu().float() + + param_ddp_full = param_ddp.detach().clone().cpu().float() + + # Use allclose_on_all_ranks for comparison + allclose_on_all_ranks( + param_ddp_full, + param_ds_full, + rtol=1e-3, + atol=1e-3, + assert_message= + f"Parameters differ for {name_ddp_clean} at precision_mode={precision_mode}, zero_stage={zero_stage}") + + # Verify parameters have correct comm_dtype attribute for autocast + from deepspeed.runtime.torch_autocast import has_comm_dtype, get_comm_dtype + for name, param in engine.module.named_parameters(): + if "weight" in name: + # Linear layer weights should have comm_dtype set + assert has_comm_dtype(param), f"Parameter {name} should have comm_dtype attribute" + assert get_comm_dtype(param) == autocast_dtype, \ + f"Parameter {name} comm_dtype should be {autocast_dtype}, got {get_comm_dtype(param)}" + + engine.destroy() diff --git a/tests/unit/v1/multimodal/__init__.py b/tests/unit/v1/multimodal/__init__.py new file mode 100644 index 000000000000..c8d652d4dc49 --- /dev/null +++ b/tests/unit/v1/multimodal/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team diff --git a/tests/unit/v1/multimodal/test_gemma4_config.py b/tests/unit/v1/multimodal/test_gemma4_config.py new file mode 100644 index 000000000000..5d1a3ca1ad12 --- /dev/null +++ b/tests/unit/v1/multimodal/test_gemma4_config.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team + +import pytest + +transformers = pytest.importorskip("transformers") +Gemma4Config = getattr(transformers, "Gemma4Config", None) +pytestmark = pytest.mark.skipif(Gemma4Config is None, reason="Gemma4Config not available in this transformers version") + + +def test_gemma4_text_config_fallback(): + config = Gemma4Config() + assert not hasattr(config, 'num_attention_heads'), \ + "Gemma4Config top-level should not have num_attention_heads" + arch_cfg = config.get_text_config() + assert hasattr(arch_cfg, 'num_attention_heads') + assert arch_cfg.num_attention_heads > 0 + assert hasattr(arch_cfg, 'num_key_value_heads') + assert arch_cfg.num_key_value_heads > 0 + assert hasattr(arch_cfg, 'num_hidden_layers') + assert arch_cfg.num_hidden_layers > 0 + assert hasattr(arch_cfg, 'hidden_size') + assert arch_cfg.hidden_size > 0 + + +def test_gemma4_text_config_matches_text_config(): + config = Gemma4Config() + arch_cfg = config.get_text_config() + assert arch_cfg is config.text_config + assert arch_cfg.num_attention_heads == config.text_config.num_attention_heads + assert arch_cfg.num_key_value_heads == config.text_config.num_key_value_heads diff --git a/tests/unit/v1/zero/test_offload_states.py b/tests/unit/v1/zero/test_offload_states.py new file mode 100644 index 000000000000..bf51d6df2e77 --- /dev/null +++ b/tests/unit/v1/zero/test_offload_states.py @@ -0,0 +1,405 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader, SimpleModel +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum +from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state +from deepspeed.runtime.zero.offload_states import get_state_devices + +# ============================================================================== +# ZeRO-1 and ZeRO-2 TESTS +# ============================================================================== + + +def validate_hp_params_device(model, device: torch.device): + """Validates that the sharded FP32 parameters are on the specified device.""" + for p in model.optimizer.single_partition_of_fp32_groups: + assert p.device.type == device.type, f"FP32 param partition is on {p.device}, expected {device}" + + +def validate_lp_params_device(model, device: torch.device): + """Validates that the sharded LP parameters are on the specified device.""" + for p in model.parameters(): + assert p.device.type == device.type, f"LP param partition is on {p.device}, expected {device}" + + +def validate_adam_states_device(model, device: torch.device): + """Validates that the sharded Adam optimizer states are on the specified device.""" + for p in model.optimizer.single_partition_of_fp32_groups: + if p in model.optimizer.state: + for state_key in ['exp_avg', 'exp_avg_sq']: + if state_key in model.optimizer.state[p]: + state_tensor = model.optimizer.state[p][state_key] + assert state_tensor.device.type == device.type, f"Optimizer state '{state_key}' is on {state_tensor.device}, expected {device}" + + +def validate_grad_device(model, device: torch.device) -> None: + """Validates that the sharded gradients are on the specified device.""" + # This path is for before step() where gradients are in averaged_gradients + if model.optimizer.averaged_gradients: + for grad_list in model.optimizer.averaged_gradients.values(): + if grad_list is not None: + for grad_tensor in grad_list: + assert grad_tensor.device.type == device.type, f"Gradient partition in averaged_gradients is on {grad_tensor.device}, expected {device}" + else: + # This path is for after step() or if grads are not in averaged_gradients + for p in model.optimizer.single_partition_of_fp32_groups: + if p.grad is not None: + assert p.grad.device.type == device.type, f"Gradient partition on hp_param.grad is on {p.grad.device}, expected {device}" + + +def is_offload_optimizer_enabled(config_dict): + return config_dict.get("zero_optimization", {}).get("offload_optimizer", {}).get("device", None) is not None + + +def is_only_offload_optimizer_states(offloaded_states, optimizer_offload_states): + if offloaded_states is None: + return False + offload_set = set(offloaded_states) + optim_states_set = set(optimizer_offload_states) + return offload_set - optim_states_set == set() + + +def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloaded_states, pin_memory, non_blocking): + """ + This function runs a training step, offloads states, reloads them, and verifies correctness for ZeRO-1/2. + The logic is carefully structured to handle transient gradient states vs. persistent parameter/optimizer states. + """ + offload_device = OffloadDeviceEnum.cpu + offload_torch_device = torch.device(offload_device.value) + accelerator_device = torch.device(get_accelerator().current_device_name()) + optimizer_device = offload_torch_device if is_offload_optimizer_enabled(config_dict) else accelerator_device + offload_only_optimizer_states = is_only_offload_optimizer_states( + offloaded_states, [OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.hp_params]) + expect_memory_change = not (is_offload_optimizer_enabled(config_dict) and offload_only_optimizer_states) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict) + + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + # We only need one step to verify the logic + batch = next(iter(data_loader)) + + loss = model(batch[0], batch[1]) + model.backward(loss) + + # Determine if we are testing a transient state (gradients) or a persistent state + # REVERTED: Condition now only checks for lp_grads as it's the relevant transient state. + is_grad_test = offloaded_states is not None and OffloadStateTypeEnum.lp_grads in offloaded_states + + if is_grad_test: + # --- TEST PATH FOR TRANSIENT GRADIENT STATE --- + # Gradients exist only between backward() and step(). We must test them here. + grads_expected = [[g.clone().detach() for g in grad_list] + for grad_list in model.optimizer.averaged_gradients.values() if grad_list is not None] + grad_numel = sum(sum(g.numel() for g in grad_list) for grad_list in grads_expected) + + alloc_before_offload = get_accelerator().memory_allocated() + model.offload_states(include=offloaded_states, + device=offload_device, + pin_memory=pin_memory, + non_blocking=non_blocking) + alloc_after_offload = get_accelerator().memory_allocated() + + if grad_numel > 0: + assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory for grads should decrease after offload {alloc_after_offload=} < {alloc_before_offload=}" + validate_grad_device(model, offload_torch_device) + + model.reload_states() + alloc_after_reload = get_accelerator().memory_allocated() + + if grad_numel > 0: + assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory for grads should increase after reload {alloc_after_reload=} > {alloc_after_offload=}" + validate_grad_device(model, accelerator_device) + + reloaded_grads = [ + grad_list for grad_list in model.optimizer.averaged_gradients.values() if grad_list is not None + ] + assert len(grads_expected) == len(reloaded_grads), "FAIL: Number of gradient groups changed after reload" + for expected_list, reloaded_list in zip(grads_expected, reloaded_grads): + for expected_g, reloaded_g in zip(expected_list, reloaded_list): + assert torch.equal(expected_g, reloaded_g), "FAIL: Reloaded gradient data does not match original" + + model.step() + + if not is_grad_test: + # --- TEST PATH FOR PERSISTENT STATES (Params, Optimizer States) --- + # These states exist after step(), so we can test them here. + + # --- Save state snapshots before offloading for data integrity check --- + lp_params_expected = [p.clone().detach() for p in model.parameters()] + hp_params_expected = [p.clone().detach() for p in model.optimizer.single_partition_of_fp32_groups] + + adam_params_in_state_before = [ + p for p in model.optimizer.single_partition_of_fp32_groups if p in model.optimizer.state + ] + adam_exp_avg_expected = [ + model.optimizer.state[p]['exp_avg'].clone().detach() for p in adam_params_in_state_before + ] + adam_exp_avg_sq_expected = [ + model.optimizer.state[p]['exp_avg_sq'].clone().detach() for p in adam_params_in_state_before + ] + + alloc_before_offload = get_accelerator().memory_allocated() + model.offload_states(include=offloaded_states, + device=offload_device, + pin_memory=pin_memory, + non_blocking=non_blocking) + alloc_after_offload = get_accelerator().memory_allocated() + + if expect_memory_change: + assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should decrease after offload" + + if offloaded_states is None or OffloadStateTypeEnum.lp_params in offloaded_states: + validate_lp_params_device(model, offload_torch_device) + if offloaded_states is None or OffloadStateTypeEnum.hp_params in offloaded_states: + validate_hp_params_device(model, offload_torch_device) + if offloaded_states is None or OffloadStateTypeEnum.optim_states in offloaded_states: + validate_adam_states_device(model, offload_torch_device) + + model.reload_states() + alloc_after_reload = get_accelerator().memory_allocated() + + if expect_memory_change: + assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should increase after reload" + + # --- Verify restored data integrity --- + for expected, restored in zip(lp_params_expected, model.parameters()): + assert torch.equal(expected, restored), "FAIL: Reloaded LP param data does not match original" + + for expected, restored in zip(hp_params_expected, model.optimizer.single_partition_of_fp32_groups): + assert torch.equal(expected, restored), "FAIL: Reloaded HP param data does not match original" + + adam_params_in_state_after = [ + p for p in model.optimizer.single_partition_of_fp32_groups if p in model.optimizer.state + ] + assert len(adam_params_in_state_before) == len( + adam_params_in_state_after), "FAIL: Number of params in optimizer state changed after reload" + + for expected, p in zip(adam_exp_avg_expected, adam_params_in_state_after): + assert torch.equal( + expected, model.optimizer.state[p]['exp_avg']), "FAIL: Reloaded 'exp_avg' data does not match original" + for expected, p in zip(adam_exp_avg_sq_expected, adam_params_in_state_after): + assert torch.equal( + expected, + model.optimizer.state[p]['exp_avg_sq']), "FAIL: Reloaded 'exp_avg_sq' data does not match original" + + # --- FINAL VALIDATION FOR ALL TESTS --- + validate_lp_params_device(model, accelerator_device) + validate_hp_params_device(model, optimizer_device) + validate_adam_states_device(model, optimizer_device) + + assert torch.any(torch.ne(list(model.parameters())[0], 0.0)) + + +@pytest.mark.parametrize("included_state", [ + OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.hp_params, + OffloadStateTypeEnum.lp_params, None +]) +@pytest.mark.parametrize("pin_memory", [False, True]) +@pytest.mark.parametrize("non_blocking", [False, True]) +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("static_offload_optimizer", [False, True]) +class TestDynamicOffloadStatesZero12(DistributedTest): + world_size = 2 + + def test_dynamic_offload_states_zero12(self, included_state, pin_memory, non_blocking, zero_stage, + static_offload_optimizer): + hidden_dim = 1024 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": zero_stage + }, + "bf16": { + "enabled": True + } + } + if static_offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + model = SimpleModel(hidden_dim, nlayers=4) + param_groups = [{ + "params": [p for n, p in model.named_parameters() if 'bias' not in n], + "weight_decay": 0.1 + }, { + "params": [p for n, p in model.named_parameters() if 'bias' in n], + "weight_decay": 0.0 + }] + offloaded_states = None if included_state is None else [included_state] + run_model_zero12(model, param_groups, config_dict, hidden_dim, torch.bfloat16, offloaded_states, pin_memory, + non_blocking) + + +# ============================================================================== +# ZeRO-3 TESTS +# ============================================================================== + + +def validate_device(model, state_device: dict[OffloadStateTypeEnum, torch.device], offloaded_states) -> None: + + def compare_device(state) -> bool: + devices = get_state_devices(model, state) + return len(devices) == 1 and state_device[state] in devices + + for state in OffloadStateTypeEnum: + if offloaded_states is None or state in offloaded_states: + if state == OffloadStateTypeEnum.contiguous_grad_buffer and state_device[state] == torch.device("cpu"): + assert len(get_state_devices(model, + state)) == 0, f"State {state} must be removed after offload_states()" + else: + assert compare_device(state), f"State {state} is not on device {state_device[state]}" + + +def run_model_zero3(model, param_groups, config_dict, hidden_dim, dtype, offloaded_states, pin_memory, non_blocking): + # Currently we only support OffloadDeviceEnum.cpu + offload_device = OffloadDeviceEnum.cpu + offload_torch_device = torch.device(offload_device.value) + accelerator_device = torch.device(get_accelerator().current_device_name()) + optimizer_device = offload_torch_device if is_offload_optimizer_enabled(config_dict) else accelerator_device + offload_only_optimizer_states = is_only_offload_optimizer_states( + offloaded_states, + [OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_grads]) + expect_memory_change = not (is_offload_optimizer_enabled(config_dict) and offload_only_optimizer_states) + + offload_state_device: dict[OffloadStateTypeEnum, torch.device] = { + OffloadStateTypeEnum.hp_params: offload_torch_device, + OffloadStateTypeEnum.lp_params: offload_torch_device, + OffloadStateTypeEnum.optim_states: offload_torch_device, + OffloadStateTypeEnum.lp_grads: offload_torch_device, + OffloadStateTypeEnum.contiguous_grad_buffer: offload_torch_device, + } + + reload_state_device: dict[OffloadStateTypeEnum, torch.device] = { + OffloadStateTypeEnum.hp_params: optimizer_device, + OffloadStateTypeEnum.lp_params: accelerator_device, + OffloadStateTypeEnum.optim_states: optimizer_device, + OffloadStateTypeEnum.lp_grads: optimizer_device, + OffloadStateTypeEnum.contiguous_grad_buffer: accelerator_device, + } + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + hp_params_expected = [safe_get_local_fp32_param(p).clone() for p in model.parameters()] + lp_params_expected = [p.ds_tensor.clone() for p in model.parameters()] + lp_grads_expected = model.optimizer.grad_partitions_flat_buffer.clone() + adam_exp_avg_expected = [safe_get_local_optimizer_state(p, "exp_avg").clone() for p in model.parameters()] + adam_exp_avg_sq = [safe_get_local_optimizer_state(p, "exp_avg_sq").clone() for p in model.parameters()] + + # Start offloading + alloc_before_offload = get_accelerator().memory_allocated() + model.offload_states(include=offloaded_states, + device=offload_device, + pin_memory=pin_memory, + non_blocking=non_blocking) + alloc_after_offload = get_accelerator().memory_allocated() + + if expect_memory_change: + assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory should decrease after offload {alloc_after_offload=} < {alloc_before_offload=}" + validate_device(model, offload_state_device, offloaded_states) + + # Reload states + model.reload_states() + alloc_after_reload = get_accelerator().memory_allocated() + + if expect_memory_change: + assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory should increase after offload back {alloc_after_reload=} > {alloc_after_offload=}" + + # Verify restored states + hp_param_restored = [safe_get_local_fp32_param(p) for p in model.parameters()] + for hp_param_expected, hp_param_restored in zip(hp_params_expected, hp_param_restored): + assert torch.equal(hp_param_expected, hp_param_restored) + + lp_param_restored = [p.ds_tensor for p in model.parameters()] + + for lp_param_expected, lp_param_restored in zip(lp_params_expected, lp_param_restored): + assert torch.equal(lp_param_expected, lp_param_restored) + + assert torch.equal(lp_grads_expected, model.optimizer.grad_partitions_flat_buffer) + + adam_exp_avg_restored = [safe_get_local_optimizer_state(p, "exp_avg") for p in model.parameters()] + for adam_exp_avg_expected, adam_exp_avg_restored in zip(adam_exp_avg_expected, adam_exp_avg_restored): + assert torch.equal(adam_exp_avg_expected, adam_exp_avg_restored) + + adam_exp_avg_sq_restored = [safe_get_local_optimizer_state(p, "exp_avg_sq") for p in model.parameters()] + for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored): + assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored) + + validate_device(model, reload_state_device, offloaded_states) + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +@pytest.mark.parametrize("included_state", [ + OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_params, OffloadStateTypeEnum.optim_states, + OffloadStateTypeEnum.lp_grads, OffloadStateTypeEnum.contiguous_grad_buffer, None +]) +@pytest.mark.parametrize("pin_memory", [False, True]) +@pytest.mark.parametrize("non_blocking", [False, True]) +@pytest.mark.parametrize("static_offload_optimizer", [False, True]) +class TestDynamicOffloadStatesZero3(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + + def test_dynamic_offload_states_zero3(self, included_state, pin_memory, non_blocking, static_offload_optimizer): + hidden_dim = 1024 + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + } + } + config_dict["bf16"] = {"enabled": True} + if static_offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=4) + + param_groups = [{ + "params": [p for n, p in model.named_parameters() if not 'bias' in n], + "weight_decay": 0.1 + }, { + "params": [p for n, p in model.named_parameters() if 'bias' in n], + "weight_decay": 0.0 + }] + offloaded_states = None if included_state is None else [included_state] + run_model_zero3(model, param_groups, config_dict, hidden_dim, torch.bfloat16, offloaded_states, pin_memory, + non_blocking) diff --git a/tests/unit/v1/zero/test_overlap_comm_record_stream.py b/tests/unit/v1/zero/test_overlap_comm_record_stream.py new file mode 100644 index 000000000000..431461703063 --- /dev/null +++ b/tests/unit/v1/zero/test_overlap_comm_record_stream.py @@ -0,0 +1,97 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from contextlib import nullcontext + +import torch + +import deepspeed.runtime.zero.stage_1_and_2 as zero_stage12 +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + + +class _FakeTensor: + + def __init__(self): + self.recorded_streams = [] + self.copied_from = None + + def copy_(self, other): + self.copied_from = other + return self + + def record_stream(self, stream): + self.recorded_streams.append(stream) + + +class _FakeAccelerator: + + def __init__(self, resolves_data_dependency, current_device_name="cpu"): + self._resolves_data_dependency = resolves_data_dependency + self._current_device_name = current_device_name + + def resolves_data_dependency(self): + return self._resolves_data_dependency + + def stream(self, stream): + return nullcontext() + + def current_stream(self): + return object() + + def current_device_name(self): + return self._current_device_name + + def synchronize(self): + return None + + +def _build_overlap_optimizer(monkeypatch, *, resolves_data_dependency): + optimizer = DeepSpeedZeroOptimizer.__new__(DeepSpeedZeroOptimizer) + optimizer.overlap_comm = True + optimizer.reduction_stream = object() + optimizer.dp_process_group = object() + optimizer.previous_reduced_grads = {} + + allreduced = _FakeTensor() + synced = [_FakeTensor(), _FakeTensor()] + + optimizer.allreduce_bucket = lambda *args, **kwargs: allreduced + optimizer.unflatten = lambda allreduced_tensor, small_bucket: synced + + monkeypatch.setattr( + zero_stage12, + "get_accelerator", + lambda: _FakeAccelerator(resolves_data_dependency), + ) + monkeypatch.setattr(zero_stage12.dist, "get_rank", lambda group=None: 0) + return optimizer, allreduced, synced + + +def test_allreduce_and_copy_records_stream_for_overlap_comm(monkeypatch): + optimizer, allreduced, synced = _build_overlap_optimizer(monkeypatch, resolves_data_dependency=False) + bucket = [_FakeTensor(), _FakeTensor()] + + optimizer.allreduce_and_copy(bucket, torch.float16) + + assert allreduced.recorded_streams == [optimizer.reduction_stream] + for buf, expected_synced in zip(bucket, synced): + assert buf.copied_from is expected_synced + assert buf.recorded_streams == [optimizer.reduction_stream] + + +def test_allreduce_and_copy_with_multiple_ranks_records_only_local_buffers(monkeypatch): + optimizer, allreduced, synced = _build_overlap_optimizer(monkeypatch, resolves_data_dependency=False) + bucket = [_FakeTensor(), _FakeTensor()] + + optimizer.allreduce_and_copy_with_multiple_ranks( + bucket, + torch.float16, + bucket_ranks=[0, 1], + ) + + assert allreduced.recorded_streams == [optimizer.reduction_stream] + assert bucket[0].copied_from is synced[0] + assert bucket[0].recorded_streams == [optimizer.reduction_stream] + assert bucket[1].copied_from is None + assert bucket[1].recorded_streams == [] diff --git a/tests/unit/v1/zero/test_stage2_flatten_on_gpu.py b/tests/unit/v1/zero/test_stage2_flatten_on_gpu.py new file mode 100644 index 000000000000..483da0e97279 --- /dev/null +++ b/tests/unit/v1/zero/test_stage2_flatten_on_gpu.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Test that ZeRO Stage 1 and 2 use the GPU flatten path when VRAM is sufficient. +Parametrized over zero_stage (1, 2) and dtype (fp32, fp16, bf16). +""" + +import pytest +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import set_log_level_from_string +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + +_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +def _apply_dtype_to_config(config_dict, dtype): + """Set bf16/fp16 in config_dict based on dtype; skip if not supported.""" + if dtype == "bf16": + if not get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported on this accelerator") + config_dict["bf16"] = {"enabled": True} + elif dtype == "fp16": + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported on this accelerator") + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + # fp32: no half-precision block + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"], ids=["fp32", "fp16", "bf16"]) +class TestStage2FlattenOnGPU(DistributedTest): + """ZeRO-1 and ZeRO-2 with small model should flatten on GPU (sufficient VRAM).""" + + world_size = 2 # Run on 2 GPUs when available + + def test_flatten_on_gpu_path_taken(self, monkeypatch, zero_stage, dtype): + """Assert the GPU flatten path was used (not CPU flatten + move).""" + if not get_accelerator().is_available(): + pytest.skip("Accelerator not available") + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + _apply_dtype_to_config(config_dict, dtype) + + set_log_level_from_string("info") + log_messages = [] + + def mock_logger_info(msg, *args, **kwargs): + log_messages.append(msg if isinstance(msg, str) else str(msg)) + + monkeypatch.setattr("deepspeed.utils.logger.info", mock_logger_info) + + hidden_dim = 64 + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + ) + + # Small model + no CPU offload => accelerator path logs "Flattening param group ... (sufficient memory)" + accel_path_logs = [m for m in log_messages if "Flattening param group" in m and "(sufficient memory)" in m] + assert accel_path_logs, ( + f"Expected accelerator flatten path (log should contain 'Flattening param group' and '(sufficient memory)'). " + f"Captured messages: {log_messages}") + + def test_flat_buffers_on_accelerator(self, zero_stage, dtype): + """Regression: flat buffers must end up on the accelerator (not left on CPU).""" + if not get_accelerator().is_available(): + pytest.skip("Accelerator not available") + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + _apply_dtype_to_config(config_dict, dtype) + + hidden_dim = 64 + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + engine, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + ) + opt = engine.optimizer + assert hasattr(opt, "bit16_groups_flat"), "ZeRO-1/2 optimizer should have bit16_groups_flat" + device_type = get_accelerator().device_name() + for i, flat in enumerate(opt.bit16_groups_flat): + assert flat.device.type == device_type, (f"Flat buffer {i} must be on {device_type}, got {flat.device}") + + @pytest.mark.world_size(1) + def test_flatten_on_accelerator_training_step(self, zero_stage, dtype): + """Regression: flat buffer must be detached so inplace ops during step don't crash.""" + if not get_accelerator().is_available(): + pytest.skip("Accelerator not available") + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + _apply_dtype_to_config(config_dict, dtype) + + hidden_dim = 64 + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + engine, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + ) + for flat in engine.optimizer.bit16_groups_flat: + assert flat.grad_fn is None, ("Flat buffer must be detached from autograd graph" + " to prevent inplace-modification errors during optimizer step") + + data_loader = random_dataloader(model=engine, + total_samples=8, + hidden_dim=hidden_dim, + device=engine.device, + dtype=_DTYPE_MAP[dtype]) + for batch in data_loader: + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/v1/zero/test_zero.py similarity index 59% rename from tests/unit/runtime/zero/test_zero.py rename to tests/unit/v1/zero/test_zero.py index 6e6aa5bc60fd..fde4a42bfb15 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/v1/zero/test_zero.py @@ -14,12 +14,14 @@ from torch.nn.modules.container import ModuleList from torch.nn.modules.loss import L1Loss from torch.nn.parameter import Parameter +from torch.nn.utils import skip_init -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel, random_dataloader import deepspeed from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint from deepspeed.runtime.zero.utils import ZeRORuntimeException @@ -52,7 +54,69 @@ def dump_state_dict(model): print(f"{name} {param.data}") -@pytest.mark.parametrize('zero_stage', [1, 2, 3]) +class TestBF16OptimizerGradReduction(DistributedTest): + world_size = 2 + + def test_boundary_microbatch_grad_is_reduced(self): + if not get_accelerator().is_bf16_supported(): + pytest.skip("bfloat16 is not supported on this accelerator") + + class ScaleModel(Module): + + def __init__(self): + super().__init__() + self.weight = Parameter(torch.ones(4)) + + def forward(self, x): + return (self.weight * x).sum() + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 2, + "zero_optimization": { + "stage": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "bf16": { + "enabled": True, + "immediate_grad_update": False, + }, + "data_types": { + "grad_accum_dtype": "fp32" + } + } + + model = ScaleModel() + engine, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + assert isinstance(engine.optimizer, BF16_Optimizer) + + rank = dist.get_rank() + rank_offset = 18 * rank + inputs = [ + torch.tensor([1, 3, 5, 7], dtype=torch.bfloat16, device=engine.device) + rank_offset, + torch.tensor([11, 13, 15, 17], dtype=torch.bfloat16, device=engine.device) + rank_offset, + ] + for i, x in enumerate(inputs): + engine.set_gradient_accumulation_boundary(i == len(inputs) - 1) + engine.backward(engine(x)) + + grad = engine.optimizer.fp32_groups_gradients_flat[0].detach().clone() + expected = torch.tensor([15, 17, 19, 21], dtype=grad.dtype, device=grad.device) + torch.testing.assert_close(grad, expected) + + gathered_grads = [torch.zeros_like(grad) for _ in range(dist.get_world_size())] + dist.all_gather(gathered_grads, grad) + torch.testing.assert_close(gathered_grads[0], gathered_grads[1]) + + engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) class TestZeroUnbalancedGradients(DistributedTest): world_size = 1 @@ -70,11 +134,11 @@ def test(self, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -82,21 +146,30 @@ def test(self, zero_stage): data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) run_unbalanced_gradients(model, data_loader) + model.destroy() -# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 +# testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227 +@pytest.mark.parametrize("mics_enabled", [True, False]) class TestZero3RepeatForwardLoop(DistributedTest): world_size = 1 - def test(self, zero_stage=3): + def test(self, mics_enabled, zero_stage=3): + if mics_enabled and get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") # force all params to be partitioned by forcing threshold=0 + mics_shard_size = -1 + if mics_enabled: + mics_shard_size = self.world_size + config_dict = { "train_micro_batch_size_per_gpu": 2, "gradient_accumulation_steps": 2, "steps_per_print": 1, "zero_optimization": { "stage": zero_stage, - "stage3_param_persistence_threshold": 0 + "stage3_param_persistence_threshold": 0, + "mics_shard_size": mics_shard_size, }, "optimizer": { "type": "Adam", @@ -104,11 +177,11 @@ def test(self, zero_stage=3): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 class AlbertLikeModel(torch.nn.Module): @@ -134,14 +207,17 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + -# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 -# also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372 -@pytest.mark.parametrize('zero_stage', [2, 3]) +# testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227 +# also reproduces the https://github.com/deepspeedai/DeepSpeed/pull/1372 +@pytest.mark.parametrize("zero_stage", [2, 3]) +@pytest.mark.parametrize("freeze_params", [True, False]) class TestZeroToFP32(DistributedTest): world_size = 2 - def test_1_param_group(self, tmpdir, zero_stage): + def test_1_param_group(self, tmpdir, zero_stage, freeze_params): # XXX: ideally refactor with the 2_param_group test as 75% is the same # force all params to be partitioned by forcing threshold=0 config_dict = { @@ -150,7 +226,7 @@ def test_1_param_group(self, tmpdir, zero_stage): "steps_per_print": 1, "zero_optimization": { "stage": zero_stage, - "stage3_param_persistence_threshold": 0 + "stage3_param_persistence_threshold": 0, }, "optimizer": { "type": "Adam", @@ -158,17 +234,17 @@ def test_1_param_group(self, tmpdir, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} class MyModel(torch.nn.Module): - def __init__(self, hidden_dim, n_layers): + def __init__(self, hidden_dim, n_layers, freeze_params): super().__init__() - # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that + # to reproduce https://github.com/deepspeedai/DeepSpeed/pull/1372 it is important that # the number of total elements is uneven: # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) @@ -176,6 +252,9 @@ def __init__(self, hidden_dim, n_layers): self.classifier = torch.nn.Linear(4, 1) # total 48+5=53 (uneven as desired) elements self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + if freeze_params: + self.ll[0].weight.requires_grad = False + self.ll[0].bias.requires_grad = False def forward(self, x, y): hidden = x @@ -188,9 +267,12 @@ def forward(self, x, y): world_size = dist.get_world_size() # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2 n_layers = world_size * 2 - model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) + model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + # Flush zero stage 3 cache + model.empty_partition_cache() + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): @@ -198,6 +280,7 @@ def forward(self, x, y): model.backward(loss) model.step() + model.empty_partition_cache() model.save_checkpoint(tmpdir) # make sure all sides saved it @@ -219,14 +302,16 @@ def forward(self, x, y): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() - #dump_state_dict(fp32_model) + # dump_state_dict(fp32_model) if dist.get_rank() == 0: for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float()) - def test_2_param_groups(self, tmpdir, zero_stage): + model.destroy() + + def test_2_param_groups(self, tmpdir, zero_stage, freeze_params): # TODO: # - need to test with multiple param groups # force all params to be partitioned by forcing threshold=0 @@ -237,7 +322,7 @@ def test_2_param_groups(self, tmpdir, zero_stage): "zero_allow_untested_optimizer": 1, "zero_optimization": { "stage": zero_stage, - "stage3_param_persistence_threshold": 0 + "stage3_param_persistence_threshold": 0, }, "optimizer": { "type": "Adam", @@ -245,18 +330,21 @@ def test_2_param_groups(self, tmpdir, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} class MyModel(torch.nn.Module): - def __init__(self, hidden_dim, n_layers): + def __init__(self, hidden_dim, n_layers, freeze_params): super().__init__() self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + if freeze_params: + self.ll[0].weight.requires_grad = False + self.ll[0].bias.requires_grad = False def forward(self, x, y): hidden = x @@ -268,7 +356,7 @@ def forward(self, x, y): world_size = dist.get_world_size() n_layers = world_size * 2 - model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) + model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params) optim_groups = [ { @@ -282,10 +370,14 @@ def forward(self, x, y): ] optim = torch.optim.SGD(optim_groups, lr=0.1) - model, _, _, _ = deepspeed.initialize(model=model, - model_parameters=model.parameters(), - optimizer=optim, - config=config_dict) + model, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + optimizer=optim, + config=config_dict, + ) + model.empty_partition_cache() + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): @@ -293,12 +385,13 @@ def forward(self, x, y): model.backward(loss) model.step() + model.empty_partition_cache() model.save_checkpoint(tmpdir) # make sure all sides saved it dist.barrier() - #dump_state_dict(model) + # dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): @@ -316,13 +409,15 @@ def forward(self, x, y): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() - #dump_state_dict(fp32_model) + # dump_state_dict(fp32_model) if dist.get_rank() == 0: for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float()) + model.destroy() + @pytest.mark.parametrize("allgather_bucket_size", [1000, 1001]) class TestIncorectAllgatherBucketSize(DistributedTest): @@ -335,7 +430,7 @@ def test(self, allgather_bucket_size, zero_stage=2): "steps_per_print": 1, "zero_optimization": { "stage": zero_stage, - "allgather_bucket_size": allgather_bucket_size + "allgather_bucket_size": allgather_bucket_size, }, "optimizer": { "type": "Adam", @@ -343,11 +438,11 @@ def test(self, allgather_bucket_size, zero_stage=2): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -358,11 +453,11 @@ def test(self, allgather_bucket_size, zero_stage=2): model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str(assertinfo) + assert ("allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str(assertinfo)) class TestPartitionNcclAlignment(DistributedTest): - world_size = 4 + world_size = 2 def test(self, zero_stage=2): config_dict = { @@ -378,11 +473,11 @@ def test(self, zero_stage=2): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @@ -391,7 +486,8 @@ def test(self, zero_stage=2): # get nccl all-gather send buffers alignment factor nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor - parallel_partitioned_bit16_groups = model.optimizer.parallel_partitioned_bit16_groups if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups + parallel_partitioned_bit16_groups = (model.optimizer.parallel_partitioned_bit16_groups + if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups) for data_parallel_partitions in parallel_partitioned_bit16_groups: for partition_id, partitioned_data in enumerate(data_parallel_partitions): # verify that data partition start locations are 4-byte aligned @@ -444,9 +540,14 @@ def __init__( self.loss = L1Loss(reduction="none") def forward(self, x: Tensor, y: Tensor, use_module_trace: bool, param_prefetching: bool) -> Dict[str, Tensor]: - _assert_partition_status(self, - {ZeroParamStatus.NOT_AVAILABLE, ZeroParamStatus.INFLIGHT, ZeroParamStatus.AVAILABLE} - if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE}) + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE, + } if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE}, + ) pre_layer_expected_states = { ZeroParamStatus.INFLIGHT if param_prefetching else ZeroParamStatus.NOT_AVAILABLE, @@ -471,9 +572,14 @@ def forward(self, x: Tensor, y: Tensor, use_module_trace: bool, param_prefetchin loss = self.loss(y_hat, y) - _assert_partition_status(self, - {ZeroParamStatus.NOT_AVAILABLE, ZeroParamStatus.INFLIGHT, ZeroParamStatus.AVAILABLE} - if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE}) + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE, + } if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE}, + ) return { "hidden1": hidden1, @@ -498,10 +604,12 @@ class EltwiseMultiplicationTestNetwork_NamedTuple(EltwiseMultiplicationTestNetwo def forward(self, *args, **kwargs) -> EltwiseMultiplicationNamedTuple: outputs_dicts = super().forward(*args, **kwargs) - return EltwiseMultiplicationNamedTuple(hidden1=outputs_dicts['hidden1'], - hidden2=outputs_dicts['hidden2'], - y_hat=outputs_dicts['y_hat'], - loss=outputs_dicts['loss']) + return EltwiseMultiplicationNamedTuple( + hidden1=outputs_dicts["hidden1"], + hidden2=outputs_dicts["hidden2"], + y_hat=outputs_dicts["y_hat"], + loss=outputs_dicts["loss"], + ) @staticmethod def to_dict(outputs: EltwiseMultiplicationNamedTuple) -> Dict[str, Tensor]: @@ -513,18 +621,20 @@ def to_dict(outputs: EltwiseMultiplicationNamedTuple) -> Dict[str, Tensor]: } -EltwiseMultiplication_namedtuple = namedtuple('EltwiseMultiplication_namedtuple', - ['hidden1', 'hidden2', 'y_hat', 'loss']) +EltwiseMultiplication_namedtuple = namedtuple("EltwiseMultiplication_namedtuple", + ["hidden1", "hidden2", "y_hat", "loss"]) class EltwiseMultiplicationTestNetwork_namedtuple(EltwiseMultiplicationTestNetwork_Dict): def forward(self, *args, **kwargs) -> EltwiseMultiplication_namedtuple: outputs_dicts = super().forward(*args, **kwargs) - return EltwiseMultiplication_namedtuple(hidden1=outputs_dicts['hidden1'], - hidden2=outputs_dicts['hidden2'], - y_hat=outputs_dicts['y_hat'], - loss=outputs_dicts['loss']) + return EltwiseMultiplication_namedtuple( + hidden1=outputs_dicts["hidden1"], + hidden2=outputs_dicts["hidden2"], + y_hat=outputs_dicts["y_hat"], + loss=outputs_dicts["loss"], + ) @staticmethod def to_dict(outputs: EltwiseMultiplicationNamedTuple) -> Dict[str, Tensor]: @@ -540,7 +650,12 @@ class EltwiseMultiplicationTestNetwork_Tuple(EltwiseMultiplicationTestNetwork_Di def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]: outputs_dicts = super().forward(*args, **kwargs) - return (outputs_dicts['hidden1'], outputs_dicts['hidden2'], outputs_dicts['y_hat'], outputs_dicts['loss']) + return ( + outputs_dicts["hidden1"], + outputs_dicts["hidden2"], + outputs_dicts["y_hat"], + outputs_dicts["loss"], + ) @staticmethod def to_dict(outputs: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Dict[str, Tensor]: @@ -556,7 +671,12 @@ class EltwiseMultiplicationTestNetwork_List(EltwiseMultiplicationTestNetwork_Dic def forward(self, *args, **kwargs) -> List[Tensor]: outputs_dicts = super().forward(*args, **kwargs) - return [outputs_dicts['hidden1'], outputs_dicts['hidden2'], outputs_dicts['y_hat'], outputs_dicts['loss']] + return [ + outputs_dicts["hidden1"], + outputs_dicts["hidden2"], + outputs_dicts["y_hat"], + outputs_dicts["loss"], + ] @staticmethod def to_dict(outputs: List[Tensor]) -> Dict[str, Tensor]: @@ -568,29 +688,57 @@ def to_dict(outputs: List[Tensor]) -> Dict[str, Tensor]: } -@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) -@pytest.mark.parametrize("fp16_enabled", [True, False]) -@pytest.mark.parametrize("contiguous_gradients", [True, False]) -@pytest.mark.parametrize("offload_optimizer", [True, False]) -@pytest.mark.parametrize("zero_grad", [True, False]) -@pytest.mark.parametrize("prefetching", [True, False]) -@pytest.mark.parametrize("model_class", [ - EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple, - EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple, - EltwiseMultiplicationTestNetwork_List -]) class TestZero3ParamPartitioningBase(DistributedTest): world_size = 2 - def test( + @pytest.mark.parametrize("param_persistence_threshold", [0, 10]) + def test_param_persistence_threshold(self, param_persistence_threshold): + self._test(param_persistence_threshold=param_persistence_threshold) + + @pytest.mark.parametrize("fp16_enabled", [True, False]) + def test_fp16_enabled(self, fp16_enabled): + if not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + self._test(fp16_enabled=fp16_enabled) + + @pytest.mark.parametrize("contiguous_gradients", [True, False]) + def test_contiguous_gradients(self, contiguous_gradients): + self._test(contiguous_gradients=contiguous_gradients) + + @pytest.mark.parametrize("offload_optimizer", [True, False]) + def test_offload_optimizer(self, offload_optimizer): + self._test(offload_optimizer=offload_optimizer) + + @pytest.mark.parametrize("zero_grad", [True, False]) + def test_zero_grad(self, zero_grad): + self._test(zero_grad=zero_grad) + + @pytest.mark.parametrize("prefetching", [True, False]) + def test_prefetching(self, prefetching): + self._test(prefetching=prefetching) + + @pytest.mark.parametrize("reduce_scatter", [True, False]) + def test_reduce_scatter(self, reduce_scatter): + self._test(reduce_scatter=reduce_scatter) + + @pytest.mark.parametrize("model_class", [ + EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple, + EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple, + EltwiseMultiplicationTestNetwork_List + ]) + def test_model_class(self, model_class): + self._test(model_class=model_class) + + def _test( self, - param_persistence_threshold: int, - fp16_enabled: bool, - contiguous_gradients: bool, - offload_optimizer: bool, - zero_grad: bool, - prefetching: bool, - model_class: EltwiseMultiplicationTestNetwork_Dict, + param_persistence_threshold: int = 0, + fp16_enabled: bool = False, + contiguous_gradients: bool = False, + offload_optimizer: bool = False, + zero_grad: bool = False, + prefetching: bool = False, + reduce_scatter: bool = False, + model_class: EltwiseMultiplicationTestNetwork_Dict = EltwiseMultiplicationTestNetwork_Dict, ) -> None: if offload_optimizer and not contiguous_gradients: return @@ -600,41 +748,45 @@ def test( weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] model = model_class(*weights) prefetch_bucket_size = sum([p.numel() for p in model.parameters(recurse=True)]) - cfg = { + config_dict = { "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 3, "stage3_max_reuse_distance": 0, "stage3_param_persistence_threshold": param_persistence_threshold, "contiguous_gradients": contiguous_gradients, - "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0 + "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0, + "reduce_scatter": reduce_scatter, }, "optimizer": { "type": "Adam", "params": { - "lr": 1. + "lr": 1.0 } }, - "fp16": { - "enabled": fp16_enabled, - "loss_scale": 1., - } } + if fp16_enabled: + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + if offload_optimizer: - cfg["zero_optimization"]["offload_optimizer"] = { + config_dict["zero_optimization"]["offload_optimizer"] = { "device": "cpu", "pin_memory": True, } - ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + ds_engine = _ds_initialize_for_param_partitioning_testing(model, config_dict) for i, weight in enumerate(weights): weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, (i + 1) * (1 + dist.get_rank())) def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: - return torch.as_tensor(vals, - dtype=dtype or (torch.float16 if fp16_enabled else torch.float32), - device=ds_engine.device) + return torch.as_tensor( + vals, + dtype=dtype or (torch.float16 if fp16_enabled else torch.float32), + device=ds_engine.device, + ) expected_hidden1 = create_tensor([ [1, 1, 1, 1, 1], @@ -655,8 +807,16 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: for train_iter in range(3): activations = ds_engine( - x=torch.ones((m, n), dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device), - y=torch.ones((m, n), dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device), + x=torch.ones( + (m, n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device, + ), + y=torch.ones( + (m, n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device, + ), use_module_trace=train_iter > 0, param_prefetching=prefetching and train_iter > 0, ) @@ -691,21 +851,33 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: grad_multiplier = 1 if zero_grad else (train_iter + 1) if dist.get_rank() == 0: - assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([2] * 8, torch.float)) - assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([3 * 1] * 8, torch.float)) - assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([3 * 2 * 1] * 8, torch.float)) + assert torch.allclose( + dloss_wrt_layer3.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([2] * 8, torch.float), + ) + assert torch.allclose( + dloss_wrt_layer2.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([3 * 1] * 8, torch.float), + ) + assert torch.allclose( + dloss_wrt_layer1.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([3 * 2 * 1] * 8, torch.float), + ) elif dist.get_rank() == 1: # parameters dont split evenly across ranks so rank 1 has a zero-padded # partition - assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([8] * 7) + [0], torch.float)) - assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([6 * 2] * 7) + [0], torch.float)) - assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], torch.float)) + assert torch.allclose( + dloss_wrt_layer3.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([8] * 7) + [0], torch.float), + ) + assert torch.allclose( + dloss_wrt_layer2.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([6 * 2] * 7) + [0], torch.float), + ) + assert torch.allclose( + dloss_wrt_layer1.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], torch.float), + ) else: raise RuntimeError("test has world size of two") @@ -720,12 +892,15 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) + ds_engine.destroy() + @pytest.mark.parametrize("init_context_manager", [True, False]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) class TestZero3ParamPartitioningLargeParam(DistributedTest): - world_size = 4 + world_size = 2 - def test(self, init_context_manager: bool, param_sz: int = 8100) -> None: + def test(self, init_context_manager: bool, reduce_scatter: bool, param_sz: int = 8100) -> None: class LargeParamModel(Module): @@ -746,28 +921,31 @@ def __init__(self): def forward(self, x: Tensor) -> Tensor: return x * self.param - ds_config = { + config_dict = { "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 3, "stage3_max_reuse_distance": 0, "contiguous_gradients": True, "overlap_comm": True, + "reduce_scatter": reduce_scatter, }, "optimizer": { "type": "Adam", "params": { - "lr": 1. + "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } } - with deepspeed.zero.Init(mem_efficient_linear=False, enabled=init_context_manager): + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + with deepspeed.zero.Init(mem_efficient_linear=False, + enabled=init_context_manager, + config_dict_or_path=config_dict): model = LargeParamModel() - ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) + ds_engine = _ds_initialize_for_param_partitioning_testing(model, config_dict) for train_iter in range(3): # test multiple iterations to cover prefetching activation: Tensor = ds_engine(torch.ones(param_sz, dtype=torch.float16, device=ds_engine.device)) @@ -775,26 +953,29 @@ def forward(self, x: Tensor) -> Tensor: partition_sz = math.ceil(param_sz / self.world_size) for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)): activation_from_partition = activation[start_idx:start_idx + partition_sz] - assert torch.allclose(activation_from_partition, torch.full_like(activation_from_partition, rank_idx)) + assert torch.allclose( + activation_from_partition, + torch.full_like(activation_from_partition, rank_idx), + ) ds_engine.backward(activation.sum()) ds_engine.allreduce_gradients() avgd_gradients = ds_engine.optimizer.averaged_gradients assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" - weight_gradient, = avgd_gradients[0] + (weight_gradient, ) = avgd_gradients[0] expected_weight_gradient = (train_iter + 1) * torch.full_like(weight_gradient, 1) assert torch.allclose(weight_gradient, expected_weight_gradient) + ds_engine.destroy() + -@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000]) -@pytest.mark.parametrize("n_layers", [100, 1_000]) @pytest.mark.parametrize("init_context_manager", [True, False]) class TestZero3ParamPartitioningManyParams(DistributedTest): - world_size = 4 + world_size = 2 - def test(self, param_sz: int, n_layers: int, init_context_manager: bool) -> None: + def test(self, init_context_manager: bool, param_sz: int = 100, n_layers: int = 100) -> None: class ManyParamModel(Module): @@ -824,7 +1005,7 @@ def forward(self, x: Tensor) -> Tensor: return activations - ds_cfg = { + config_dict = { "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 3, @@ -835,27 +1016,29 @@ def forward(self, x: Tensor) -> Tensor: "optimizer": { "type": "Adam", "params": { - "lr": 1. + "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } } - - with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=init_context_manager): + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} + + with deepspeed.zero.Init(config_dict_or_path=config_dict, + mem_efficient_linear=False, + enabled=init_context_manager): model = ManyParamModel() - ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) + ds_engine = _ds_initialize_for_param_partitioning_testing(model, config_dict) + dtype = preferred_dtype() for _ in range(3): # test multiple iterations to cover prefetching - activations: List[Tensor] = ds_engine( - torch.ones((param_sz, ), dtype=torch.float16, device=ds_engine.device)) + activations: List[Tensor] = ds_engine(torch.ones((param_sz, ), dtype=dtype, device=ds_engine.device)) assert len(activations) == n_layers partition_sz = math.ceil(param_sz / self.world_size) - expected_activations = torch.empty(param_sz, dtype=torch.float16, device=ds_engine.device) + expected_activations = torch.empty(param_sz, dtype=dtype, device=ds_engine.device) for start_idx in range(0, param_sz, partition_sz): expected_activations[start_idx:start_idx + partition_sz] = dist.get_rank() @@ -873,9 +1056,11 @@ def forward(self, x: Tensor) -> Tensor: for layer_num, activation in enumerate(weight_gradients): pass + ds_engine.destroy() + class TestZero3InitForParentWeightInitialization(DistributedTest): - world_size = 4 + world_size = 2 def test(self): @@ -893,7 +1078,7 @@ def __init_weights(self, module): with torch.no_grad(): module.weight.fill_(1 + dist.get_rank()) - ds_cfg = { + config_dict = { "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 3, @@ -904,38 +1089,59 @@ def __init_weights(self, module): "optimizer": { "type": "Adam", "params": { - "lr": 1. + "lr": 1.0 } }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0} - with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=True): + with deepspeed.zero.Init(config_dict_or_path=config_dict, mem_efficient_linear=False, enabled=True): model = ModelWhereParentInitializesChildWeights() assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / self.world_size) - assert torch.allclose(model.linear.weight.ds_tensor, torch.full_like(model.linear.weight.ds_tensor, 1)) + assert torch.allclose( + model.linear.weight.ds_tensor, + torch.full_like(model.linear.weight.ds_tensor, 1), + ) -@pytest.mark.skip("not working") +""" @pytest.mark.parametrize("param_persistence_threshold", [0, 10]) @pytest.mark.parametrize("contiguous_gradients", [True, False]) @pytest.mark.parametrize("offload_optimizer", [True, False]) @pytest.mark.parametrize("zero_grad", [True, False]) @pytest.mark.parametrize("prefetching", [True, False]) -@pytest.mark.parametrize("model_class", [ - EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple, - EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple, - EltwiseMultiplicationTestNetwork_List -]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) +@pytest.mark.parametrize( + "model_class", + [ + EltwiseMultiplicationTestNetwork_Dict, + EltwiseMultiplicationTestNetwork_NamedTuple, + EltwiseMultiplicationTestNetwork_namedtuple, + EltwiseMultiplicationTestNetwork_Tuple, + EltwiseMultiplicationTestNetwork_List, + ], +) +""" + + +@pytest.mark.skip("not working") class TestZero3ParamPartitioningBaseBF16(DistributedTest): world_size = 2 - def test(self, param_persistence_threshold: int, contiguous_gradients: bool, offload_optimizer: bool, - zero_grad: bool, prefetching: bool, model_class: EltwiseMultiplicationTestNetwork_Dict) -> None: + def test( + self, + param_persistence_threshold: int, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + prefetching: bool, + reduce_scatter: bool, + model_class: EltwiseMultiplicationTestNetwork_Dict, + ) -> None: if offload_optimizer and not contiguous_gradients: return @@ -951,18 +1157,19 @@ def test(self, param_persistence_threshold: int, contiguous_gradients: bool, off "stage3_max_reuse_distance": 0, "stage3_param_persistence_threshold": param_persistence_threshold, "contiguous_gradients": contiguous_gradients, - "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0 + "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0, + "reduce_scatter": reduce_scatter, }, "optimizer": { "type": "Adam", "params": { - "lr": 1. + "lr": 1.0 } }, "bf16": { "enabled": True, - "loss_scale": 1., - } + "loss_scale": 1.0, + }, } if offload_optimizer: @@ -1033,21 +1240,33 @@ def create_tensor(vals): grad_multiplier = 1 if zero_grad else (train_iter + 1) if dist.get_rank() == 0: - assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) - assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) - assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()), - grad_multiplier * create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer3.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype), + ) + assert torch.allclose( + dloss_wrt_layer2.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype), + ) + assert torch.allclose( + dloss_wrt_layer1.to(get_accelerator().device_name()), + grad_multiplier * create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype), + ) elif dist.get_rank() == 1: # parameters dont split evenly across ranks so rank 1 has a zero-padded # partition - assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) - assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) - assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()), - grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer3.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([8] * 7) + [0]).to(expected_grad_dtype), + ) + assert torch.allclose( + dloss_wrt_layer2.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype), + ) + assert torch.allclose( + dloss_wrt_layer1.to(get_accelerator().device_name()), + grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype), + ) else: raise RuntimeError("test has world size of two") @@ -1059,6 +1278,103 @@ def create_tensor(vals): ds_engine.optimizer.step() _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + ds_engine.destroy() + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +class TestParamPartitioningSkipInit(DistributedTest): + world_size = 2 + + def test(self, dtype): + + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip("{dtype} is not supported") + + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 3 + }, + } + + if dtype == torch.bfloat16: + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + else: + pytest.skip("bfloat16 is not supported on this accelerator") + elif dtype == torch.float16: + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + else: + pytest.skip("fp16 is not supported on this accelerator") + hidden_dim = 10 + + class SubModel(torch.nn.Module): + + def __init__(self, input_size, output_size, dropout_prob=0.5, device=None): + super(SubModel, self).__init__() + self.linear = torch.nn.Linear(input_size, output_size, device=device) + self.dropout = torch.nn.Dropout(dropout_prob) + self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)]) + + def forward(self, x): + x = self.linear(x) + x = self.dropout(x) + x = self.module_list[0](x) + return x + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = skip_init(Linear, hidden_dim, hidden_dim) + self.l2 = skip_init(SubModel, hidden_dim, hidden_dim) + self.l3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cel = torch.nn.CrossEntropyLoss() + self.l4 = skip_init(SubModel, hidden_dim, hidden_dim) + + def forward(self, x, y): + x = self.l1(x) + x = self.l2(x) + x = self.l3(x) + x = self.l4(x) + loss = self.cel(x, y) + val = [x, loss] + return val + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + world_size = dist.get_world_size() + ds_tensor_numel = math.ceil(hidden_dim * hidden_dim / world_size) + assert model.l1.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + assert model.l3.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step() + + model.destroy() + class TestZeroOffloadStage1(DistributedTest): world_size = 2 @@ -1074,16 +1390,17 @@ def test(self): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 1, "offload_optimizer": { "device": "cpu" } - } + }, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -1095,12 +1412,16 @@ def test(self): model.backward(loss) model.step() + model.destroy() + -@pytest.mark.parametrize('return_type', [tuple, list, dict]) +@pytest.mark.parametrize("return_type", [tuple, list, dict]) class TestZero3DictFwd(DistributedTest): world_size = 1 def test(self, return_type): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -1110,13 +1431,14 @@ def test(self, return_type): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 3 - } + }, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 class MyModel(torch.nn.Module): @@ -1130,7 +1452,7 @@ def forward(self, x, y): x = self.l1(x) loss = self.cel(x, y) if return_type == dict: - val = {'a': x, 'loss': loss, 'b': 1, 'c': None} + val = {"a": x, "loss": loss, "b": 1, "c": None} elif return_type == list: val = [x, loss] elif return_type == tuple: @@ -1139,7 +1461,7 @@ def forward(self, x, y): raise NotImplementedError return val - with deepspeed.zero.Init(): + with deepspeed.zero.Init(config=config_dict): model = MyModel(hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) @@ -1148,18 +1470,25 @@ def forward(self, x, y): for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if return_type == dict: - loss = loss['loss'] + loss = loss["loss"] else: loss = loss[1] model.backward(loss) model.step() + model.destroy() + -@pytest.mark.parametrize('zero_stage', [1, 2, 3]) +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) class TestZeroAdamOptimizerStepCount(DistributedTest): world_size = 1 def test(self, zero_stage): + # We verify trhee conditions: + # 1. global_steps starts at 0 + # 2. All subgroups have the same step count + # 3. The global step count is the same as the step count of the first subgroup + # force all params to be partitioned by forcing threshold=0 config_dict = { "train_micro_batch_size_per_gpu": 2, @@ -1176,11 +1505,11 @@ def test(self, zero_stage): "lr": 1e-3 } }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim, nlayers=12) @@ -1189,30 +1518,40 @@ def test(self, zero_stage): model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) - for i, batch in enumerate(data_loader): + assert model.global_steps == 0 + + for batch in data_loader: loss = model(batch[0], batch[1]) model.backward(loss) + + is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary() model.step() - step_counts = [] - if zero_stage == 3: - for sub_group_id, _ in enumerate(optimizer.fp16_groups): - fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] - state = optimizer.optimizer.state[fp32_param] - step_counts.append(state['step']) - assert all(step == step_counts[0] for step in step_counts) - elif zero_stage == 1 or zero_stage == 2: - for param_group in optimizer.optimizer.param_groups: - for param in param_group['params']: - state = optimizer.optimizer.state[param] - step_counts.append(state['step']) + if is_gradient_accumulation_boundary: + step_counts = [] + + if zero_stage == 3: + for sub_group_id, _ in enumerate(optimizer.fp16_groups): + fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] + state = optimizer.optimizer.state[fp32_param] + step_counts.append(state["step"]) + elif zero_stage == 1 or zero_stage == 2: + for param_group in optimizer.optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.optimizer.state[param] + step_counts.append(state["step"]) + assert all(step == step_counts[0] for step in step_counts) + assert model.global_steps == step_counts[0] + + model.destroy() +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) class TestZeroFrozenWeights(DistributedTest): - world_size = 1 + world_size = 2 - def test(self): + def test(self, zero_stage): config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -1222,13 +1561,14 @@ def test(self): "lr": 1e-4 } }, - "fp16": { - "enabled": True - }, "zero_optimization": { - "stage": 3 - } + "stage": zero_stage + }, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 class MyModel(torch.nn.Module): @@ -1252,7 +1592,7 @@ def forward(self, x, y): val = (x, loss) return val - with deepspeed.zero.Init(config_dict_or_path=config_dict): + with deepspeed.zero.Init(config_dict_or_path=config_dict, enabled=zero_stage == 3): model = MyModel(hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) @@ -1264,8 +1604,10 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + -@pytest.mark.parametrize('force_ds_optim', [True, False]) +@pytest.mark.parametrize("force_ds_optim", [True, False]) class TestZeroOffloadOptim(DistributedTest): world_size = 1 @@ -1274,9 +1616,6 @@ def test(self, force_ds_optim): "train_batch_size": 4, "gradient_accumulation_steps": 2, "steps_per_print": 1, - "fp16": { - "enabled": True - }, "zero_optimization": { "stage": 1, "offload_optimizer": { @@ -1285,6 +1624,10 @@ def test(self, force_ds_optim): }, "zero_force_ds_cpu_optimizer": force_ds_optim, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -1298,7 +1641,7 @@ def test(self, force_ds_optim): model, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict) -@pytest.mark.parametrize('training', [True, False]) +@pytest.mark.parametrize("training", [True, False]) class TestZeroPartitionCache(DistributedTest): world_size = 1 @@ -1306,15 +1649,15 @@ def test_training_partition_cache(self, training): hidden_dim = 10 config_dict = { "train_batch_size": 2, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - }, "zero_optimization": { "stage": 3, - "stage3_param_persistence_threshold": hidden_dim - } + "stage3_param_persistence_threshold": hidden_dim, + }, } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} if training: config_dict["optimizer"] = {"type": "Adam"} @@ -1323,12 +1666,12 @@ def test_training_partition_cache(self, training): model, _, _, _ = deepspeed.initialize(model=model, config=config_dict) - dtype = torch.half - data_loader = random_dataloader(model=model, - total_samples=6, - hidden_dim=hidden_dim, - device=model.device, - dtype=dtype) + data_loader = random_dataloader( + model=model, + total_samples=6, + hidden_dim=hidden_dim, + device=model.device, + ) for _, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) @@ -1342,3 +1685,141 @@ def test_training_partition_cache(self, training): model.empty_partition_cache() assert sum([p.numel() for p in model.parameters()]) == 0 + + model.destroy() + + +@pytest.mark.parametrize("use_client_optimizer", [True, False]) +@pytest.mark.parametrize("empty_weight_group", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +class TestEmptyParameterGroup(DistributedTest): + world_size = 1 + + def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_group): + if dtype == torch.float16 and not get_accelerator().is_fp16_supported(): + pytest.skip("fp16 is not supported") + model = SimpleModel(hidden_dim=4, nlayers=4) + param_groups = [ + { + "params": [] if empty_weight_group else [l.weight for l in model.linears], + "weight_decay": 0.01, + }, + { + "params": [l.bias for l in model.linears] if empty_weight_group else [], + "weight_decay": 0.0 + }, + ] + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + }, + "fp16": { + "enabled": dtype == torch.float16, + }, + "bf16": { + "enabled": dtype == torch.bfloat16 + } + } + + if use_client_optimizer: + optimizer = torch.optim.AdamW(param_groups, lr=0.1) + model_parameters = model.parameters() + else: + config_dict["optimizer"] = {"type": "adamw"} + optimizer = None + model_parameters = param_groups + + model, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model_parameters, + optimizer=optimizer, + config=config_dict, + ) + + model.destroy() + + +class TestZero3SwitchModes(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0]) + def test(self, prefetch_ratio, zero_stage=3): + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio) + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "zero_optimization": { + "stage": zero_stage, + "stage3_prefetch_bucket_size": prefetch_bucket_size + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) + + for _ in range(3): + model.train() + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.eval() + with torch.no_grad(): + for batch in data_loader: + loss = model(batch[0], batch[1]) + + model.destroy() + + +# Avoid overwriting client module id +# https://github.com/deepspeedai/DeepSpeed/issues/6772 +class TestZero3ClientModuleID(DistributedTest): + world_size = 2 + + def test_client_module_id(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + }, + "zero_optimization": { + "stage": 3 + }, + } + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.id = 3 # ID arbitrary client usage, e.g. GPU placement + self.fc = Linear(128, 128) + + def forward(self, x): + return self.fc(x) + + model = MyModel() + pre_init_m_id = model.id + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + post_init_m_id = model.id + assert pre_init_m_id == post_init_m_id + model.destroy() diff --git a/tests/unit/v1/zero/test_zero2_offload_multi_backward.py b/tests/unit/v1/zero/test_zero2_offload_multi_backward.py new file mode 100644 index 000000000000..0c62b98630d7 --- /dev/null +++ b/tests/unit/v1/zero/test_zero2_offload_multi_backward.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Regression tests for ZeRO-1/2 + cpu_offload with multiple engine.backward() +calls per optimizer step (ga_steps=1, driven via set_gradient_accumulation_boundary). +""" + +import pytest +import torch +import deepspeed + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator + + +def _base_config(zero_stage, gradient_accumulation_steps=1, cpu_offload=False): + config_dict = { + "train_batch_size": gradient_accumulation_steps, + "gradient_accumulation_steps": gradient_accumulation_steps, + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "zero_force_ds_cpu_optimizer": False, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3, + }, + }, + } + if cpu_offload: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"} + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + return config_dict + + +def _init_engine(config_dict, hidden_dim, seed=42): + torch.manual_seed(seed) + model = SimpleModel(hidden_dim, nlayers=2) + engine, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=config_dict, + ) + return engine + + +def _capture_params(engine): + return {name: p.detach().float().cpu().clone() for name, p in engine.module.named_parameters()} + + +def _assert_params_match(ref, test, label, tol=5e-5): + for name in ref: + max_diff = (ref[name] - test[name]).abs().max().item() + assert max_diff < tol, f"{label}: {name} differs by {max_diff:.3e}" + + +def _run_multi_backward(config_dict, hidden_dim, num_chunks, num_steps=1, seed=42): + engine = _init_engine(config_dict, hidden_dim, seed=seed) + data_loader = random_dataloader( + model=engine, + total_samples=num_chunks * num_steps, + hidden_dim=hidden_dim, + device=engine.device, + ) + batches = list(data_loader) + for step_idx in range(num_steps): + step_batches = batches[step_idx * num_chunks:(step_idx + 1) * num_chunks] + for i, batch in enumerate(step_batches): + loss = engine(batch[0], batch[1]) + engine.set_gradient_accumulation_boundary(i == num_chunks - 1) + engine.backward(loss) + engine.step() + params = _capture_params(engine) + engine.destroy() + return params + + +def _run_ga_microsteps(config_dict, hidden_dim, total_microsteps, seed=42): + engine = _init_engine(config_dict, hidden_dim, seed=seed) + data_loader = random_dataloader( + model=engine, + total_samples=total_microsteps, + hidden_dim=hidden_dim, + device=engine.device, + ) + for batch in data_loader: + loss = engine(batch[0], batch[1]) + engine.backward(loss) + engine.step() + params = _capture_params(engine) + engine.destroy() + return params + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +class TestZeroOffloadMultiBackward(DistributedTest): + world_size = 1 + + def test_multi_backward_matches_no_offload(self, zero_stage): + hidden_dim = 8 + num_chunks = 4 + ref = _run_multi_backward(_base_config(zero_stage, cpu_offload=False), hidden_dim, num_chunks) + test = _run_multi_backward(_base_config(zero_stage, cpu_offload=True), hidden_dim, num_chunks) + _assert_params_match(ref, test, label=f"ZeRO-{zero_stage} N=4") + + def test_single_backward_unchanged(self, zero_stage): + hidden_dim = 8 + ref = _run_multi_backward(_base_config(zero_stage, cpu_offload=False), hidden_dim, num_chunks=1) + test = _run_multi_backward(_base_config(zero_stage, cpu_offload=True), hidden_dim, num_chunks=1) + _assert_params_match(ref, test, label=f"ZeRO-{zero_stage} N=1") + + def test_multi_backward_across_multiple_steps(self, zero_stage): + hidden_dim = 8 + ref = _run_multi_backward(_base_config(zero_stage, cpu_offload=False), hidden_dim, num_chunks=3, num_steps=3) + test = _run_multi_backward(_base_config(zero_stage, cpu_offload=True), hidden_dim, num_chunks=3, num_steps=3) + _assert_params_match(ref, test, label=f"ZeRO-{zero_stage} 3x3") + + def test_single_backward_allocates_no_cpu_accumulator(self, zero_stage): + hidden_dim = 8 + engine = _init_engine(_base_config(zero_stage, cpu_offload=True), hidden_dim) + batch = next( + iter(random_dataloader(model=engine, total_samples=1, hidden_dim=hidden_dim, device=engine.device))) + loss = engine(batch[0], batch[1]) + engine.set_gradient_accumulation_boundary(True) + engine.backward(loss) + engine.step() + populated = len(engine.optimizer.accumulated_grads_in_cpu) + engine.destroy() + assert populated == 0, f"ZeRO-{zero_stage}: ga=1+N=1 populated accumulated_grads_in_cpu ({populated} entries)" + + def test_ga_greater_than_one_offload_unchanged(self, zero_stage): + hidden_dim = 8 + ga = 4 + ref = _run_ga_microsteps(_base_config(zero_stage, gradient_accumulation_steps=ga, cpu_offload=False), + hidden_dim, + total_microsteps=ga) + test = _run_ga_microsteps(_base_config(zero_stage, gradient_accumulation_steps=ga, cpu_offload=True), + hidden_dim, + total_microsteps=ga) + _assert_params_match(ref, test, label=f"ZeRO-{zero_stage} ga=4") diff --git a/tests/unit/v1/zero/test_zero_autocast.py b/tests/unit/v1/zero/test_zero_autocast.py new file mode 100644 index 000000000000..e7cb0059ce35 --- /dev/null +++ b/tests/unit/v1/zero/test_zero_autocast.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from copy import deepcopy + +import pytest + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +from unit.common import DistributedTest, enable_determinism, allclose_on_all_ranks +from unit.simple_model import SimpleModel +from unit.util import bf16_required_version_check + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import GatheredParameters +from deepspeed.runtime.torch_autocast import PARAM_COMM_DTYPE_ATTR_NAME, get_comm_dtype + + +def cls_to_qualname(cls): + return f"{cls.__module__}.{cls.__name__}" + + +class SimpleModelWithLayerNorm(torch.nn.Module): + + def __init__(self, hidden_dim, nlayers=1): + super(SimpleModelWithLayerNorm, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)]) + self.norm = torch.nn.LayerNorm(hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + x = self.linears[0](x) + x = self.norm(x) + return self.cross_entropy_loss(x, y) + + +def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside, + baseline_scaler, step, x, y, expect_match): + device_type = get_accelerator().device_name() + + # Runs the forward pass with autocasting. + with torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled): + baseline_optimizer.zero_grad() + baseline_loss = baseline_model(x, y) + + baseline_scaler.scale(baseline_loss).backward() + baseline_scaler.step(baseline_optimizer) + baseline_scaler.update() + + # We don't need torch.autocast here in real applications, but want to test the behavior of nested autocast. + with torch.autocast(device_type=device_type, dtype=dtype, enabled=enable_autocast_outside): + target_loss = target_engine(x, y) + + # reduce-scatter in `dtype` makes a difference in the loss. + if step <= 1 and expect_match: + allclose_on_all_ranks(baseline_loss, target_loss) + + target_engine.backward(target_loss) + target_engine.step() + + +@enable_determinism(123) +def compare_loss(model_cls, + enable, + zero_stage, + model_dtype, + dtype, + autocast_conf, + enable_autocast_outside, + lower_precision_safe_modules, + expect_match=True): + iteration = 5 + hidden_dim = 10 + lr = 0.001 + + if dtype == torch.bfloat16 and not bf16_required_version_check(): + raise ValueError( + "DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "torch_autocast": autocast_conf, + } + + model = model_cls(hidden_dim) + model.to(model_dtype) + + deepspeed.init_distributed(dist_backend='nccl') + + i = get_accelerator().current_device() + device = get_accelerator().current_device_name() + baseline_model = DDP(deepcopy(model).to(device=device, dtype=torch.float32), device_ids=[i], output_device=i) + baseline_optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=lr, weight_decay=0.0) + baseline_scaler = torch.amp.GradScaler() + + stage_3_enabled = config_dict["zero_optimization"]["stage"] == 3 + if stage_3_enabled: + # Trick to avoid conversion to fp32 in Init() while also avoiding deepspeed's mixed precision + # Ideally Init() should have a flag to avoid conversion to fp32 + import copy + config_for_init = copy.deepcopy(config_dict) + if model_dtype == torch.float16: + config_for_init["fp16"] = {"enabled": True} + elif model_dtype == torch.bfloat16: + config_for_init["bf16"] = {"enabled": True} + + with deepspeed.zero.Init(config_dict_or_path=config_for_init): + target_model = model_cls(hidden_dim) + with GatheredParameters(target_model.parameters(), modifier_rank=0): + for p1, p2 in zip(target_model.parameters(), model.parameters()): + p1.data.copy_(p2.data) + else: + target_model = deepcopy(model) + + ds_optimizer = torch.optim.Adam(target_model.parameters(), lr=lr) + target_engine, _, _, _ = deepspeed.initialize(config=config_dict, model=target_model, optimizer=ds_optimizer) + train_batch_size = config_dict["train_micro_batch_size_per_gpu"] + + xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=torch.float32) for _ in range(iteration)] + ys = [torch.randn_like(x) for x in xs] + + for i, (x, y) in enumerate(zip(xs, ys)): + step_amp(enable, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside, + baseline_scaler, i, x, y, expect_match) + + for module in target_engine.modules(): + for p in module.parameters(recurse=False): + if module.__class__ in lower_precision_safe_modules and autocast_conf["enabled"]: + assert hasattr( + p, PARAM_COMM_DTYPE_ATTR_NAME + ), f"A module is in the lower precision safe list, but param does not have autocast_dtype: {module.__class__.__name__}" + assert get_comm_dtype( + p + ) == dtype, f"dtype of a module in the lower precision safe list is not set to {dtype}: {module.__class__.__name__}" + else: + assert not hasattr( + p, PARAM_COMM_DTYPE_ATTR_NAME + ), f"A module is not in the lower precision safe list, but param has autocast_dtype: {module.__class__.__name__}" + assert get_comm_dtype( + p + ) == model_dtype, f"comm dtype doesn't match module dtype though the module is not in lower precision list" + target_engine.destroy() + + +@pytest.mark.parametrize("enable", [True]) +class TestZeroAutoCast(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test(self, enable, zero_stage, dtype): + lower_precision_safe_modules = [torch.nn.Linear] + autocast_conf = {"enabled": enable, "dtype": str(dtype)} + + compare_loss(SimpleModel, enable, zero_stage, torch.float32, dtype, autocast_conf, False, + lower_precision_safe_modules) + + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_safe_modules_conf(self, enable, zero_stage, dtype): + lower_precision_safe_modules = [torch.nn.Linear] + autocast_conf = { + "enabled": enable, + "dtype": str(dtype), + "lower_precision_safe_modules": [cls_to_qualname(cls) for cls in lower_precision_safe_modules] + } + + # The model has both lower precision safe and unsafe modules. + compare_loss(SimpleModelWithLayerNorm, enable, zero_stage, torch.float32, dtype, autocast_conf, False, + lower_precision_safe_modules) + + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + def test_nested_autocast(self, enable, zero_stage, dtype): + lower_precision_safe_modules = [torch.nn.Linear] + autocast_conf = { + "enabled": False, + "dtype": str(dtype), + } + + # torch.autocast is disabled in DeepSpeed engine + compare_loss(SimpleModelWithLayerNorm, + enable, + zero_stage, + torch.float32, + dtype, + autocast_conf, + True, + lower_precision_safe_modules, + expect_match=False) + + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + def test_lower_precision_model(self, enable, zero_stage, dtype): + lower_precision_safe_modules = [torch.nn.Linear] + autocast_conf = { + "enabled": enable, + "dtype": str(dtype), + } + + # Use the same dtype for model as autocast dtype + compare_loss(SimpleModelWithLayerNorm, enable, zero_stage, dtype, dtype, autocast_conf, True, + lower_precision_safe_modules, False) diff --git a/tests/unit/v1/zero/test_zero_cpu_offload_grad_accum.py b/tests/unit/v1/zero/test_zero_cpu_offload_grad_accum.py new file mode 100644 index 000000000000..7d12e8981c34 --- /dev/null +++ b/tests/unit/v1/zero/test_zero_cpu_offload_grad_accum.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Regression test for https://github.com/deepspeedai/DeepSpeed/pull/7967 + +In ZeRO-2 CPU-offload mode with gradient_accumulation_steps > 1, +`async_accumulate_grad_in_cpu_via_gpu` only copied gradients to CPU +when micro_step_id == 0. For micro_step_id > 0, gradients were +accumulated on GPU but never copied back to CPU, causing +accumulated_grads_in_cpu to stay frozen and the gradient norm to be +underestimated. +""" + +import torch +import deepspeed + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator + + +def _cpu_grad_norm(engine): + total_norm_sq = 0.0 + for grad in engine.optimizer.accumulated_grads_in_cpu.values(): + total_norm_sq += grad.float().norm(2).item()**2 + return total_norm_sq**0.5 + + +class TestZero2CPUOffloadGradAccumNorm(DistributedTest): + world_size = 1 + + def test(self): + gradient_accumulation_steps = 4 + hidden_dim = 10 + + config_dict = { + "train_batch_size": gradient_accumulation_steps, + "gradient_accumulation_steps": gradient_accumulation_steps, + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + }, + }, + "zero_force_ds_cpu_optimizer": False, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3, + }, + }, + } + + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + + torch.manual_seed(42) + model = SimpleModel(hidden_dim, nlayers=2) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + model, _, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=config_dict, + ) + + data_loader = random_dataloader( + model=model, + total_samples=gradient_accumulation_steps, + hidden_dim=hidden_dim, + device=model.device, + ) + + norms = [] + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + norms.append(_cpu_grad_norm(model)) + + model.destroy() + + assert norms[0] > 0, "accumulated_grads_in_cpu should be non-zero after first backward" + for i in range(1, len(norms)): + assert norms[i] != norms[0], (f"accumulated_grads_in_cpu norm did not change after micro-step {i}: " + f"norm[0]={norms[0]:.6f}, norm[{i}]={norms[i]:.6f}. " + "Gradients were not copied back to CPU (PR #7967 regression).") diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py new file mode 100644 index 000000000000..b0a1dfc2be6c --- /dev/null +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression: ZeRO-3 linear autograd.Function must work with torch.func transforms. + +ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as +the memory-efficient linear path. After ``deepspeed.initialize``, global +``torch.nn.functional.linear`` is often the built-in again, so tests call +``zero3_linear_wrap`` directly-the same ``autograd.Function`` as when the patch +is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward +raises on strict functorch builds:: + + RuntimeError: In order to use an autograd.Function with functorch + transforms ... it must override the setup_context staticmethod. +""" + +import pytest +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.linear import zero3_linear_wrap + +from unit.common import DistributedTest + + +def _zero3_functorch_config(): + config = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 2147483647, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + }, + }, + } + acc = get_accelerator() + if acc.is_bf16_supported(): + config["bf16"] = {"enabled": True} + elif acc.is_fp16_supported(): + config["fp16"] = {"enabled": True, "initial_scale_power": 8} + return config + + +class TestZeroFunctorchLinearRegression(DistributedTest): + """``torch.func.grad_and_value`` over ``zero3_linear_wrap`` / LinearFunctionForZeroStage3.""" + + world_size = 1 + + def test_grad_and_value_over_patched_functional_linear(self): + if not hasattr(torch, "func"): + pytest.skip("torch.func not available") + + model = nn.Linear(8, 8, bias=True) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + + device = engine.device + dtype = engine.module.weight.dtype + weight = torch.randn(8, 8, device=device, dtype=dtype, requires_grad=True) + inp = torch.randn(2, 8, device=device, dtype=dtype, requires_grad=True) + + with torch.enable_grad(): + probe = zero3_linear_wrap(inp, weight, None) + assert "LinearFunctionForZeroStage3" in type(probe.grad_fn).__name__ + + def loss_fn(w, x): + return zero3_linear_wrap(x, w, None).sum() + + grads, value = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) + assert torch.isfinite(value) + assert grads[0] is not None and torch.isfinite(grads[0]).all() + assert grads[1] is not None and torch.isfinite(grads[1]).all() + + +class TestZeroLinearAutocast(DistributedTest): + """Verify autocast state is correctly propagated through forward and backward.""" + + world_size = 1 + + def _run_forward_backward(self, device, use_autocast, dtype=None): + """Run zero3_linear_wrap forward+backward, optionally inside autocast.""" + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + bias = torch.randn(4, device=device, dtype=torch.float32, requires_grad=True) + + if use_autocast: + with torch.amp.autocast(device_type=device.type, dtype=dtype): + out = zero3_linear_wrap(inp, weight, bias) + else: + out = zero3_linear_wrap(inp, weight, bias) + + loss = out.sum() + loss.backward() + return out, weight.grad, inp.grad, bias.grad + + def test_backward_without_autocast(self): + """Backward without autocast should produce float32 gradients.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=False) + assert out.dtype == torch.float32 + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_backward_with_autocast(self): + """Backward with autocast should produce float32 gradients (autocast only affects forward).""" + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=True, dtype=amp_dtype) + # Forward output should be in reduced precision + assert out.dtype == amp_dtype + # Gradients accumulate in float32 (master weights) + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_no_autocast_leak_into_backward(self): + """When forward runs without autocast, an outer autocast during backward must not affect gradient dtype.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Forward WITHOUT autocast + out = zero3_linear_wrap(inp, weight, None) + assert out.dtype == torch.float32 + + # Backward WITH an outer autocast region -- should NOT affect gradient computation + # because setup_context captured _fwd_used_autocast=False + with torch.amp.autocast(device_type=device.type, dtype=amp_dtype): + out.sum().backward() + + assert weight.grad.dtype == torch.float32 + assert inp.grad.dtype == torch.float32 + + def test_setup_context_stores_autocast_attrs(self): + """setup_context must store _fwd_used_autocast and _dtype on ctx.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Without autocast: setup_context must record that forward did not use autocast + out = zero3_linear_wrap(inp, weight, None) + grad_fn = out.grad_fn + assert hasattr(grad_fn, "_fwd_used_autocast") + assert grad_fn._fwd_used_autocast is False + assert hasattr(grad_fn, "_dtype") + out.sum().backward() + assert torch.isfinite(weight.grad).all() + + +class TestLinearFunctionVmap(DistributedTest): + """``LinearFunctionForZeroStage3`` must accept ``torch.func.vmap`` directly.""" + + world_size = 1 + + def test_vmap_over_linear_function(self): + from deepspeed.runtime.zero.linear import LinearFunctionForZeroStage3 + device = get_accelerator().device_name() + weight = torch.randn(4, 8, device=device, requires_grad=True) + bias = torch.randn(4, device=device, requires_grad=True) + xs = torch.randn(3, 8, device=device) + y = torch.func.vmap(lambda xi: LinearFunctionForZeroStage3.apply(xi, weight, bias).sum())(xs) + ref = torch.func.vmap(lambda xi: (xi @ weight.t() + bias).sum())(xs) + assert torch.allclose(y, ref, atol=1e-5) diff --git a/tests/unit/v1/zero/test_zero_hook_count_regression.py b/tests/unit/v1/zero/test_zero_hook_count_regression.py new file mode 100644 index 000000000000..057afc458d2c --- /dev/null +++ b/tests/unit/v1/zero/test_zero_hook_count_regression.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression tests for count_used_parameters_in_backward() call count. + +Verifies fix for https://github.com/deepspeedai/DeepSpeed/issues/7885: +count_used_parameters_in_backward() was called once per gradient hook +(O(n) calls per backward) instead of once per backward phase (O(1) +for non-reentrant, O(p) for reentrant with p phases). +""" + +import pytest +import torch +from unittest.mock import patch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +def get_config_dict(zero_stage): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + + if zero_stage == 3: + config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + return config_dict + + +class TestHookCountRegression(DistributedTest): + """Test that count_used_parameters_in_backward is not called per-hook.""" + world_size = 2 + + @pytest.mark.parametrize("zero_stage", [2, 3]) + def test_non_reentrant_single_count_call(self, zero_stage): + """Non-reentrant backward should call count_used_parameters_in_backward exactly once.""" + hidden_dim = 16 + model = SimpleModel(hidden_dim) + config = get_config_dict(zero_stage) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + data_loader = random_dataloader(model=engine, total_samples=4, hidden_dim=hidden_dim, device=engine.device) + + # Determine the correct module path to patch based on stage + if zero_stage == 2: + patch_target = "deepspeed.runtime.zero.stage_1_and_2.count_used_parameters_in_backward" + else: + patch_target = "deepspeed.runtime.zero.stage3.count_used_parameters_in_backward" + + call_counts = [] + + for batch in data_loader: + with patch(patch_target, wraps=deepspeed.runtime.utils.count_used_parameters_in_backward) as mock_count: + loss = engine(batch[0], batch[1]) + engine.backward(loss) + call_counts.append(mock_count.call_count) + engine.step() + break + + # Non-reentrant: exactly 1 call per backward + assert call_counts[0] == 1, (f"Expected exactly 1 call to count_used_parameters_in_backward " + f"per backward, got {call_counts[0]}") + + @pytest.mark.parametrize("zero_stage", [2, 3]) + def test_training_step_succeeds_after_fix(self, zero_stage): + """Verify a full training step produces a finite loss after the caching fix.""" + hidden_dim = 16 + model = SimpleModel(hidden_dim) + config = get_config_dict(zero_stage) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + data_loader = random_dataloader(model=engine, total_samples=8, hidden_dim=hidden_dim, device=engine.device) + + losses = [] + for i, batch in enumerate(data_loader): + loss = engine(batch[0], batch[1]) + assert torch.isfinite(loss), f"Loss is not finite at step {i}: {loss.item()}" + losses.append(loss.item()) + engine.backward(loss) + engine.step() + if i >= 1: + break + + assert len(losses) >= 2, "Expected at least 2 training steps" diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py new file mode 100644 index 000000000000..b094c22c2fb5 --- /dev/null +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -0,0 +1,1434 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed.comm as dist +import deepspeed +from torch.nn.parallel import DistributedDataParallel as DDP + +from unit.common import DistributedTest, preferred_dtype, allclose_on_all_ranks +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import safe_get_full_grad + + +class SimpleNonScalarModel(torch.nn.Module): + """Model that returns non-scalar output for testing tensor.backward(grad)""" + + def __init__(self, hidden_dim): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + # Returns non-scalar output + x = self.linear1(x) + x = self.linear2(x) + return x + + +class SimpleOutputModel(torch.nn.Module): + """Model that returns output without computing loss""" + + def __init__(self, hidden_dim): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def get_config_dict(zero_stage, gradient_accumulation_steps=1): + """Helper to create config dict with common settings""" + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": gradient_accumulation_steps, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + + if zero_stage == 3: + # For ZeRO-3, force partitioning of all parameters + config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + return config_dict + + +def collect_gradients_safe(model): + """Collect gradients from model parameters using safe_get_full_grad API""" + grads = {} + for name, param in model.named_parameters(): + if param.requires_grad: + grad = safe_get_full_grad(param) + if grad is not None: + # Remove 'module.' prefix if present (DeepSpeed wraps the model) + clean_name = name.replace('module.', '') + grads[clean_name] = grad.detach().clone().cpu() + return grads + + +def initialize_distributed(): + deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) + device = get_accelerator().current_device_name() + rank = get_accelerator().current_device() + dtype = preferred_dtype() + return device, rank, dtype + + +def create_ddp_model(model_class, device, rank, dtype, seed=42, lr=1e-3, **model_kwargs): + torch.manual_seed(seed) + model = model_class(**model_kwargs) + model = model.to(device=device, dtype=dtype) + model = DDP(model, device_ids=[rank], output_device=rank) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + return model, optimizer + + +def create_deepspeed_engine(model_class, zero_stage, seed=42, gradient_accumulation_steps=1, **model_kwargs): + torch.manual_seed(seed) + model = model_class(**model_kwargs) + + config = get_config_dict(zero_stage, gradient_accumulation_steps=gradient_accumulation_steps) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + +def create_deepspeed_engine_from_model(model, zero_stage, gradient_accumulation_steps=1): + config = get_config_dict(zero_stage, gradient_accumulation_steps=gradient_accumulation_steps) + engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + return engine + + +def setup_models_and_engines(model_class, zero_stage, seed=42, lr=1e-3, gradient_accumulation_steps=1, **model_kwargs): + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model + model_ddp, optimizer_ddp = create_ddp_model(model_class, device, rank, dtype, seed=seed, lr=lr, **model_kwargs) + + # Create DeepSpeed engine + model_engine = create_deepspeed_engine(model_class, + zero_stage, + seed=seed, + gradient_accumulation_steps=gradient_accumulation_steps, + **model_kwargs) + + return model_ddp, optimizer_ddp, model_engine, device, dtype + + +def collect_ddp_gradients(model_ddp): + """Collect gradients from DDP model""" + grads = {} + for name, param in model_ddp.named_parameters(): + if param.grad is not None: + clean_name = name.replace('module.', '') + grads[clean_name] = param.grad.detach().clone().cpu() + return grads + + +def compare_gradients(grads_ddp, grads_ds, step_info=""): + """Compare gradients between DDP and DeepSpeed. + + Uses PyTorch's default tolerances for the tensor dtype (e.g., for bfloat16: + rtol=1.6e-2, atol=1e-5). The 2-layer model keeps differences small enough + to pass with default tolerances even after multiple optimizer steps. + """ + step_suffix = f" at {step_info}" if step_info else "" + assert len(grads_ddp) == len(grads_ds), \ + f"Different number of parameters with gradients{step_suffix}: DDP={len(grads_ddp)}, DeepSpeed={len(grads_ds)}" + + for name in grads_ddp.keys(): + assert name in grads_ds, f"Parameter {name} missing in DeepSpeed gradients{step_suffix}" + grad_ddp = grads_ddp[name] + grad_ds = grads_ds[name] + # If dtypes differ, convert ds to match ddp's dtype + if grad_ds.dtype != grad_ddp.dtype: + grad_ds = grad_ds.to(grad_ddp.dtype) + # Use PyTorch's default tolerances for the dtype + allclose_on_all_ranks(grad_ddp, grad_ds, assert_message=f"Gradients differ for parameter {name}{step_suffix}") + + +def collect_ddp_parameters(model_ddp): + """Collect parameters from DDP model""" + params = {} + for name, param in model_ddp.named_parameters(): + clean_name = name.replace('module.', '') + params[clean_name] = param.detach().clone().cpu() + return params + + +def collect_deepspeed_parameters(model_engine, zero_stage): + """Collect parameters from DeepSpeed engine (handles ZeRO-3 gathering)""" + params = {} + for name, param in model_engine.named_parameters(): + clean_name = name.replace('module.', '') + if zero_stage == 3: + with deepspeed.zero.GatheredParameters([param], modifier_rank=None): + params[clean_name] = param.detach().clone().cpu() + else: + params[clean_name] = param.detach().clone().cpu() + return params + + +def compare_parameters(params_ddp, params_ds, step_info=""): + """Compare parameters between DDP and DeepSpeed""" + step_suffix = f" at {step_info}" if step_info else "" + assert len(params_ddp) == len(params_ds), \ + f"Parameter count mismatch{step_suffix}: DDP={len(params_ddp)}, DeepSpeed={len(params_ds)}" + + for name in params_ddp.keys(): + assert name in params_ds, f"Parameter {name} missing in DeepSpeed model{step_suffix}" + # Convert to fp32 for comparison in case of dtype mismatch + params_ddp_fp32 = params_ddp[name].float() + params_ds_fp32 = params_ds[name].float() + allclose_on_all_ranks(params_ddp_fp32, + params_ds_fp32, + assert_message=f"Parameter {name} mismatch{step_suffix}") + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestZeroUserBackwardBasic(DistributedTest): + """Test basic functionality of user backward (loss.backward()) by comparing with PyTorch DDP""" + world_size = 2 + + def test_loss_backward_matches_ddp(self, zero_stage): + """Test that DeepSpeed loss.backward() produces same gradients as PyTorch DDP""" + hidden_dim = 4 + + # Create DDP and DeepSpeed models + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + data_loader = random_dataloader(model=model_engine, total_samples=8, hidden_dim=hidden_dim, device=device) + + # Run one training step with both models + batch = next(iter(data_loader)) + + # DDP: forward and backward + optimizer_ddp.zero_grad() + loss_ddp = model_ddp(batch[0], batch[1]) + loss_ddp.backward() + grads_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and backward + loss_ds = model_engine(batch[0], batch[1]) + loss_ds.backward() + grads_ds = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(grads_ddp, grads_ds) + + model_engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestZeroUserBackwardNonScalar(DistributedTest): + """Test non-scalar backward support""" + world_size = 2 + + def test_non_scalar_backward(self, zero_stage): + """Test that tensor.backward(grad) works correctly by comparing with PyTorch DDP""" + hidden_dim = 4 + batch_size = 2 + + # Create DDP and DeepSpeed models + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines( + model_class=SimpleNonScalarModel, zero_stage=zero_stage, hidden_dim=hidden_dim) + + # Create input data + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + + # DDP: forward and non-scalar backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x) + grad_output = torch.ones_like(output_ddp) + output_ddp.backward(grad_output) + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and non-scalar backward + output_deepspeed = model_engine(x) + grad_output_ds = torch.ones_like(output_deepspeed) + output_deepspeed.backward(grad_output_ds) + deepspeed_grads = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(ddp_grads, deepspeed_grads, "after non-scalar backward") + + # Run optimizer step + optimizer_ddp.step() + model_engine.step() + + # Collect and compare parameters after step + ddp_params = collect_ddp_parameters(model_ddp) + deepspeed_params = collect_deepspeed_parameters(model_engine, zero_stage) + compare_parameters(ddp_params, deepspeed_params, "after non-scalar backward") + + model_engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestZeroUserBackwardGradAccumulation(DistributedTest): + """Test gradient accumulation with user backward""" + world_size = 2 + + def test_grad_accumulation(self, zero_stage): + """Test that gradient accumulation works correctly with loss.backward() by comparing with DDP""" + hidden_dim = 4 + gradient_accumulation_steps = 4 + + # Create DDP and DeepSpeed models with gradient accumulation + model_ddp, optimizer_ddp, model_engine, device, _ = setup_models_and_engines( + model_class=SimpleModel, + zero_stage=zero_stage, + gradient_accumulation_steps=gradient_accumulation_steps, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + data_loader = random_dataloader(model=model_engine, total_samples=16, hidden_dim=hidden_dim, device=device) + + # Run training with gradient accumulation + for i, batch in enumerate(data_loader): + # DDP: Manual gradient accumulation + loss_ddp = model_ddp(batch[0], batch[1]) + (loss_ddp / gradient_accumulation_steps).backward() + + # DeepSpeed: Built-in gradient accumulation + loss_ds = model_engine(batch[0], batch[1]) + loss_ds.backward() + + # Compare gradients at accumulation boundary + if model_engine.is_gradient_accumulation_boundary(): + grads_ddp = collect_ddp_gradients(model_ddp) + grads_ds = collect_gradients_safe(model_engine) + compare_gradients(grads_ddp, grads_ds, f"step {i}") + + # Step both optimizers + optimizer_ddp.step() + optimizer_ddp.zero_grad() + + # Step DeepSpeed (handles gradient accumulation internally) + model_engine.step() + + model_engine.destroy() + + def test_grad_accumulation_scale_wrt_gas_false(self, zero_stage): + """Test that scale_wrt_gas=False disables gradient scaling by accumulation steps. + + When scale_wrt_gas=False is passed to engine.backward(), gradients should NOT be + scaled by gradient_accumulation_steps. This is useful when users want to handle + gradient scaling themselves (e.g., using Hugging Face Accelerate). + """ + hidden_dim = 4 + gradient_accumulation_steps = 4 + + # Create DDP and DeepSpeed models with gradient accumulation + model_ddp, optimizer_ddp, model_engine, device, _ = setup_models_and_engines( + model_class=SimpleModel, + zero_stage=zero_stage, + gradient_accumulation_steps=gradient_accumulation_steps, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + data_loader = random_dataloader(model=model_engine, total_samples=16, hidden_dim=hidden_dim, device=device) + + # Run training with gradient accumulation but WITHOUT scaling by GAS + for i, batch in enumerate(data_loader): + # DDP: Do NOT divide by GAS (since we're testing scale_wrt_gas=False) + loss_ddp = model_ddp(batch[0], batch[1]) + loss_ddp.backward() + + # DeepSpeed: Use scale_wrt_gas=False to disable gradient scaling + loss_ds = model_engine(batch[0], batch[1]) + model_engine.backward(loss_ds, scale_wrt_gas=False) + + # Compare gradients at accumulation boundary + if model_engine.is_gradient_accumulation_boundary(): + grads_ddp = collect_ddp_gradients(model_ddp) + grads_ds = collect_gradients_safe(model_engine) + compare_gradients(grads_ddp, grads_ds, f"step {i} with scale_wrt_gas=False") + + # Step both optimizers + optimizer_ddp.step() + optimizer_ddp.zero_grad() + + # Step DeepSpeed (handles gradient accumulation internally) + model_engine.step() + + model_engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestZeroUserBackwardMultipleEngines(DistributedTest): + """Test multiple DeepSpeed engines with combined loss without manual _backward_epilogue()""" + world_size = 2 + + def test_multiple_engines_combined_loss(self, zero_stage): + """Test that multiple engines work with combined loss.backward() without manual _backward_epilogue() + + This test compares the behavior with PyTorch DDP baseline to ensure correctness. + """ + hidden_dim = 4 + batch_size = 2 + num_models = 3 + lr = 1e-3 + + # Initialize distributed + device, rank, dtype = initialize_distributed() + + # Create DDP baseline models + ddp_models = [] + ddp_optimizers = [] + for i in range(num_models): + model, optimizer = create_ddp_model(SimpleModel, + device, + rank, + dtype, + seed=42 + i, + lr=lr, + hidden_dim=hidden_dim, + nlayers=2) + ddp_models.append(model) + ddp_optimizers.append(optimizer) + + # Create multiple DeepSpeed engines with identical initialization + model_engines = [] + for i in range(num_models): + engine = create_deepspeed_engine(SimpleModel, zero_stage, seed=42 + i, hidden_dim=hidden_dim, nlayers=2) + model_engines.append(engine) + + # Create same input for all models + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP baseline: compute losses and combined backward + for optimizer in ddp_optimizers: + optimizer.zero_grad() + + ddp_losses = [] + for model in ddp_models: + loss = model(x, y) + ddp_losses.append(loss) + + ddp_combined_loss = sum(l / (i + 1) for i, l in enumerate(ddp_losses)) + ddp_combined_loss.backward() + + # Collect DDP gradients for each model + ddp_grads_per_model = [collect_ddp_gradients(model) for model in ddp_models] + + # DeepSpeed: compute losses and combined backward WITHOUT manual _backward_epilogue() + ds_losses = [engine(x, y) for engine in model_engines] + ds_combined_loss = sum(l / (i + 1) for i, l in enumerate(ds_losses)) + ds_combined_loss.backward() + + # Collect DeepSpeed gradients for each engine and compare with DDP + for engine_idx, engine in enumerate(model_engines): + ds_grads = collect_gradients_safe(engine) + ddp_grads = ddp_grads_per_model[engine_idx] + assert len(ds_grads) > 0, f"Engine {engine_idx} has no gradients after combined_loss.backward()" + compare_gradients(ddp_grads, ds_grads, f"Engine {engine_idx}") + + # Step all DDP models + for optimizer in ddp_optimizers: + optimizer.step() + optimizer.zero_grad() + + # Step all DeepSpeed engines + for engine in model_engines: + engine.step() + engine.optimizer.zero_grad() + + # Run another iteration to ensure everything still works + torch.manual_seed(456) + x2 = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP second iteration + ddp_losses2 = [model(x2, y2) for model in ddp_models] + ddp_combined_loss2 = sum(l / (i + 1) for i, l in enumerate(ddp_losses2)) + ddp_combined_loss2.backward() + ddp_grads_per_model2 = [collect_ddp_gradients(model) for model in ddp_models] + + # DeepSpeed second iteration + ds_losses2 = [engine(x2, y2) for engine in model_engines] + ds_combined_loss2 = sum(l / (i + 1) for i, l in enumerate(ds_losses2)) + ds_combined_loss2.backward() + + # Verify gradients again and compare with DDP + for engine_idx, engine in enumerate(model_engines): + ds_grads = collect_gradients_safe(engine) + ddp_grads = ddp_grads_per_model2[engine_idx] + assert len(ds_grads) > 0, f"Engine {engine_idx} has no gradients in second iteration" + compare_gradients(ddp_grads, ds_grads, f"Engine {engine_idx} (iter 2)") + + # Step both + for optimizer in ddp_optimizers: + optimizer.step() + + for engine in model_engines: + engine.step() + + # Cleanup + for engine in model_engines: + engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +class TestZeroUserBackwardSeparateLoss(DistributedTest): + """Test using separate loss functions""" + world_size = 2 + + def test_separate_loss_function(self, zero_stage): + """Test that separate loss function works correctly by comparing with PyTorch DDP""" + hidden_dim = 4 + batch_size = 2 + + # Create DDP and DeepSpeed models + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim) + + # Define loss function separately + loss_fn = torch.nn.CrossEntropyLoss() + + # Create input data + torch.manual_seed(456) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP: forward, loss, backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x) + loss_ddp = loss_fn(output_ddp, y) + loss_ddp.backward() + grads_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward, loss, backward + output_ds = model_engine(x) + loss_ds = loss_fn(output_ds, y) + loss_ds.backward() + grads_ds = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(grads_ddp, grads_ds) + + model_engine.destroy() + + +class LeafModuleModel(torch.nn.Module): + """Model with ModuleList that uses all parameters - for testing leaf module compatibility""" + + def __init__(self, hidden_dim): + super().__init__() + # ModuleList where all branches are used in forward pass + self.branches = torch.nn.ModuleList([ + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.Linear(hidden_dim, hidden_dim), + ]) + self.final_layer = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x, y): + # Use all branches - add their outputs together + x = self.branches[0](x) + self.branches[1](x) + x = self.final_layer(x) + loss = torch.nn.functional.cross_entropy(x, y) + return loss + + +class LeafNonScalarModel(torch.nn.Module): + """Leaf module model that returns non-scalar output""" + + def __init__(self, hidden_dim): + super().__init__() + self.branches = torch.nn.ModuleList([ + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.Linear(hidden_dim, hidden_dim), + ]) + + def forward(self, x): + # Use all branches - returns non-scalar output + return self.branches[0](x) + self.branches[1](x) + + +@pytest.mark.parametrize("zero_stage", [3]) +class TestZeroUserBackwardLeafModule(DistributedTest): + """Test leaf module behavior during backward passes in ZeRO Stage 3""" + world_size = 2 + + def test_leaf_module_backward(self, zero_stage): + """Test that leaf modules work correctly with user backward by comparing with PyTorch DDP + + This test validates that the leaf_module_count and backward hooks are correctly + handled in create_reduce_and_remove_grad_hooks. + """ + from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module + + hidden_dim = 4 + batch_size = 2 + lr = 1e-3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model + model_ddp, optimizer_ddp = create_ddp_model(LeafModuleModel, + device, + rank, + dtype, + seed=42, + lr=lr, + hidden_dim=hidden_dim) + + # Create DeepSpeed model and mark leaf modules BEFORE initialization + torch.manual_seed(42) + model_deepspeed = LeafModuleModel(hidden_dim=hidden_dim) + leaf_modules = set_z3_leaf_modules(model_deepspeed, [torch.nn.ModuleList]) + assert len(leaf_modules) == 1, "Expected exactly one ModuleList to be marked as leaf" + assert z3_leaf_module(model_deepspeed.branches), "ModuleList should be marked as leaf module" + + # Initialize DeepSpeed engine from the prepared model + model_engine = create_deepspeed_engine_from_model(model_deepspeed, zero_stage) + + # Verify leaf_module_count was set correctly + assert len(model_engine.optimizer.leaf_parameters) == 1, \ + "Expected 1 leaf module in optimizer.leaf_parameters" + + # Create input data + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP: forward and backward + optimizer_ddp.zero_grad() + loss_ddp = model_ddp(x, y) + loss_ddp.backward() + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and backward with leaf module + loss_deepspeed = model_engine(x, y) + loss_deepspeed.backward() + deepspeed_grads = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(ddp_grads, deepspeed_grads, "with leaf modules") + + model_engine.destroy() + + def test_leaf_module_non_scalar_backward(self, zero_stage): + """Test that leaf modules work correctly with non-scalar backward (tensor.backward(grad)) + + This specifically tests the interaction between leaf modules and non-scalar backward. + """ + from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module + + hidden_dim = 4 + batch_size = 2 + lr = 1e-3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model + model_ddp, optimizer_ddp = create_ddp_model(LeafNonScalarModel, + device, + rank, + dtype, + seed=42, + lr=lr, + hidden_dim=hidden_dim) + + # Create DeepSpeed model and mark leaf modules BEFORE initialization + torch.manual_seed(42) + model_deepspeed = LeafNonScalarModel(hidden_dim=hidden_dim) + leaf_modules = set_z3_leaf_modules(model_deepspeed, [torch.nn.ModuleList]) + assert len(leaf_modules) == 1, "Expected exactly one ModuleList to be marked as leaf" + assert z3_leaf_module(model_deepspeed.branches), "ModuleList should be marked as leaf module" + + # Initialize DeepSpeed engine from the prepared model + model_engine = create_deepspeed_engine_from_model(model_deepspeed, zero_stage) + + # Verify leaf_module_count was set correctly + assert len(model_engine.optimizer.leaf_parameters) == 1, \ + "Expected 1 leaf module in optimizer.leaf_parameters" + + # Create input data + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + + # DDP: forward and non-scalar backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x) + grad_output = torch.ones_like(output_ddp) + output_ddp.backward(grad_output) + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and non-scalar backward with leaf module + output_deepspeed = model_engine(x) + grad_output_ds = torch.ones_like(output_deepspeed) + output_deepspeed.backward(grad_output_ds) + deepspeed_grads = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(ddp_grads, deepspeed_grads, "in leaf module non-scalar backward") + + model_engine.destroy() + + +@pytest.mark.sequential +class TestZeroUserBackwardScaleErrorDetection(DistributedTest): + """Test error detection for missing scale() with fp16 in single-process setup""" + world_size = 1 # Use single process to avoid distributed deadlock issues + + def test_error_when_backward_without_scale_sequential(self): + """Test that error is raised when calling backward() without scale() with fp16""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("Test requires fp16 support") + + hidden_dim = 4 + zero_stage = 1 # Use ZeRO stage 1 for simplicity + + # Initialize distributed + device, _, _ = initialize_distributed() + + # Create engine with fp16 - requires scaling + torch.manual_seed(42) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + + config = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + model_engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + + # Verify needs_scaler is True + from deepspeed.runtime.base_optimizer import ZeROOptimizer + assert isinstance(model_engine.optimizer, ZeROOptimizer) + assert model_engine.optimizer.needs_scaler(), "fp16 should require scaling" + + # Create data + data_loader = random_dataloader(model=model_engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.float16) + batch = next(iter(data_loader)) + + loss = model_engine(batch[0], batch[1]) + + # Calling backward() without scale() should raise RuntimeError + with pytest.raises(RuntimeError, match="Loss scaling is required"): + loss.backward() + + model_engine.destroy() + + +@pytest.mark.parametrize("zero_stage", [1, 3]) +class TestZeroUserBackwardWithScale(DistributedTest): + """Test engine.scale() method for manual backward passes with loss scaling""" + world_size = 2 + + def test_scale_backward_matches_engine_backward(self, zero_stage): + """Test that engine.scale(loss).backward() produces same gradients as engine.backward(loss)""" + hidden_dim = 4 + + # Create DeepSpeed engines with same seed + model_engine1 = create_deepspeed_engine(model_class=SimpleModel, + zero_stage=zero_stage, + seed=42, + hidden_dim=hidden_dim, + nlayers=2) + model_engine2 = create_deepspeed_engine(model_class=SimpleModel, + zero_stage=zero_stage, + seed=42, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + device = get_accelerator().current_device_name() + data_loader = random_dataloader(model=model_engine1, total_samples=8, hidden_dim=hidden_dim, device=device) + batch = next(iter(data_loader)) + + # Model 1: use engine.backward(loss) + loss1 = model_engine1(batch[0], batch[1]) + model_engine1.backward(loss1) + grads1 = collect_gradients_safe(model_engine1) + + # Model 2: use engine.scale(loss).backward() + loss2 = model_engine2(batch[0], batch[1]) + scaled_loss = model_engine2.scale(loss2) + scaled_loss.backward() + grads2 = collect_gradients_safe(model_engine2) + + # Compare gradients - they should be identical + compare_gradients(grads1, grads2, "comparing engine.backward vs engine.scale().backward()") + + model_engine1.destroy() + model_engine2.destroy() + + def test_scale_backward_matches_ddp(self, zero_stage): + """Test that engine.scale(loss).backward() produces same gradients as DDP""" + hidden_dim = 4 + + # Create DDP and DeepSpeed models + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + data_loader = random_dataloader(model=model_engine, total_samples=8, hidden_dim=hidden_dim, device=device) + batch = next(iter(data_loader)) + + # DDP: forward and backward + optimizer_ddp.zero_grad() + loss_ddp = model_ddp(batch[0], batch[1]) + loss_ddp.backward() + grads_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and scale + backward + loss_ds = model_engine(batch[0], batch[1]) + scaled_loss = model_engine.scale(loss_ds) + scaled_loss.backward() + grads_ds = collect_gradients_safe(model_engine) + + # Compare gradients + compare_gradients(grads_ddp, grads_ds, "comparing DDP vs engine.scale().backward()") + + model_engine.destroy() + + def test_scale_with_gradient_accumulation(self, zero_stage): + """Test that engine.scale() works correctly with gradient accumulation""" + hidden_dim = 4 + gradient_accumulation_steps = 4 + + # Create models with gradient accumulation + model_ddp, optimizer_ddp, model_engine, device, _ = setup_models_and_engines( + model_class=SimpleModel, + zero_stage=zero_stage, + gradient_accumulation_steps=gradient_accumulation_steps, + hidden_dim=hidden_dim, + nlayers=2) + + # Create data + data_loader = random_dataloader(model=model_engine, total_samples=16, hidden_dim=hidden_dim, device=device) + + # Run gradient accumulation steps + for i, batch in enumerate(data_loader): + # DDP: manual gradient accumulation + loss_ddp = model_ddp(batch[0], batch[1]) + # Scale by GAS for DDP to match DeepSpeed behavior + (loss_ddp / gradient_accumulation_steps).backward() + + # DeepSpeed: use scale() with built-in gradient accumulation + # Note: scale() only applies loss scaler, NOT GAS. DeepSpeed handles GAS internally + # via engine.step(), so we do NOT manually divide by GAS here. + loss_ds = model_engine(batch[0], batch[1]) + scaled_loss = model_engine.scale(loss_ds) + scaled_loss.backward() + + # Compare gradients at accumulation boundary + if model_engine.is_gradient_accumulation_boundary(): + grads_ddp = collect_ddp_gradients(model_ddp) + grads_ds = collect_gradients_safe(model_engine) + compare_gradients(grads_ddp, grads_ds, f"step {i}") + + # Step both optimizers + optimizer_ddp.step() + optimizer_ddp.zero_grad() + + # Step DeepSpeed (handles gradient accumulation internally) + model_engine.step() + + model_engine.destroy() + + def test_needs_scaler_with_fp16(self, zero_stage): + """Test that needs_scaler() correctly identifies when scaling is required with fp16""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("Test requires fp16 support for gradient scaling") + + hidden_dim = 4 + + # Initialize distributed first + device, _, _ = initialize_distributed() + + # Create engine with fp16 explicitly to test gradient scaling requirement + torch.manual_seed(42) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + + config = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + # Explicitly enable fp16 to test gradient scaling requirement + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + if zero_stage == 3: + config["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + model_engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + + # Verify that the optimizer correctly reports it needs scaling with fp16 + from deepspeed.runtime.base_optimizer import ZeROOptimizer + assert isinstance(model_engine.optimizer, ZeROOptimizer), "Optimizer should be ZeROOptimizer" + assert model_engine.optimizer.needs_scaler(), "fp16 configuration should require gradient scaling" + + # Verify scale() method works correctly + data_loader = random_dataloader(model=model_engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.float16) + batch = next(iter(data_loader)) + loss = model_engine(batch[0], batch[1]) + + # Should be able to use scale() method and get a valid scaled tensor + scaled_loss = model_engine.scale(loss) + assert scaled_loss is not None, "scale() should return a scaled loss tensor" + assert scaled_loss.requires_grad, "scaled loss should require grad" + + model_engine.destroy() + + def test_needs_scaler_with_bf16(self, zero_stage): + """Test that needs_scaler() correctly identifies that bf16 does NOT require scaling""" + if not get_accelerator().is_bf16_supported(): + pytest.skip("Test requires bf16 support") + + hidden_dim = 4 + + # Initialize distributed first + device, _, _ = initialize_distributed() + + # Create engine with bf16 to verify scaling is NOT required + torch.manual_seed(42) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + + config = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + # Use bf16 which does NOT require gradient scaling + "bf16": { + "enabled": True + } + } + + if zero_stage == 3: + config["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + model_engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + + # Verify that the optimizer correctly reports it does NOT need scaling with bf16 + from deepspeed.runtime.base_optimizer import ZeROOptimizer + assert isinstance(model_engine.optimizer, ZeROOptimizer), "Optimizer should be ZeROOptimizer" + assert not model_engine.optimizer.needs_scaler(), "bf16 configuration should NOT require gradient scaling" + + # Verify that loss.backward() can be called directly without scale() for bf16 + data_loader = random_dataloader(model=model_engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.bfloat16) + batch = next(iter(data_loader)) + loss = model_engine(batch[0], batch[1]) + + # With bf16, should be able to call backward directly (no scaling required) + loss.backward() + + # Collect gradients to verify backward completed successfully + grads = collect_gradients_safe(model_engine) + assert len(grads) > 0, "Expected gradients to be computed" + + model_engine.destroy() + + def test_error_when_backward_without_scale_fp16(self, zero_stage): + """Test that calling backward() without scale() raises an error with fp16""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("Test requires fp16 support for gradient scaling") + + hidden_dim = 4 + + # Initialize distributed first + device, _, _ = initialize_distributed() + + # Create engine with fp16 + torch.manual_seed(42) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + + config = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + if zero_stage == 3: + config["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + model_engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + + # Verify needs_scaler is True + assert model_engine.optimizer.needs_scaler(), "fp16 should require scaling" + + # Create data + data_loader = random_dataloader(model=model_engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.float16) + batch = next(iter(data_loader)) + + loss = model_engine(batch[0], batch[1]) + + # Try to call backward without scale - should raise RuntimeError + error_raised = False + try: + loss.backward() + except RuntimeError as e: + if "Loss scaling is required" in str(e): + error_raised = True + else: + raise # Re-raise if it's a different error + + # If the test completes (doesn't hang), verify error was raised + if error_raised: + # Success - error was properly detected + pass + else: + # If no error was raised, this is a problem (or it hung and timed out) + pytest.fail("Expected RuntimeError about loss scaling, but backward completed without error") + + model_engine.destroy() + + def test_scale_validates_scalar_loss(self, zero_stage): + """Test that scale() validates the input is a scalar loss tensor""" + hidden_dim = 4 + + model_engine = create_deepspeed_engine(model_class=SimpleNonScalarModel, + zero_stage=zero_stage, + seed=42, + hidden_dim=hidden_dim) + + device = get_accelerator().current_device_name() + dtype = preferred_dtype() + torch.manual_seed(123) + x = torch.randn(2, hidden_dim, device=device, dtype=dtype) + + # Forward to get non-scalar output + output = model_engine(x) + + # Trying to scale a non-scalar tensor should raise an assertion error + with pytest.raises(AssertionError, match="scalar tensor"): + model_engine.scale(output) + + model_engine.destroy() + + def test_scale_with_torch_autocast(self, zero_stage): + """Test that scale() works correctly with torch.autocast and fp16""" + if not get_accelerator().is_fp16_supported(): + pytest.skip("FP16 not supported on this accelerator") + + hidden_dim = 4 + + # Initialize distributed first + device, _, _ = initialize_distributed() + + # Create engine with fp16 config to test gradient scaling + torch.manual_seed(42) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) + + config = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + # Enable fp16 to test gradient scaling (bf16 doesn't use gradient scaling) + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + if zero_stage == 3: + config["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + model_engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) + + # Create data with fp16 dtype to match the config + data_loader = random_dataloader(model=model_engine, + total_samples=8, + hidden_dim=hidden_dim, + device=device, + dtype=torch.float16) + batch = next(iter(data_loader)) + + # Forward and use scale() + loss = model_engine(batch[0], batch[1]) + scaled_loss = model_engine.scale(loss) + + # Should be able to call backward + scaled_loss.backward() + + # Collect gradients to verify they exist + grads = collect_gradients_safe(model_engine) + assert len(grads) > 0, "Expected gradients to be computed" + + model_engine.destroy() + + +class NonCheckpointedModel(torch.nn.Module): + """Model without gradient checkpointing, used as reference for comparison.""" + + def __init__(self, hidden_dim): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + return x + + +class CheckpointedModel(torch.nn.Module): + """Model that uses gradient checkpointing with configurable use_reentrant setting. + + This model is designed to test the interaction between ZeRO-3 and gradient + checkpointing with both reentrant (use_reentrant=True) and non-reentrant + (use_reentrant=False) modes. + + Uses 2 layers to minimize numerical divergence from bfloat16 precision + accumulation over multiple optimizer steps. + """ + + def __init__(self, hidden_dim, use_reentrant=True): + super().__init__() + self.use_reentrant = use_reentrant + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def _checkpointed_block(self, x): + """Block that will be checkpointed""" + x = self.linear1(x) + x = torch.nn.functional.relu(x) + return x + + def forward(self, x): + # Use gradient checkpointing on the first block + if self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(self._checkpointed_block, x, use_reentrant=self.use_reentrant) + else: + x = self._checkpointed_block(x) + x = self.linear2(x) + return x + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +@pytest.mark.parametrize("use_reentrant", [True, False]) +class TestZeroUserBackwardWithCheckpointing(DistributedTest): + """Test ZeRO with gradient checkpointing and non-scalar backward. + + This test class validates the interaction between: + 1. ZeRO parameter partitioning (stages 1 and 3) + 2. Gradient checkpointing (both reentrant and non-reentrant modes) + 3. Non-scalar backward (tensor.backward(gradient=...)) + + Both use_reentrant=True and use_reentrant=False are supported with ZeRO. + Note: When using use_reentrant=True, input tensors should have requires_grad=True + for proper gradient computation through the checkpointed region. + """ + world_size = 2 + + def test_checkpointed_non_scalar_backward(self, zero_stage, use_reentrant): + """Test that gradient checkpointing works with ZeRO and non-scalar backward. + + Verifies that tensor.backward(gradient=...) works correctly with ZeRO + and gradient checkpointing in both reentrant and non-reentrant modes. + """ + hidden_dim = 8 + batch_size = 2 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model for reference (no checkpointing issues with DDP) + torch.manual_seed(42) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + + # Create DeepSpeed model with ZeRO-3 + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters()) + + # Create input data - use separate tensors for DDP and DeepSpeed to avoid + # memory sharing issues during parallel test execution + torch.manual_seed(123) + x_ddp = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + + # DDP: forward and non-scalar backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x_ddp) + grad_output = torch.ones_like(output_ddp) + output_ddp.backward(grad_output) + get_accelerator().synchronize() # Ensure CUDA ops complete + dist.barrier() # Ensure all ranks complete gradient sync + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed with ZeRO-3: forward and non-scalar backward + # This is the pattern used in disaggregated training + # Create fresh tensor with same seed for reproducibility + torch.manual_seed(123) + x_ds = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + output_ds = model_engine(x_ds) + grad_output_ds = torch.ones_like(output_ds) + + # Non-scalar backward with gradient checkpointing + output_ds.backward(grad_output_ds) + + # Synchronize device before collecting gradients. ZeRO-3 uses async operations + # on separate streams for gradient reduction. With use_reentrant=True checkpointing, + # we need to ensure all operations complete before reading gradient data. + get_accelerator().synchronize() + dist.barrier() # Ensure all ranks complete backward before collecting gradients + + # Collect and verify gradients + ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed + assert len(ds_grads) > 0, \ + f"No gradients computed with use_reentrant={use_reentrant} and ZeRO-3" + + # Compare gradients with DDP reference + compare_gradients(ddp_grads, ds_grads, f"with checkpointing use_reentrant={use_reentrant}") + + # Run optimizer step to verify full training loop works + model_engine.step() + + model_engine.destroy() + + def test_checkpointed_scalar_backward(self, zero_stage, use_reentrant): + """Test that gradient checkpointing works with ZeRO and scalar backward. + + Verifies that scalar loss.backward() works correctly with ZeRO and + gradient checkpointing in both reentrant and non-reentrant modes. + """ + hidden_dim = 8 + batch_size = 2 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model for reference + torch.manual_seed(42) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + + # Create DeepSpeed model with ZeRO-3 + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters()) + + # Create input data - use separate tensors for DDP and DeepSpeed to avoid + # memory sharing issues during parallel test execution + torch.manual_seed(123) + x_ddp = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + y = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP: forward with scalar loss and backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x_ddp) + loss_ddp = torch.nn.functional.cross_entropy(output_ddp, y) + loss_ddp.backward() + get_accelerator().synchronize() # Ensure CUDA ops complete + dist.barrier() # Ensure all ranks complete gradient sync + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed with ZeRO-3: forward with scalar loss and backward + # Create fresh tensor with same seed for reproducibility + torch.manual_seed(123) + x_ds = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + output_ds = model_engine(x_ds) + loss_ds = torch.nn.functional.cross_entropy(output_ds, y) + + loss_ds.backward() + + # Synchronize device before collecting gradients. ZeRO-3 uses async operations + # on separate streams for gradient reduction. With use_reentrant=True checkpointing, + # we need to ensure all operations complete before reading gradient data. + get_accelerator().synchronize() + dist.barrier() # Ensure all ranks complete backward before collecting gradients + + # Collect and verify gradients + ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed + assert len(ds_grads) > 0, \ + f"No gradients computed with scalar loss, use_reentrant={use_reentrant}" + + # Compare gradients with DDP reference + compare_gradients(ddp_grads, ds_grads, f"scalar loss with checkpointing use_reentrant={use_reentrant}") + + model_engine.destroy() + + def test_checkpointed_multiple_backward(self, zero_stage, use_reentrant): + """Test multiple backward passes with checkpointing and ZeRO. + + Verifies that consecutive training iterations work correctly with + gradient checkpointing. Compares gradients with DDP at all iterations + to verify correctness. Uses PyTorch Adam for both to ensure fair comparison. + """ + hidden_dim = 8 + batch_size = 2 + num_iterations = 3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model for reference with PyTorch Adam + torch.manual_seed(42) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + + # Create DeepSpeed model WITH checkpointing, using PyTorch Adam + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + optimizer_ds = torch.optim.Adam(model_ds.parameters(), lr=1e-3) + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters(), + optimizer=optimizer_ds) + + for iteration in range(num_iterations): + # Use same random seed for both models + torch.manual_seed(123 + iteration) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + + # DDP: forward and backward + optimizer_ddp.zero_grad() + x_ddp = x.clone().detach().requires_grad_(True) + output_ddp = model_ddp(x_ddp) + output_ddp.backward(torch.ones_like(output_ddp)) + get_accelerator().synchronize() + dist.barrier() + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed: forward and backward + x_ds = x.clone().detach().requires_grad_(True) + output_ds = model_engine(x_ds) + output_ds.backward(torch.ones_like(output_ds)) + get_accelerator().synchronize() + dist.barrier() + ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed + assert len(ds_grads) > 0, \ + f"No gradients at iteration {iteration} with use_reentrant={use_reentrant}" + + # Compare gradients with DDP - using same optimizer so should match closely + # Small differences at later iterations are expected due to bfloat16 precision + compare_gradients(ddp_grads, ds_grads, f"iteration {iteration} with use_reentrant={use_reentrant}") + + # Run optimizer steps on both models + optimizer_ddp.step() + model_engine.step() + + model_engine.destroy() diff --git a/version.txt b/version.txt index ac39a106c485..41915c799477 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.9.0 +0.19.1